from abc import ABC, abstractmethod
import numpy as np
import pandas as pd
from .convert_tracking import compute_angles_from_keypoints
from .convert_tracking import convert_tail_angle_to_keypoints
from .convert_tracking import interpolate_tail_keypoint, interpolate_tail_angle
[docs]
class TrackingConfig:
    """Configuration for zebrafish tracking datasets.
    Parameters
    ----------
    fps : int
        Frames per second of the recording (between 20 and 700)
    tracking : str
        Type of tracking ('full_tracking', 'head_tracking', 'tail_tracking')
    Examples
    --------
    >>> config = TrackingConfig(fps=700, tracking='full_tracking')
    >>> config.fps
    700
    >>> config = TrackingConfig(fps=25, tracking='head_tracking')
    >>> config.tracking
    'head_tracking'
    """
[docs]
    def __init__(self, *, fps, tracking):
        tracking_options = ["full_tracking", "head_tracking", "tail_tracking"]
        if tracking not in tracking_options:
            raise AttributeError(f"tracking should be among {tracking_options}")
        if not (20 <= fps <= 700) or not (fps == np.round(fps)):
            raise AttributeError("fps should be an integer between 20 and 700")
        self.fps = int(fps)
        self.tracking = tracking 
 
[docs]
class TrackingData(ABC):
[docs]
    @classmethod
    @abstractmethod
    def from_keypoints(*args):
        raise NotImplementedError("This method should be overridden by subclasses") 
[docs]
    @classmethod
    @abstractmethod
    def from_posture(*args):
        raise NotImplementedError("This method should be overridden by subclasses") 
    @staticmethod
    @abstractmethod
    def _validate_keypoints(*args):
        raise NotImplementedError("This method should be overridden by subclasses")
    @staticmethod
    @abstractmethod
    def _validate_posture(*args):
        raise NotImplementedError("This method should be overridden by subclasses") 
[docs]
class FullTrackingData(TrackingData):
    """Container for full tracking data including both head and tail information.
    The class provides two constructors:
    - from_keypoints: construct from raw x,y coordinates
    - from_posture: construct from head position and tail angles
    Examples
    --------
    >>> from megabouts.tracking_data import load_example_data
    >>> df, fps, mm_per_unit = load_example_data('fulltracking_posture')
    >>> head_x = df["head_x"].values * mm_per_unit
    >>> head_y = df["head_y"].values * mm_per_unit
    >>> head_yaw = df["head_angle"].values
    >>> tail_angle = df.filter(like="tail_angle").values
    >>> tracking_data = FullTrackingData.from_posture(
    ...     head_x=head_x, head_y=head_y, head_yaw=head_yaw, tail_angle=tail_angle
    ... )
    >>> isinstance(tracking_data.tail_df, pd.DataFrame)
    True
    >>> isinstance(tracking_data.traj_df, pd.DataFrame)
    True
    """
[docs]
    def __init__(self, head_x, head_y, head_yaw, tail_x, tail_y, tail_angle):
        self._tail_x = tail_x
        self._tail_y = tail_y
        self._tail_angle = tail_angle
        self._head_x = head_x
        self._head_y = head_y
        self._head_yaw = head_yaw
        self.T = len(self._tail_angle) 
[docs]
    @classmethod
    def from_keypoints(cls, *, head_x, head_y, tail_x, tail_y):
        """Construct from raw keypoint coordinates.
        Parameters
        ----------
        head_x, head_y : array-like
            Head coordinates, shape (T,)
        tail_x, tail_y : array-like
            Tail coordinates, shape (T, N_segments)
        """
        cls._validate_keypoints(head_x, head_y, tail_x, tail_y)
        if tail_x.shape[1] != 11:
            tail_x, tail_y = interpolate_tail_keypoint(tail_x, tail_y, 10)
        tail_angle, head_yaw = compute_angles_from_keypoints(
            head_x, head_y, tail_x, tail_y
        )
        return cls(head_x, head_y, head_yaw, tail_x, tail_y, tail_angle) 
