Source code for megabouts.pipeline.head_restrained_pipeline
from ..config.preprocessing_config import TailPreprocessingConfig
from ..preprocessing.tail_preprocessing import TailPreprocessing
from ..config.segmentation_config import TailSegmentationConfig
from ..segmentation.segmentation import Segmentation
from ..config.sparse_coding_config import SparseCodingConfig
from ..sparse_coding.sparse_coding import SparseCoding
from ..pipeline.base_pipeline import Pipeline
[docs]
class HeadRestrainedPipeline(Pipeline):
"""Pipeline for processing head-restrained fish data.
Parameters
----------
tracking_cfg : TrackingConfig
Configuration for tracking data
Examples
--------
>>> import pandas as pd
>>> from megabouts.tracking_data import TrackingConfig, TailTrackingData, load_example_data
>>> df, fps, mm_per_unit = load_example_data('HR_DLC')
>>> tail_x = df["DLC_resnet50_Zebrafish"].loc[:, [(f"tail{i}", "x") for i in range(11)]].values * mm_per_unit
>>> tail_y = df["DLC_resnet50_Zebrafish"].loc[:, [(f"tail{i}", "y") for i in range(11)]].values * mm_per_unit
>>> tracking_data = TailTrackingData.from_keypoints(tail_x=tail_x, tail_y=tail_y)
>>> tracking_cfg = TrackingConfig(fps=fps, tracking='tail_tracking')
>>> pipeline = HeadRestrainedPipeline(tracking_cfg)
>>> sparse_coding_result, segments, tail = pipeline.run(tracking_data)
>>> isinstance(sparse_coding_result.df, pd.DataFrame)
True
"""
[docs]
def __init__(self, tracking_cfg):
self.tracking_cfg = tracking_cfg
self.initialize_parameters_for_pipeline()
[docs]
def initialize_parameters_for_pipeline(self):
self.tail_preprocessing_cfg = TailPreprocessingConfig(
fps=self.tracking_cfg.fps,
baseline_method="whittaker",
baseline_params={"lmbda": 1e5, "half_window": 200},
)
self.tail_segmentation_cfg = TailSegmentationConfig(fps=self.tracking_cfg.fps)
self.sparse_coding_cfg = SparseCodingConfig(fps=self.tracking_cfg.fps)
[docs]
def preprocess_tail(self, tail_df):
"""Preprocess tail angle data.
Parameters
----------
tail_df : pd.DataFrame
DataFrame containing tail angle data
Returns
-------
TailPreprocessingResult
Preprocessed tail data
"""
tail = TailPreprocessing(self.tail_preprocessing_cfg).preprocess_tail_df(
tail_df
)
return tail
[docs]
def segment_tail(self, tail_vigor):
"""Segment tail movement into bouts.
Parameters
----------
tail_vigor : np.ndarray
Tail vigor signal
Returns
-------
SegmentationResult
Detected segments
"""
segmentation_function = Segmentation.from_config(self.tail_segmentation_cfg)
segments = segmentation_function.segment(tail_vigor)
return segments
[docs]
def compute_sparse_coding(self, tail_angle):
"""Compute sparse coding of tail angles.
Parameters
----------
tail_angle : np.ndarray
Tail angle data
Returns
-------
SparseCodingResult
Sparse coding results
"""
sparse_coding = SparseCoding(self.sparse_coding_cfg)
sparse_coding_result = sparse_coding.sparse_code_tail_angle(tail_angle)
return sparse_coding_result
[docs]
def run(self, tracking_data):
tail = self.preprocess_tail(tracking_data.tail_df)
segments = self.segment_tail(tail.vigor)
sparse_coding_result = self.compute_sparse_coding(tail.angle_smooth)
return sparse_coding_result, segments, tail