Full tracking pipeline#

In this notebook, we will walk through the process of running the full tracking pipeline. The pipeline handles preprocessing of tracking data and segmentation/classification of tail bouts.

  • Loading dependencies

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from cycler import cycler

from megabouts.tracking_data import TrackingConfig, FullTrackingData, load_example_data
from megabouts.pipeline import FullTrackingPipeline
from megabouts.utils import (
    bouts_category_name,
    bouts_category_name_short,
    bouts_category_color,
    cmp_bouts,
)
  • Loading data into the FullTrackingData:

df_recording, fps, mm_per_unit = load_example_data("fulltracking_posture")

tracking_cfg = TrackingConfig(fps=fps, tracking="full_tracking")

head_x = df_recording["head_x"].values * mm_per_unit
head_y = df_recording["head_y"].values * mm_per_unit
head_yaw = df_recording["head_angle"].values
tail_angle = df_recording.filter(like="tail_angle").values

tracking_data = FullTrackingData.from_posture(
    head_x=head_x, head_y=head_y, head_yaw=head_yaw, tail_angle=tail_angle
)
  • Define the pipeline:

pipeline = FullTrackingPipeline(tracking_cfg, exclude_CS=True)

The pipeline has a default configuration, but we can change it if needed, for instance let’s change the segmentation threshold:

pipeline.segmentation_cfg.threshold = 50
  • Run the pipeline:

pipeline.tail_preprocessing_cfg
TailPreprocessingConfig(fps=700, limit_na_ms=100, num_pcs=4, savgol_window_ms=15, baseline_method='median', baseline_params={'fps': 700, 'half_window': 350}, tail_speed_filter_ms=100, tail_speed_boxcar_filter_ms=14)
pipeline.tail_preprocessing_cfg.savgol_window
11
ethogram, bouts, segments, tail, traj = pipeline.run(tracking_data)
  • We can check the segmentation and first tail beat detection:

Hide code cell source
fig, ax = plt.subplots(2, 1, figsize=(15, 5), sharex=True)

x = tracking_data._tail_angle[:, 7]
ax[0].plot(x)
ax[0].plot(segments.onset, x[segments.onset], "x", color="green")
ax[0].plot(segments.offset, x[segments.offset], "x", color="red")
ax[0].plot(segments.HB1, x[segments.HB1], "x", color="blue")
ax[0].set_ylim(-4, 4)

x = tail.vigor
ax[1].plot(x)
ax[1].plot(segments.onset, x[segments.onset], "x", color="green")
ax[1].plot(segments.offset, x[segments.offset], "x", color="red")

t = np.arange(tracking_data.T) / tracking_cfg.fps
IdSt = 386502  # np.random.randint(tracking_data.T)
Duration = 10 * tracking_cfg.fps
ax[1].set_xlim(IdSt, IdSt + Duration)

fig, ax = plt.subplots(1, 1, figsize=(10, 3))
x = tail.df.angle_smooth.iloc[:, 7]
ax.plot(t, x, color="tab:grey", lw=1)
ax.plot(t[segments.onset], x[segments.onset], "x", color="tab:green", label="onset")
ax.plot(t[segments.offset], x[segments.offset], "x", color="tab:red", label="offset")
ax.plot(
    t[segments.HB1], x[segments.HB1], "x", color="tab:blue", label="first tail beat"
)
ax.set(
    **{
        "title": "segmentation",
        "xlim": (t[IdSt], t[IdSt + Duration]),
        "ylim": (-4, 4),
        "ylabel": "tail angle (rad)",
        "xlabel": "time (s)",
    }
)
ax.legend()
plt.show()
../_images/15383dd428f5e2a2c304bc70c6cba586ba6df5f00253c2c82e4c8418808e474b.png ../_images/f110695becb15faa2a5b8d931c0395e1f213408aa217c1d6fa3cc1be85554419.png
  • Let’s display the bouts classified with a probability greater than 0.5:

Hide code cell source
id_b = np.unique(bouts.df.label.category[bouts.df.label.proba > 0.5]).astype("int")

fig, ax = plt.subplots(facecolor="white", figsize=(25, 4))

ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
G = gridspec.GridSpec(1, len(id_b))
ax0 = {}
for i, b in enumerate(id_b):
    ax0 = plt.subplot(G[i])
    ax0.set_title(bouts_category_name_short[b], fontsize=15)
    for i_sg, sg in enumerate([1, -1]):
        id = bouts.df[
            (bouts.df.label.category == b)
            & (bouts.df.label.sign == sg)
            & (bouts.df.label.proba > 0.5)
        ].index
        if len(id) > 0:
            ax0.plot(sg * bouts.tail[id, 7, :].T, color="k", alpha=0.3)
        ax0.set_xlim(0, pipeline.segmentation_cfg.bout_duration)
        ax0.set_ylim(-4, 4)
        ax0.set_xticks([])
        ax0.set_yticks([])
        for sp in ["top", "bottom", "left", "right"]:
            ax0.spines[sp].set_color(bouts_category_color[b])
            ax0.spines[sp].set_linewidth(5)

plt.show()
../_images/fabeca031a70c10c51e815d8b16a61bae20f4a7df58878a2ef9c8b809b5960d7.png
  • Finally, we can display a sample ethogram:

Hide code cell source
IdSt = 161011
T = 3
Duration = T * tracking_cfg.fps
IdEd = IdSt + Duration
t = np.arange(Duration) / tracking_cfg.fps


fig = plt.figure(facecolor="white", figsize=(15, 5), constrained_layout=True)
G = gridspec.GridSpec(2, 1, height_ratios=[1, 0.2], hspace=0.5, figure=fig)
ax = plt.subplot(G[0, 0])
blue_cycler = cycler(color=plt.cm.Blues(np.linspace(0.2, 0.9, 10)))
ax.set_prop_cycle(blue_cycler)

ax.plot(t, ethogram.df["tail_angle"].values[IdSt:IdEd, :7], lw=1)
ax.set_ylim(-4, 4)
ax.set_xlim(0, T)

ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.get_yaxis().tick_left()
ax.get_xaxis().set_ticks([])
ax.set_ylabel("tail angle (rad)", rotation=0, labelpad=100)

ax = plt.subplot(G[1, 0])
ax.imshow(
    ethogram.df[("bout", "cat")].values[IdSt:IdEd].T,
    cmap=cmp_bouts,
    aspect="auto",
    vmin=0,
    vmax=12,
    interpolation="nearest",
    extent=(0, T, 0, 1),
)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.get_yaxis().tick_left()
ax.get_xaxis().set_ticks([])
ax.get_yaxis().set_ticks([])
ax.set_xlim(0, T)
ax.set_ylim(0, 1.1)

id_b = np.unique(ethogram.df[("bout", "id")].values[IdSt:IdEd]).astype("int")
id_b = id_b[id_b > -1]
for i in id_b:
    on_ = bouts.df.iloc[i][("location", "onset")]
    b = bouts.df.iloc[i][("label", "category")]
    ax.text((on_ - IdSt) / tracking_cfg.fps, 1.1, bouts_category_name[int(b)])

ax.set_ylabel("bout category", rotation=0, labelpad=100)
plt.show()
../_images/028d8e7a58912b1c1021b8e2e3f7d46fa3037422631896f5954cfbab43f59bb6.png