megabouts.classification#
- class megabouts.classification.transformer_network.BoutsDataset(X, t_sample, sampling_mask, device=None, precision=None)[source]#
Bases:
Dataset
Dataset class for bout data with continuous positional encoding.
- Parameters:
X (np.ndarray) – Input features, shape (n_bouts, bout_duration, n_features)
t_sample (np.ndarray) – Time points for each sample
sampling_mask (np.ndarray) – Boolean mask for valid samples
device (torch.device, optional) – Device to store tensors on
precision (torch.dtype, optional) – Precision of tensors
- class megabouts.classification.transformer_network.ContinuousPositionalEncoding(d_model, max_seq_length)[source]#
Bases:
Module
- __init__(d_model, max_seq_length)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x, t)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class megabouts.classification.transformer_network.TransAm(mapping_label_to_sublabel, feature_size=64, num_layers=3, dropout=0.0, nhead=8)[source]#
Bases:
Module
Transformer model for bout classification.
- Parameters:
mapping_label_to_sublabel (dict) – Mapping from main labels to sublabels
feature_size (int, optional) – Size of feature embedding, by default 64
num_layers (int, optional) – Number of transformer layers, by default 3
dropout (float, optional) – Dropout rate, by default 0.0
nhead (int, optional) – Number of attention heads, by default 8
- __init__(mapping_label_to_sublabel, feature_size=64, num_layers=3, dropout=0.0, nhead=8)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(input, t, mask)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class megabouts.classification.classification.TailBouts(*, segments, classif_results, tail_array=None, traj_array=None)[source]#
Bases:
object
Container for classified bout data.
- Parameters:
segments (SegmentationResult) – Segmentation results containing bout timing information
classif_results (dict) – Classification results containing: - ‘cat’: bout categories (numpy array) - ‘subcat’: bout subcategories (numpy array) - ‘sign’: bout direction (-1 or 1) (numpy array) - ‘proba’: classification probabilities (numpy array) - ‘first_half_beat’: frame indices of first half-beats (numpy array)
tail_array (np.ndarray, optional) – Tail angles for each bout, shape (n_bouts, n_segments, bout_duration)
traj_array (np.ndarray, optional) – Trajectory data for each bout, shape (n_bouts, 3, bout_duration)
- class megabouts.classification.classification.BoutClassifier(tracking_cfg: TrackingConfig, segmentation_cfg: SegmentationConfig, exclude_CS: bool = False, device=None, precision=None)[source]#
Bases:
object
Classifier for zebrafish swimming bouts.
- Parameters:
tracking_cfg (TrackingConfig) – Configuration for tracking data
segmentation_cfg (SegmentationConfig) – Configuration for bout segmentation
exclude_CS (bool, optional) – Whether to exclude capture swim bouts, by default False
device (torch.device, optional) – Device to run model on, by default None (auto-select)
precision (torch.dtype, optional) – Model precision, by default None (auto-select)
Examples
>>> import numpy as np >>> from megabouts.tracking_data import TrackingConfig >>> from megabouts.config.segmentation_config import TailSegmentationConfig >>> # Initialize classifier >>> tracking_cfg = TrackingConfig(fps=100, tracking='full_tracking') >>> segmentation_cfg = TailSegmentationConfig(fps=100) >>> bout_duration = segmentation_cfg.bout_duration >>> # Create fake bout data (10 bouts, 7 tail segments, bout_duration frames) >>> tail_array = np.zeros((10, 7, bout_duration)) >>> traj_array = np.zeros((10, 3, bout_duration)) # x, y, angle >>> classifier = BoutClassifier(tracking_cfg, segmentation_cfg) >>> results = classifier.run_classification(tail_array=tail_array, traj_array=traj_array) >>> isinstance(results, dict) True >>> "cat" in results and "sign" in results True
- __init__(tracking_cfg: TrackingConfig, segmentation_cfg: SegmentationConfig, exclude_CS: bool = False, device=None, precision=None)[source]#