import numpy as np
from ..config.preprocessing_config import (
TailPreprocessingConfig,
TrajPreprocessingConfig,
)
from ..preprocessing.traj_preprocessing import TrajPreprocessing
from ..preprocessing.tail_preprocessing import TailPreprocessing
from ..config.segmentation_config import TailSegmentationConfig, TrajSegmentationConfig
from ..segmentation.segmentation import Segmentation
from ..classification.classification import TailBouts, BoutClassifier
from ..utils.data_utils import create_hierarchical_df
from ..pipeline.base_pipeline import Pipeline
[docs]
class EthogramHeadTracking:
"""Container for head tracking ethogram data.
Parameters
----------
segments : SegmentationResult
Segmentation results
bouts : TailBouts
Classified bout data
traj : TrajPreprocessingResult
Preprocessed trajectory data
"""
[docs]
def __init__(self, segments, bouts, traj):
self.segments = segments
self.df = self.compute_df(bouts, traj)
[docs]
def compute_df(self, bouts, traj):
head_x = traj.x
head_y = traj.y
head_angle = traj.yaw
vigor = traj.vigor
# is_swimming = segments.is_swimming
bout_idx = self.compute_time_series(
np.arange(len(bouts.category)), default_val=-1
)
bout_cat_ts = self.compute_time_series(bouts.category, default_val=-1)
bout_sign_ts = self.compute_time_series(bouts.sign)
data_info = [
("trajectory", "x", head_x),
("trajectory", "y", head_y),
("trajectory", "angle", head_angle),
("trajectory", "vigor", vigor),
# ('bout','is_swimming',is_swimming),
("bout", "id", bout_idx),
("bout", "cat", bout_cat_ts),
("bout", "sign", bout_sign_ts),
]
df = create_hierarchical_df(data_info)
return df
[docs]
def compute_time_series(self, x, default_val=0):
x_ts = np.full(self.segments.T, default_val)
for i, (on_, off_) in enumerate(zip(self.segments.onset, self.segments.offset)):
x_ts[on_:off_] = x[i]
return x_ts
[docs]
class EthogramFullTracking:
"""Container for full tracking ethogram data.
Parameters
----------
segments : SegmentationResult
Segmentation results
bouts : TailBouts
Classified bout data
tail : TailPreprocessingResult
Preprocessed tail data
traj : TrajPreprocessingResult
Preprocessed trajectory data
"""
[docs]
def __init__(self, segments, bouts, tail, traj):
self.segments = segments
self.df = self.compute_df(bouts, tail, traj)
[docs]
def compute_df(self, bouts, tail, traj):
tail_angle = tail.angle_smooth
vigor = tail.vigor
head_x = traj.x
head_y = traj.y
head_angle = traj.yaw
# is_swimming = segments.is_swimming
bout_idx = self.compute_time_series(
np.arange(len(bouts.category)), default_val=-1
)
bout_cat_ts = self.compute_time_series(bouts.category, default_val=-1)
bout_sign_ts = self.compute_time_series(bouts.sign)
data_info = [
("tail_angle", "segment", tail_angle),
("tail_angle", "vigor", vigor),
("trajectory", "x", head_x),
("trajectory", "y", head_y),
("trajectory", "angle", head_angle),
# ('bout','is_swimming',is_swimming),
("bout", "id", bout_idx),
("bout", "cat", bout_cat_ts),
("bout", "sign", bout_sign_ts),
]
df = create_hierarchical_df(data_info)
return df
[docs]
def compute_time_series(self, x, default_val=0):
x_ts = np.full(self.segments.T, default_val)
for i, (on_, off_) in enumerate(zip(self.segments.onset, self.segments.offset)):
x_ts[on_:off_] = x[i]
return x_ts
[docs]
class HeadTrackingPipeline(Pipeline):
"""Pipeline for processing freely swimming fish with head tracking only.
Parameters
----------
tracking_cfg : TrackingConfig
Configuration for tracking data
exclude_CS : bool, optional
Whether to exclude capture swim bouts, by default False
Examples
--------
>>> import pandas as pd
>>> from megabouts.tracking_data import TrackingConfig, HeadTrackingData, load_example_data
>>> df, fps, mm_per_unit = load_example_data('fulltracking_posture')
>>> head_x = df["head_x"].values * mm_per_unit
>>> head_y = df["head_y"].values * mm_per_unit
>>> head_yaw = df["head_angle"].values
>>> tracking_data = HeadTrackingData.from_posture(
... head_x=head_x, head_y=head_y, head_yaw=head_yaw
... )
>>> tracking_cfg = TrackingConfig(fps=fps, tracking='head_tracking')
>>> pipeline = HeadTrackingPipeline(tracking_cfg)
>>> ethogram, bouts, segments, traj = pipeline.run(tracking_data)
>>> isinstance(ethogram.df, pd.DataFrame)
True
"""
[docs]
def __init__(self, tracking_cfg, exclude_CS=False):
self.tracking_cfg = tracking_cfg
# self.logger = logging.getLogger(__name__)
# logging.basicConfig(level=logging.INFO)
# self.logger.info("Initializing FullTrackingPipeline...")
self.initialize_parameters_for_pipeline()
self.exclude_CS = exclude_CS
[docs]
def initialize_parameters_for_pipeline(self):
self.traj_preprocessing_cfg = TrajPreprocessingConfig(fps=self.tracking_cfg.fps)
self.traj_segmentation_cfg = TrajSegmentationConfig(fps=self.tracking_cfg.fps)
[docs]
def preprocess_traj(self, traj_df):
traj = TrajPreprocessing(self.traj_preprocessing_cfg).preprocess_traj_df(
traj_df
)
return traj
[docs]
def segment_traj(self, traj_vigor):
segmentation_function = Segmentation.from_config(self.traj_segmentation_cfg)
segments = segmentation_function.segment(traj_vigor)
return segments
[docs]
def classify_bouts(self, traj, segments):
# Include Nan:
x, y, yaw = traj.x_smooth, traj.y_smooth, traj.yaw_smooth
x[traj.no_tracking], y[traj.no_tracking], yaw[traj.no_tracking] = (
np.nan,
np.nan,
np.nan,
)
traj_array = segments.extract_traj_array(head_x=x, head_y=y, head_angle=yaw)
classifier = BoutClassifier(
self.tracking_cfg, self.traj_segmentation_cfg, exclude_CS=self.exclude_CS
)
classif_results = classifier.run_classification(traj_array=traj_array)
segments.set_HB1(classif_results["first_half_beat"])
traj_array = segments.extract_traj_array(
head_x=x, head_y=y, head_angle=yaw, align_to_onset=False
)
bouts = TailBouts(
segments=segments,
classif_results=classif_results,
tail_array=None,
traj_array=traj_array,
)
return bouts
# def compute_ethogram(self,tail_df,traj_df,segment_df,bouts_df):
# return segment_df
[docs]
def run(self, tracking_data):
# self.logger.info("Running FullTrackingPipeline...")
# self.logger.info("Preprocessing...")
traj = self.preprocess_traj(tracking_data.traj_df)
# self.logger.info("Segmentation...")
segments = self.segment_traj(traj.vigor)
# self.logger.info("Classification...")
bouts = self.classify_bouts(traj, segments)
ethogram = EthogramHeadTracking(segments, bouts, traj)
return ethogram, bouts, segments, traj
def __str__(self):
lin1 = (
f"Parameters are: traj_preprocessing_cfg: {self.traj_preprocessing_cfg}"
+ "\n"
)
lin2 = f"Parameters are: traj_segmentation_cfg: {self.traj_segmentation_cfg}"
return lin1 + lin2
def __repr__(self):
return self.__str__()
[docs]
class FullTrackingPipeline(Pipeline):
"""Pipeline for processing freely swimming fish with full tracking data.
Parameters
----------
tracking_cfg : TrackingConfig
Configuration for tracking data
exclude_CS : bool, optional
Whether to exclude capture swim bouts, by default False
Examples
--------
>>> import pandas as pd
>>> from megabouts.tracking_data import TrackingConfig, FullTrackingData, load_example_data
>>> df, fps, mm_per_unit = load_example_data('fulltracking_posture')
>>> head_x = df["head_x"].values * mm_per_unit
>>> head_y = df["head_y"].values * mm_per_unit
>>> head_yaw = df["head_angle"].values
>>> tail_angle = df.filter(like="tail_angle").values
>>> tracking_data = FullTrackingData.from_posture(
... head_x=head_x, head_y=head_y, head_yaw=head_yaw, tail_angle=tail_angle
... )
>>> tracking_cfg = TrackingConfig(fps=fps, tracking='full_tracking')
>>> pipeline = FullTrackingPipeline(tracking_cfg)
>>> ethogram, bouts, segments, tail, traj = pipeline.run(tracking_data)
>>> isinstance(ethogram.df, pd.DataFrame)
True
"""
[docs]
def __init__(self, tracking_cfg, exclude_CS=False):
self.tracking_cfg = tracking_cfg
# self.logger = logging.getLogger(__name__)
# logging.basicConfig(level=logging.INFO)
# self.logger.info("Initializing FullTrackingPipeline...")
self.initialize_parameters_for_pipeline()
self.exclude_CS = exclude_CS
[docs]
def initialize_parameters_for_pipeline(self):
self.tail_preprocessing_cfg = TailPreprocessingConfig(fps=self.tracking_cfg.fps)
self.traj_preprocessing_cfg = TrajPreprocessingConfig(fps=self.tracking_cfg.fps)
self.segmentation_cfg = TailSegmentationConfig(fps=self.tracking_cfg.fps)
[docs]
def preprocess_tail(self, tail_df):
tail = TailPreprocessing(self.tail_preprocessing_cfg).preprocess_tail_df(
tail_df
)
return tail
[docs]
def preprocess_traj(self, traj_df):
traj = TrajPreprocessing(self.traj_preprocessing_cfg).preprocess_traj_df(
traj_df
)
return traj
[docs]
def segment(self, vigor):
segmentation_function = Segmentation.from_config(self.segmentation_cfg)
segments = segmentation_function.segment(vigor)
return segments
[docs]
def classify_bouts(self, tail, traj, segments):
tail_array = segments.extract_tail_array(tail_angle=tail.angle_smooth)
traj_array = segments.extract_traj_array(
head_x=traj.x_smooth, head_y=traj.y_smooth, head_angle=traj.yaw_smooth
)
classifier = BoutClassifier(
self.tracking_cfg, self.segmentation_cfg, exclude_CS=self.exclude_CS
)
classif_results = classifier.run_classification(
tail_array=tail_array, traj_array=traj_array
)
segments.set_HB1(classif_results["first_half_beat"])
tail_array = segments.extract_tail_array(
tail_angle=tail.angle_smooth, align_to_onset=False
)
traj_array = segments.extract_traj_array(
head_x=traj.x_smooth,
head_y=traj.y_smooth,
head_angle=traj.yaw_smooth,
align_to_onset=False,
)
bouts = TailBouts(
segments=segments,
classif_results=classif_results,
tail_array=tail_array,
traj_array=traj_array,
)
return bouts
[docs]
def run(self, tracking_data):
# self.logger.info("Running FullTrackingPipeline...")
# self.logger.info("Preprocessing...")
tail = self.preprocess_tail(tracking_data.tail_df)
traj = self.preprocess_traj(tracking_data.traj_df)
# self.logger.info("Segmentation...")
if isinstance(self.segmentation_cfg, TailSegmentationConfig):
segments = self.segment(tail.vigor)
elif isinstance(self.segmentation_cfg, TrajSegmentationConfig):
segments = self.segment(traj.vigor)
else:
raise ValueError(
"segmentation_cfg should be an instance of TailSegmentationConfig or TrajSegmentationConfig"
)
# self.logger.info("Classification...")
bouts = self.classify_bouts(tail, traj, segments)
ethogram = EthogramFullTracking(segments, bouts, tail, traj)
return ethogram, bouts, segments, tail, traj
def __str__(self):
lin1 = (
f"Parameters are: tail_preprocessing_cfg: {self.tail_preprocessing_cfg}"
+ "\n"
)
lin2 = (
f"Parameters are: traj_preprocessing_cfg: {self.traj_preprocessing_cfg}"
+ "\n"
)
lin3 = f"Parameters are: tail_segmentation_cfg: {self.segmentation_cfg}"
return lin1 + lin2 + lin3
def __repr__(self):
return self.__str__()