megabouts.classification#

class megabouts.classification.transformer_network.BoutsDataset(X, t_sample, sampling_mask, device=None, precision=None)[source]#

Bases: Dataset

Dataset class for bout data with continuous positional encoding.

Parameters:
  • X (np.ndarray) – Input features, shape (n_bouts, bout_duration, n_features)

  • t_sample (np.ndarray) – Time points for each sample

  • sampling_mask (np.ndarray) – Boolean mask for valid samples

  • device (torch.device, optional) – Device to store tensors on

  • precision (torch.dtype, optional) – Precision of tensors

__init__(X, t_sample, sampling_mask, device=None, precision=None)[source]#
class megabouts.classification.transformer_network.ContinuousPositionalEncoding(d_model, max_seq_length)[source]#

Bases: Module

__init__(d_model, max_seq_length)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, t)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class megabouts.classification.transformer_network.TransAm(mapping_label_to_sublabel, feature_size=64, num_layers=3, dropout=0.0, nhead=8)[source]#

Bases: Module

Transformer model for bout classification.

Parameters:
  • mapping_label_to_sublabel (dict) – Mapping from main labels to sublabels

  • feature_size (int, optional) – Size of feature embedding, by default 64

  • num_layers (int, optional) – Number of transformer layers, by default 3

  • dropout (float, optional) – Dropout rate, by default 0.0

  • nhead (int, optional) – Number of attention heads, by default 8

__init__(mapping_label_to_sublabel, feature_size=64, num_layers=3, dropout=0.0, nhead=8)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(input, t, mask)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class megabouts.classification.classification.TailBouts(*, segments, classif_results, tail_array=None, traj_array=None)[source]#

Bases: object

Container for classified bout data.

Parameters:
  • segments (SegmentationResult) – Segmentation results containing bout timing information

  • classif_results (dict) – Classification results containing: - ‘cat’: bout categories (numpy array) - ‘subcat’: bout subcategories (numpy array) - ‘sign’: bout direction (-1 or 1) (numpy array) - ‘proba’: classification probabilities (numpy array) - ‘first_half_beat’: frame indices of first half-beats (numpy array)

  • tail_array (np.ndarray, optional) – Tail angles for each bout, shape (n_bouts, n_segments, bout_duration)

  • traj_array (np.ndarray, optional) – Trajectory data for each bout, shape (n_bouts, 3, bout_duration)

__init__(*, segments, classif_results, tail_array=None, traj_array=None)[source]#
class megabouts.classification.classification.BoutClassifier(tracking_cfg: TrackingConfig, segmentation_cfg: SegmentationConfig, exclude_CS: bool = False, device=None, precision=None)[source]#

Bases: object

Classifier for zebrafish swimming bouts.

Parameters:
  • tracking_cfg (TrackingConfig) – Configuration for tracking data

  • segmentation_cfg (SegmentationConfig) – Configuration for bout segmentation

  • exclude_CS (bool, optional) – Whether to exclude capture swim bouts, by default False

  • device (torch.device, optional) – Device to run model on, by default None (auto-select)

  • precision (torch.dtype, optional) – Model precision, by default None (auto-select)

Examples

>>> import numpy as np
>>> from megabouts.tracking_data import TrackingConfig
>>> from megabouts.config.segmentation_config import TailSegmentationConfig
>>> # Initialize classifier
>>> tracking_cfg = TrackingConfig(fps=100, tracking='full_tracking')
>>> segmentation_cfg = TailSegmentationConfig(fps=100)
>>> bout_duration = segmentation_cfg.bout_duration
>>> # Create fake bout data (10 bouts, 7 tail segments, bout_duration frames)
>>> tail_array = np.zeros((10, 7, bout_duration))
>>> traj_array = np.zeros((10, 3, bout_duration))  # x, y, angle
>>> classifier = BoutClassifier(tracking_cfg, segmentation_cfg)
>>> results = classifier.run_classification(tail_array=tail_array, traj_array=traj_array)
>>> isinstance(results, dict)
True
>>> "cat" in results and "sign" in results
True
__init__(tracking_cfg: TrackingConfig, segmentation_cfg: SegmentationConfig, exclude_CS: bool = False, device=None, precision=None)[source]#
load_classifier()[source]#
prepare_tensor_input(**kwargs)[source]#
extract_bouts(**kwargs)[source]#
extract_bouts_full_tracking(*, tail_array, traj_array)[source]#
extract_bouts_head_tracking(*, traj_array)[source]#
compute_sampling_input(num_samples)[source]#
run_classification(**kwargs)[source]#
process_results(results, num_samples)[source]#
filter_logit(logit_label, logit_sublabel)[source]#