[docs]
    @classmethod
    def from_posture(cls, *, head_x, head_y, head_yaw, tail_angle):
        """Construct from head position and tail angles.
        Parameters
        ----------
        head_x, head_y : array-like
            Head coordinates, shape (T,)
        head_yaw : array-like
            Head orientation, shape (T,)
        tail_angle : array-like
            Tail angles, shape (T, N_segments)
        """
        cls._validate_posture(head_x, head_y, head_yaw, tail_angle)
        if tail_angle.shape[1] != 10:
            tail_angle = interpolate_tail_angle(tail_angle, 10)
        #
        #
        return cls(head_x, head_y, head_yaw, None, None, tail_angle) 
    @property
    def tail_df(self):
        tail_df = pd.DataFrame(
            self._tail_angle, columns=[f"angle_{i}" for i in range(10)]
        )
        return tail_df
    @property
    def traj_df(self):
        traj_df = pd.DataFrame(
            {"x": self._head_x, "y": self._head_y, "yaw": self._head_yaw}
        )
        return traj_df
    @property
    def tail_keypoints_df(self):
        if self._tail_x is None or self._tail_y is None:
            self._tail_x, self._tail_y = convert_tail_angle_to_keypoints(
                self._head_x,
                self._head_y,
                self._head_yaw,
                self._tail_angle,
                body_to_tail_mm=0.5,
                tail_to_tail_mm=0.32,
            )
        tail_keypoints_df = pd.DataFrame(
            {"tail_x": self._tail_x, "tail_y": self._tail_y}
        )
        return tail_keypoints_df
    @staticmethod
    def _validate_keypoints(head_x, head_y, tail_x, tail_y):
        T = len(head_x)
        if not (len(head_y) == T and tail_x.shape[0] == T and tail_y.shape[0] == T):
            raise ValueError("All inputs must have the same number of time points (T).")
        N_keypoints = tail_x.shape[1]
        if N_keypoints < 4:
            raise ValueError(
                "At least 4 points from swim bladder to tail tips are required for full tracking"
            )
        if tail_x.shape[1] != tail_y.shape[1]:
            raise ValueError(
                "tail_x and tail_y must have the same number of keypoints (N)."
            )
    @staticmethod
    def _validate_posture(head_x, head_y, head_yaw, tail_angle):
        T = len(head_x)
        if not (
            len(head_y) == T and head_yaw.shape[0] == T and tail_angle.shape[0] == T
        ):
            raise ValueError("All inputs must have the same number of time points (T).")
        N_keypoints = tail_angle.shape[1] + 1
        if N_keypoints < 4:
            raise ValueError(
                "At least 4 points from swim bladder to tail tips are required for full tracking"
            ) 
[docs]
class HeadTrackingData(TrackingData):
    """Container for head tracking data.
    Examples
    --------
    >>> from megabouts.tracking_data import load_example_data
    >>> df, fps, mm_per_unit = load_example_data('zebrabox_SLEAP')
    >>> head_x = df['mid_eye.x'].values * mm_per_unit
    >>> head_y = df['mid_eye.y'].values * mm_per_unit
    >>> swimbladder_x = df['swim_bladder.x'].values * mm_per_unit
    >>> swimbladder_y = df['swim_bladder.y'].values * mm_per_unit
    >>> tracking_data = HeadTrackingData.from_keypoints(
    ...     head_x=head_x, head_y=head_y,
    ...     swimbladder_x=swimbladder_x, swimbladder_y=swimbladder_y)
    >>> isinstance(tracking_data.traj_df, pd.DataFrame)
    True
    """
[docs]
    def __init__(self, head_x, head_y, head_yaw, swimbladder_x, swimbladder_y):
        self._head_x = head_x
        self._head_y = head_y
        self._head_yaw = head_yaw
        self._swimbladder_x = swimbladder_x
        self._swimbladder_y = swimbladder_y
        self.T = len(self._head_x) 
[docs]
    @classmethod
    def from_keypoints(cls, *, head_x, head_y, swimbladder_x, swimbladder_y):
        cls._validate_keypoints(head_x, head_y, swimbladder_x, swimbladder_y)
        tail_angle, head_yaw = compute_angles_from_keypoints(
            head_x, head_y, swimbladder_x[:, np.newaxis], swimbladder_y[:, np.newaxis]
        )
        return cls(head_x, head_y, head_yaw, swimbladder_x, swimbladder_y) 
