{ "cells": [ { "cell_type": "markdown", "id": "5a232aef", "metadata": {}, "source": [ "# Trajectory Preprocessing\n", "\n", "**The following notebook illustrate the `TrajPreprocessing` class how to run the preprocessing steps.**\n", "- **Several preprocessing steps are available for the tail angle**:\n", " - Interpolating missing values\n", " - Apply 1€ filter\n", "\n", "- **The kinematic vigor is also computed from the speed and will also be useful for segmentation into bouts**:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Loading dependencies" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "from megabouts.tracking_data import TrackingConfig, FullTrackingData, load_example_data\n", "from megabouts.config import TrajPreprocessingConfig\n", "from megabouts.preprocessing import TrajPreprocessing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* TrackingConfig and TrackingData similar to [tutorial_Loading_Data](./Loading_Data.ipynb)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Load data and set tracking configuration\n", "df_recording, fps, mm_per_unit = load_example_data(\"SLEAP_fulltracking\")\n", "tracking_cfg = TrackingConfig(fps=fps, tracking=\"full_tracking\")\n", "\n", "# List of keypoints\n", "keypoints = [\"left_eye\", \"right_eye\", \"tail0\", \"tail1\", \"tail2\", \"tail3\", \"tail4\"]\n", "\n", "# Place NaN where the score is below a threshold\n", "thresh_score = 0.0\n", "for kps in keypoints:\n", " score_below_thresh = df_recording[\"instance.score\"] < thresh_score\n", " df_recording.loc[\n", " score_below_thresh | (df_recording[f\"{kps}.score\"] < thresh_score),\n", " [f\"{kps}.x\", f\"{kps}.y\"],\n", " ] = np.nan\n", "\n", "# Compute head and tail coordinates and convert to mm\n", "head_x = ((df_recording[\"left_eye.x\"] + df_recording[\"right_eye.x\"]) / 2) * mm_per_unit\n", "head_y = ((df_recording[\"left_eye.y\"] + df_recording[\"right_eye.y\"]) / 2) * mm_per_unit\n", "tail_x = df_recording[[f\"tail{i}.x\" for i in range(5)]].values * mm_per_unit\n", "tail_y = df_recording[[f\"tail{i}.y\" for i in range(5)]].values * mm_per_unit\n", "\n", "# Create FullTrackingData object\n", "tracking_data = FullTrackingData.from_keypoints(\n", " head_x=head_x.values, head_y=head_y.values, tail_x=tail_x, tail_y=tail_y\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run Preprocessing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Define preprocessing config" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "traj_preprocessing_cfg = TrajPreprocessingConfig(fps=tracking_cfg.fps)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Apply the trajectory preprocessing" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "traj_df_input = tracking_data.traj_df\n", "traj = TrajPreprocessing(traj_preprocessing_cfg).preprocess_traj_df(traj_df_input)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* traj.df contains information about the trajectory, the smooth values as well as the kinematic vigor:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | x | \n", "y | \n", "yaw | \n", "x_smooth | \n", "y_smooth | \n", "yaw_smooth | \n", "axial_speed | \n", "lateral_speed | \n", "yaw_speed | \n", "vigor | \n", "no_tracking | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "20.381254 | \n", "20.996582 | \n", "1.491111 | \n", "20.381254 | \n", "20.996582 | \n", "1.491111 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "0.0 | \n", "0.0 | \n", "
1 | \n", "20.381340 | \n", "20.996807 | \n", "1.491013 | \n", "20.381277 | \n", "20.996642 | \n", "1.491085 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "0.0 | \n", "0.0 | \n", "
2 | \n", "20.381337 | \n", "20.997020 | \n", "1.491086 | \n", "20.381293 | \n", "20.996742 | \n", "1.491085 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "0.0 | \n", "0.0 | \n", "
3 | \n", "20.418551 | \n", "20.997656 | \n", "1.430916 | \n", "20.391138 | \n", "20.996983 | \n", "1.475189 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "0.0 | \n", "0.0 | \n", "
4 | \n", "20.418671 | \n", "20.997872 | \n", "1.430749 | \n", "20.399294 | \n", "20.997219 | \n", "1.461210 | \n", "0.293531 | \n", "-1.616157 | \n", "-2.659344 | \n", "0.0 | \n", "0.0 | \n", "