{ "cells": [ { "cell_type": "markdown", "id": "5a232aef", "metadata": {}, "source": [ "# Tail Angle Preprocessing\n", "\n", "**The following notebook illustrate the `TailPreprocessing` class how to run the different preprocessing steps.**\n", "- **Several preprocessing steps are used for the tail angle**:\n", " - Interpolating missing values\n", " - PCA denoising of posture using 'eigen-fish'\n", " - Savgol filter for time series smoothing\n", " - Baseline correction\n", "\n", "- **The tail vigor is also computed and will be useful for detecting tail bouts**\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Loading dependencies" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "from megabouts.tracking_data import TrackingConfig, FullTrackingData, load_example_data\n", "from megabouts.config import TailPreprocessingConfig\n", "from megabouts.preprocessing import TailPreprocessing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* TrackingConfig and TrackingData similar to [tutorial_Loading_Data](./Loading_Data.ipynb), here we use the dataset from the poorly trained SLEAP model to underlie the effect of smoothing." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Load data and set tracking configuration\n", "df_recording, fps, mm_per_unit = load_example_data(\"SLEAP_fulltracking\")\n", "tracking_cfg = TrackingConfig(fps=fps, tracking=\"full_tracking\")\n", "\n", "# List of keypoints\n", "keypoints = [\"left_eye\", \"right_eye\", \"tail0\", \"tail1\", \"tail2\", \"tail3\", \"tail4\"]\n", "\n", "# Place NaN where the score is below a threshold\n", "thresh_score = 0.0\n", "for kps in keypoints:\n", " score_below_thresh = df_recording[\"instance.score\"] < thresh_score\n", " df_recording.loc[\n", " score_below_thresh | (df_recording[f\"{kps}.score\"] < thresh_score),\n", " [f\"{kps}.x\", f\"{kps}.y\"],\n", " ] = np.nan\n", "\n", "# Compute head and tail coordinates and convert to mm\n", "head_x = ((df_recording[\"left_eye.x\"] + df_recording[\"right_eye.x\"]) / 2) * mm_per_unit\n", "head_y = ((df_recording[\"left_eye.y\"] + df_recording[\"right_eye.y\"]) / 2) * mm_per_unit\n", "tail_x = df_recording[[f\"tail{i}.x\" for i in range(5)]].values * mm_per_unit\n", "tail_y = df_recording[[f\"tail{i}.y\" for i in range(5)]].values * mm_per_unit\n", "\n", "# Create FullTrackingData object\n", "tracking_data = FullTrackingData.from_keypoints(\n", " head_x=head_x.values, head_y=head_y.values, tail_x=tail_x, tail_y=tail_y\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run Preprocessing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Define preprocessing config using default:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "TailPreprocessingConfig(fps=350, limit_na_ms=100, num_pcs=4, savgol_window_ms=15, baseline_method='median', baseline_params={'fps': 350, 'half_window': 175}, tail_speed_filter_ms=100, tail_speed_boxcar_filter_ms=14)\n" ] } ], "source": [ "tail_preprocessing_cfg = TailPreprocessingConfig(fps=tracking_cfg.fps)\n", "print(tail_preprocessing_cfg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Alternatively the tail preprocessing parameters can be set to custom values using:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "tail_preprocessing_cfg = TailPreprocessingConfig(\n", " fps=tracking_cfg.fps,\n", " num_pcs=3,\n", " savgol_window_ms=30,\n", " tail_speed_filter_ms=100,\n", " tail_speed_boxcar_filter_ms=14,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Here we change the median filter for computing the tail baseline to 200 frame width:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "tail_preprocessing_cfg.baseline_params[\"half_window\"] = 100" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Let's run the preprocessing pipeline:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "tail_df_input = tracking_data.tail_df\n", "tail = TailPreprocessing(tail_preprocessing_cfg).preprocess_tail_df(tail_df_input)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* tail.df contains information about the raw angle, the baseline, the smooth angle as well as the vigor of the tail oscillations:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
angle...angle_smoothno_trackingvigor
segments...segments
0123456789...23456789
00.122614-0.015221-0.098553-0.126124-0.098301-0.0159370.0620220.057392-0.035583-0.215341...-0.100224-0.096774-0.067329-0.0181510.0561550.058781-0.090561-0.172721FalseNaN
10.122639-0.015076-0.098353-0.125934-0.098185-0.0159530.0619690.057438-0.035308-0.214723...-0.086259-0.081402-0.051451-0.0020930.0699390.071179-0.076765-0.157507FalseNaN
20.122712-0.015160-0.098521-0.126112-0.098298-0.0159270.0620900.057466-0.035575-0.215469...-0.077366-0.071614-0.0413470.0081230.0786980.079057-0.067974-0.147809FalseNaN
30.1829270.045016-0.038378-0.065993-0.0381980.0441700.1222670.1177080.024690-0.155223...-0.073543-0.067412-0.0370170.0124940.0824300.082415-0.064188-0.143629FalseNaN
40.1828830.045139-0.038150-0.065731-0.0379680.0443040.1223330.1177860.024859-0.154890...-0.074792-0.068794-0.0384620.0110220.0811370.081254-0.065407-0.144965FalseNaN
\n", "

5 rows × 32 columns

\n", "
" ], "text/plain": [ " angle \\\n", " segments \n", " 0 1 2 3 4 5 6 \n", "0 0.122614 -0.015221 -0.098553 -0.126124 -0.098301 -0.015937 0.062022 \n", "1 0.122639 -0.015076 -0.098353 -0.125934 -0.098185 -0.015953 0.061969 \n", "2 0.122712 -0.015160 -0.098521 -0.126112 -0.098298 -0.015927 0.062090 \n", "3 0.182927 0.045016 -0.038378 -0.065993 -0.038198 0.044170 0.122267 \n", "4 0.182883 0.045139 -0.038150 -0.065731 -0.037968 0.044304 0.122333 \n", "\n", " ... angle_smooth \\\n", " ... segments \n", " 7 8 9 ... 2 3 4 \n", "0 0.057392 -0.035583 -0.215341 ... -0.100224 -0.096774 -0.067329 \n", "1 0.057438 -0.035308 -0.214723 ... -0.086259 -0.081402 -0.051451 \n", "2 0.057466 -0.035575 -0.215469 ... -0.077366 -0.071614 -0.041347 \n", "3 0.117708 0.024690 -0.155223 ... -0.073543 -0.067412 -0.037017 \n", "4 0.117786 0.024859 -0.154890 ... -0.074792 -0.068794 -0.038462 \n", "\n", " no_tracking vigor \n", " \n", " 5 6 7 8 9 \n", "0 -0.018151 0.056155 0.058781 -0.090561 -0.172721 False NaN \n", "1 -0.002093 0.069939 0.071179 -0.076765 -0.157507 False NaN \n", "2 0.008123 0.078698 0.079057 -0.067974 -0.147809 False NaN \n", "3 0.012494 0.082430 0.082415 -0.064188 -0.143629 False NaN \n", "4 0.011022 0.081137 0.081254 -0.065407 -0.144965 False NaN \n", "\n", "[5 rows x 32 columns]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tail.df.head(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* We can visualize the result of preprocessing:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "tags": [ "hide-input" ] }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from cycler import cycler\n", "\n", "blue_cycler = cycler(color=plt.cm.Blues(np.linspace(0.2, 0.9, 10)))\n", "\n", "t = np.arange(tracking_data.T) / tracking_cfg.fps\n", "IdSt = 140612 # np.random.randint(tracking_data.T)\n", "Duration = 150\n", "t_win = t[IdSt : IdSt + Duration] - t[IdSt]\n", "\n", "# Prepare the data, titles, and subtitles\n", "angle_data = [tail.angle, tail.angle_smooth]\n", "subtitles = [\"Raw Data\", \"Smoothed Data\"]\n", "\n", "# Create subplots\n", "fig, ax = plt.subplots(2, 1, figsize=(10, 7), sharex=True)\n", "\n", "# Set a main title for the figure\n", "fig.suptitle(\"Tail Preprocessing\", fontsize=16)\n", "\n", "# Loop over the axes, data, and subtitles\n", "for axis, data, subtitle in zip(ax, angle_data, subtitles):\n", " axis.set_prop_cycle(blue_cycler)\n", " axis.plot(t_win, data[IdSt : IdSt + Duration, :7])\n", " axis.set(ylabel=\"Angle (rad)\", ylim=(-4, 4))\n", " axis.set_title(subtitle, fontsize=12)\n", "\n", "# Set x-label for the last subplot\n", "ax[-1].set_xlabel(\"Time (s)\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> ### 📝 Note\n", "> Smoothing the tail tracking data is optional for classifying tail bouts. The transformer model was trained on raw tracking data, so it can handle unsmoothed input just as well." ] } ], "metadata": { "kernelspec": { "display_name": "megabouts_dev", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 2 }