[docs]
    @classmethod
    def from_posture(cls, *, head_x, head_y, head_yaw):
        cls._validate_posture(head_x, head_y, head_yaw)
        # tail_angle = np.zeros((len(head_x),10))
        # tail_x,tail_y = convert_tail_angle_to_keypoints(head_x, head_y, head_yaw, tail_angle, body_to_tail_mm=0.5, tail_to_tail_mm=0.32)
        # swimbladder_x,swimbladder_y = tail_x[:,0],tail_y[:,0]
        # return cls(head_x,head_y,head_yaw,swimbladder_x,swimbladder_y)
        return cls(head_x, head_y, head_yaw, None, None) 
    @property
    def traj_df(self):
        traj_df = pd.DataFrame(
            {"x": self._head_x, "y": self._head_y, "yaw": self._head_yaw}
        )
        return traj_df
    @staticmethod
    def _validate_keypoints(head_x, head_y, swimbladder_x, swimbladder_y):
        T = len(head_x)
        if not (
            len(head_y) == T and len(swimbladder_x) == T and len(swimbladder_y) == T
        ):
            raise ValueError("All inputs must have the same number of time points (T).")
    @staticmethod
    def _validate_posture(head_x, head_y, head_yaw):
        T = len(head_x)
        if not (len(head_y) == T and head_yaw.shape[0] == T):
            raise ValueError("All inputs must have the same number of time points (T).") 
[docs]
class TailTrackingData(TrackingData):
    """Container for tail tracking data.
    Examples
    --------
    >>> from megabouts.tracking_data import load_example_data
    >>> df, fps, mm_per_unit = load_example_data('HR_DLC')
    >>> tail_x = df["DLC_resnet50_Zebrafish"].loc[:, [(f"tail{i}", "x") for i in range(11)]].values * mm_per_unit
    >>> tail_y = df["DLC_resnet50_Zebrafish"].loc[:, [(f"tail{i}", "y") for i in range(11)]].values * mm_per_unit
    >>> tracking_data = TailTrackingData.from_keypoints(tail_x=tail_x, tail_y=tail_y)
    >>> isinstance(tracking_data.tail_df, pd.DataFrame)
    True
    """
[docs]
    def __init__(self, tail_x, tail_y, tail_angle):
        self._tail_x = tail_x
        self._tail_y = tail_y
        self._tail_angle = tail_angle
        self.T = len(self._tail_x) 
[docs]
    @classmethod
    def from_keypoints(cls, *, tail_x, tail_y):
        cls._validate_keypoints(tail_x, tail_y)
        if tail_x.shape[1] != 11:
            tail_x, tail_y = interpolate_tail_keypoint(tail_x, tail_y, 10)
        tail_angle, head_yaw = compute_angles_from_keypoints(
            tail_x[:, 0] + 0.5, tail_y[:, 0], tail_x, tail_y
        )
        return cls(tail_x, tail_y, tail_angle) 
[docs]
    @classmethod
    def from_posture(cls, *, tail_angle):
        cls._validate_posture(tail_angle)
        if tail_angle.shape[1] != 10:
            tail_angle = interpolate_tail_angle(tail_angle, 10)
        T = tail_angle.shape[0]
        head_x, head_y, head_yaw = np.zeros(T), np.zeros(T), np.zeros(T)
        tail_x, tail_y = convert_tail_angle_to_keypoints(
            head_x,
            head_y,
            head_yaw,
            tail_angle,
            body_to_tail_mm=0.0,
            tail_to_tail_mm=0.32,
        )
        return cls(tail_x, tail_y, tail_angle) 
    @property
    def tail_df(self):
        tail_df = pd.DataFrame(
            self._tail_angle, columns=[f"angle_{i}" for i in range(10)]
        )
        return tail_df
    @staticmethod
    def _validate_keypoints(tail_x, tail_y):
        if tail_x.shape[0] != tail_y.shape[0]:
            raise ValueError("All inputs must have the same number of time points (T).")
        N_keypoints = tail_x.shape[1]
        if N_keypoints < 4:
            raise ValueError(
                "At least 4 points from swim bladder to tail tips are required for tail tracking"
            )
        if tail_x.shape[1] != tail_y.shape[1]:
            raise ValueError(
                "tail_x and tail_y must have the same number of keypoints (N)."
            )
    @staticmethod
    def _validate_posture(tail_angle):
        N_keypoints = tail_angle.shape[1] + 1
        if N_keypoints < 4:
            raise ValueError(
                "At least 4 points from swim bladder to tail tips are required for full tracking"
            )