from abc import ABC, abstractmethod
import numpy as np
import pandas as pd
from .convert_tracking import compute_angles_from_keypoints
from .convert_tracking import convert_tail_angle_to_keypoints
from .convert_tracking import interpolate_tail_keypoint, interpolate_tail_angle
[docs]
class TrackingConfig:
"""Configuration for zebrafish tracking datasets.
Parameters
----------
fps : int
Frames per second of the recording (between 20 and 700)
tracking : str
Type of tracking ('full_tracking', 'head_tracking', 'tail_tracking')
Examples
--------
>>> config = TrackingConfig(fps=700, tracking='full_tracking')
>>> config.fps
700
>>> config = TrackingConfig(fps=25, tracking='head_tracking')
>>> config.tracking
'head_tracking'
"""
[docs]
def __init__(self, *, fps, tracking):
tracking_options = ["full_tracking", "head_tracking", "tail_tracking"]
if tracking not in tracking_options:
raise AttributeError(f"tracking should be among {tracking_options}")
if not (20 <= fps <= 700) or not (fps == np.round(fps)):
raise AttributeError("fps should be an integer between 20 and 700")
self.fps = int(fps)
self.tracking = tracking
[docs]
class TrackingData(ABC):
[docs]
@classmethod
@abstractmethod
def from_keypoints(*args):
raise NotImplementedError("This method should be overridden by subclasses")
[docs]
@classmethod
@abstractmethod
def from_posture(*args):
raise NotImplementedError("This method should be overridden by subclasses")
@staticmethod
@abstractmethod
def _validate_keypoints(*args):
raise NotImplementedError("This method should be overridden by subclasses")
@staticmethod
@abstractmethod
def _validate_posture(*args):
raise NotImplementedError("This method should be overridden by subclasses")
[docs]
class FullTrackingData(TrackingData):
"""Container for full tracking data including both head and tail information.
The class provides two constructors:
- from_keypoints: construct from raw x,y coordinates
- from_posture: construct from head position and tail angles
Examples
--------
>>> from megabouts.tracking_data import load_example_data
>>> df, fps, mm_per_unit = load_example_data('fulltracking_posture')
>>> head_x = df["head_x"].values * mm_per_unit
>>> head_y = df["head_y"].values * mm_per_unit
>>> head_yaw = df["head_angle"].values
>>> tail_angle = df.filter(like="tail_angle").values
>>> tracking_data = FullTrackingData.from_posture(
... head_x=head_x, head_y=head_y, head_yaw=head_yaw, tail_angle=tail_angle
... )
>>> isinstance(tracking_data.tail_df, pd.DataFrame)
True
>>> isinstance(tracking_data.traj_df, pd.DataFrame)
True
"""
[docs]
def __init__(self, head_x, head_y, head_yaw, tail_x, tail_y, tail_angle):
self._tail_x = tail_x
self._tail_y = tail_y
self._tail_angle = tail_angle
self._head_x = head_x
self._head_y = head_y
self._head_yaw = head_yaw
self.T = len(self._tail_angle)
[docs]
@classmethod
def from_keypoints(cls, *, head_x, head_y, tail_x, tail_y):
"""Construct from raw keypoint coordinates.
Parameters
----------
head_x, head_y : array-like
Head coordinates, shape (T,)
tail_x, tail_y : array-like
Tail coordinates, shape (T, N_segments)
"""
cls._validate_keypoints(head_x, head_y, tail_x, tail_y)
if tail_x.shape[1] != 11:
tail_x, tail_y = interpolate_tail_keypoint(tail_x, tail_y, 10)
tail_angle, head_yaw = compute_angles_from_keypoints(
head_x, head_y, tail_x, tail_y
)
return cls(head_x, head_y, head_yaw, tail_x, tail_y, tail_angle)
[docs]
@classmethod
def from_posture(cls, *, head_x, head_y, head_yaw, tail_angle):
"""Construct from head position and tail angles.
Parameters
----------
head_x, head_y : array-like
Head coordinates, shape (T,)
head_yaw : array-like
Head orientation, shape (T,)
tail_angle : array-like
Tail angles, shape (T, N_segments)
"""
cls._validate_posture(head_x, head_y, head_yaw, tail_angle)
if tail_angle.shape[1] != 10:
tail_angle = interpolate_tail_angle(tail_angle, 10)
#
#
return cls(head_x, head_y, head_yaw, None, None, tail_angle)
@property
def tail_df(self):
tail_df = pd.DataFrame(
self._tail_angle, columns=[f"angle_{i}" for i in range(10)]
)
return tail_df
@property
def traj_df(self):
traj_df = pd.DataFrame(
{"x": self._head_x, "y": self._head_y, "yaw": self._head_yaw}
)
return traj_df
@property
def tail_keypoints_df(self):
if self._tail_x is None or self._tail_y is None:
self._tail_x, self._tail_y = convert_tail_angle_to_keypoints(
self._head_x,
self._head_y,
self._head_yaw,
self._tail_angle,
body_to_tail_mm=0.5,
tail_to_tail_mm=0.32,
)
tail_keypoints_df = pd.DataFrame(
{"tail_x": self._tail_x, "tail_y": self._tail_y}
)
return tail_keypoints_df
@staticmethod
def _validate_keypoints(head_x, head_y, tail_x, tail_y):
T = len(head_x)
if not (len(head_y) == T and tail_x.shape[0] == T and tail_y.shape[0] == T):
raise ValueError("All inputs must have the same number of time points (T).")
N_keypoints = tail_x.shape[1]
if N_keypoints < 4:
raise ValueError(
"At least 4 points from swim bladder to tail tips are required for full tracking"
)
if tail_x.shape[1] != tail_y.shape[1]:
raise ValueError(
"tail_x and tail_y must have the same number of keypoints (N)."
)
@staticmethod
def _validate_posture(head_x, head_y, head_yaw, tail_angle):
T = len(head_x)
if not (
len(head_y) == T and head_yaw.shape[0] == T and tail_angle.shape[0] == T
):
raise ValueError("All inputs must have the same number of time points (T).")
N_keypoints = tail_angle.shape[1] + 1
if N_keypoints < 4:
raise ValueError(
"At least 4 points from swim bladder to tail tips are required for full tracking"
)
[docs]
class HeadTrackingData(TrackingData):
"""Container for head tracking data.
Examples
--------
>>> from megabouts.tracking_data import load_example_data
>>> df, fps, mm_per_unit = load_example_data('zebrabox_SLEAP')
>>> head_x = df['mid_eye.x'].values * mm_per_unit
>>> head_y = df['mid_eye.y'].values * mm_per_unit
>>> swimbladder_x = df['swim_bladder.x'].values * mm_per_unit
>>> swimbladder_y = df['swim_bladder.y'].values * mm_per_unit
>>> tracking_data = HeadTrackingData.from_keypoints(
... head_x=head_x, head_y=head_y,
... swimbladder_x=swimbladder_x, swimbladder_y=swimbladder_y)
>>> isinstance(tracking_data.traj_df, pd.DataFrame)
True
"""
[docs]
def __init__(self, head_x, head_y, head_yaw, swimbladder_x, swimbladder_y):
self._head_x = head_x
self._head_y = head_y
self._head_yaw = head_yaw
self._swimbladder_x = swimbladder_x
self._swimbladder_y = swimbladder_y
self.T = len(self._head_x)
[docs]
@classmethod
def from_keypoints(cls, *, head_x, head_y, swimbladder_x, swimbladder_y):
cls._validate_keypoints(head_x, head_y, swimbladder_x, swimbladder_y)
tail_angle, head_yaw = compute_angles_from_keypoints(
head_x, head_y, swimbladder_x[:, np.newaxis], swimbladder_y[:, np.newaxis]
)
return cls(head_x, head_y, head_yaw, swimbladder_x, swimbladder_y)
[docs]
@classmethod
def from_posture(cls, *, head_x, head_y, head_yaw):
cls._validate_posture(head_x, head_y, head_yaw)
# tail_angle = np.zeros((len(head_x),10))
# tail_x,tail_y = convert_tail_angle_to_keypoints(head_x, head_y, head_yaw, tail_angle, body_to_tail_mm=0.5, tail_to_tail_mm=0.32)
# swimbladder_x,swimbladder_y = tail_x[:,0],tail_y[:,0]
# return cls(head_x,head_y,head_yaw,swimbladder_x,swimbladder_y)
return cls(head_x, head_y, head_yaw, None, None)
@property
def traj_df(self):
traj_df = pd.DataFrame(
{"x": self._head_x, "y": self._head_y, "yaw": self._head_yaw}
)
return traj_df
@staticmethod
def _validate_keypoints(head_x, head_y, swimbladder_x, swimbladder_y):
T = len(head_x)
if not (
len(head_y) == T and len(swimbladder_x) == T and len(swimbladder_y) == T
):
raise ValueError("All inputs must have the same number of time points (T).")
@staticmethod
def _validate_posture(head_x, head_y, head_yaw):
T = len(head_x)
if not (len(head_y) == T and head_yaw.shape[0] == T):
raise ValueError("All inputs must have the same number of time points (T).")
[docs]
class TailTrackingData(TrackingData):
"""Container for tail tracking data.
Examples
--------
>>> from megabouts.tracking_data import load_example_data
>>> df, fps, mm_per_unit = load_example_data('HR_DLC')
>>> tail_x = df["DLC_resnet50_Zebrafish"].loc[:, [(f"tail{i}", "x") for i in range(11)]].values * mm_per_unit
>>> tail_y = df["DLC_resnet50_Zebrafish"].loc[:, [(f"tail{i}", "y") for i in range(11)]].values * mm_per_unit
>>> tracking_data = TailTrackingData.from_keypoints(tail_x=tail_x, tail_y=tail_y)
>>> isinstance(tracking_data.tail_df, pd.DataFrame)
True
"""
[docs]
def __init__(self, tail_x, tail_y, tail_angle):
self._tail_x = tail_x
self._tail_y = tail_y
self._tail_angle = tail_angle
self.T = len(self._tail_x)
[docs]
@classmethod
def from_keypoints(cls, *, tail_x, tail_y):
cls._validate_keypoints(tail_x, tail_y)
if tail_x.shape[1] != 11:
tail_x, tail_y = interpolate_tail_keypoint(tail_x, tail_y, 10)
tail_angle, head_yaw = compute_angles_from_keypoints(
tail_x[:, 0] + 0.5, tail_y[:, 0], tail_x, tail_y
)
return cls(tail_x, tail_y, tail_angle)
[docs]
@classmethod
def from_posture(cls, *, tail_angle):
cls._validate_posture(tail_angle)
if tail_angle.shape[1] != 10:
tail_angle = interpolate_tail_angle(tail_angle, 10)
T = tail_angle.shape[0]
head_x, head_y, head_yaw = np.zeros(T), np.zeros(T), np.zeros(T)
tail_x, tail_y = convert_tail_angle_to_keypoints(
head_x,
head_y,
head_yaw,
tail_angle,
body_to_tail_mm=0.0,
tail_to_tail_mm=0.32,
)
return cls(tail_x, tail_y, tail_angle)
@property
def tail_df(self):
tail_df = pd.DataFrame(
self._tail_angle, columns=[f"angle_{i}" for i in range(10)]
)
return tail_df
@staticmethod
def _validate_keypoints(tail_x, tail_y):
if tail_x.shape[0] != tail_y.shape[0]:
raise ValueError("All inputs must have the same number of time points (T).")
N_keypoints = tail_x.shape[1]
if N_keypoints < 4:
raise ValueError(
"At least 4 points from swim bladder to tail tips are required for tail tracking"
)
if tail_x.shape[1] != tail_y.shape[1]:
raise ValueError(
"tail_x and tail_y must have the same number of keypoints (N)."
)
@staticmethod
def _validate_posture(tail_angle):
N_keypoints = tail_angle.shape[1] + 1
if N_keypoints < 4:
raise ValueError(
"At least 4 points from swim bladder to tail tips are required for full tracking"
)