Source code for megabouts.classification.transformer_network

import torch
from torch.utils.data import Dataset
from torch import nn
import numpy as np


[docs] class BoutsDataset(Dataset): """Dataset class for bout data with continuous positional encoding. Parameters ---------- X : np.ndarray Input features, shape (n_bouts, bout_duration, n_features) t_sample : np.ndarray Time points for each sample sampling_mask : np.ndarray Boolean mask for valid samples device : torch.device, optional Device to store tensors on precision : torch.dtype, optional Precision of tensors """
[docs] def __init__(self, X, t_sample, sampling_mask, device=None, precision=None): device = ( device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu") ) precision = ( precision if precision else (torch.float64 if self.device.type == "cuda" else torch.float32) ) self.X = torch.from_numpy(np.swapaxes(X, 1, 2)).to(dtype=precision).to(device) self.t_sample = ( torch.from_numpy(t_sample[:, :, np.newaxis]).to(dtype=precision).to(device) ) self.sampling_mask = ( torch.from_numpy(sampling_mask).to(dtype=torch.bool).to(device) )
def __len__(self): return self.X.shape[0] def __getitem__(self, idx): return self.X[idx, :, :], self.t_sample[idx, :], self.sampling_mask[idx, :]
[docs] class ContinuousPositionalEncoding(nn.Module):
[docs] def __init__(self, d_model, max_seq_length): super(ContinuousPositionalEncoding, self).__init__() self.d, self.T = d_model, max_seq_length denominators = torch.pow( 10000, 2 * torch.arange(0, self.d // 2) / self.d ) # 10000^(2i/d_model), i is the index of embedding self.register_buffer("denominators", denominators)
[docs] def forward(self, x, t): pe = torch.zeros((x.shape[0], self.T, self.d), device=x.device) pe[:, :, 0::2] = torch.sin(t / self.denominators) # sin(pos/10000^(2i/d_model)) pe[:, :, 1::2] = torch.cos(t / self.denominators) # cos(pos/10000^(2i/d_model)) return x + pe
[docs] class TransAm(nn.Module): """Transformer model for bout classification. Parameters ---------- mapping_label_to_sublabel : dict Mapping from main labels to sublabels feature_size : int, optional Size of feature embedding, by default 64 num_layers : int, optional Number of transformer layers, by default 3 dropout : float, optional Dropout rate, by default 0.0 nhead : int, optional Number of attention heads, by default 8 """
[docs] def __init__( self, mapping_label_to_sublabel, feature_size=64, num_layers=3, dropout=0.0, nhead=8, ): super(TransAm, self).__init__() self.model_type = "Transformer" self.input_embedding = nn.Linear(11, feature_size) self.pos_encoder = ContinuousPositionalEncoding( d_model=feature_size, max_seq_length=140 ) self.encoder_layer = nn.TransformerEncoderLayer( d_model=feature_size, nhead=nhead, dropout=dropout, batch_first=True ) self.transformer_encoder = nn.TransformerEncoder( self.encoder_layer, num_layers=num_layers ) self.cls_token = nn.Parameter(torch.randn(1, 1, feature_size)) self.feature_fill = nn.Parameter(torch.randn(1, 1, 11) - 10) self.dense_bout_cat = nn.Linear(feature_size, 18, bias=True) self.dense_bout_sign = nn.Linear(feature_size, 2, bias=True) self.dense_peak_loc_1 = nn.Linear(feature_size, feature_size, bias=True) self.dense_peak_loc_2 = nn.Linear(feature_size, 1, bias=True) self.mapping_label_to_sublabel = mapping_label_to_sublabel
[docs] def forward(self, input, t, mask): body_angle = input[:, :, 9] input[:, :, 9] = torch.cos(body_angle) input = torch.cat([input, torch.sin(body_angle[:, :, None])], axis=2) mask_feature = torch.isnan(input) feature_filler = torch.broadcast_to(self.feature_fill, input.shape).to( dtype=input.dtype ) # Match dtype with input input[mask_feature] = feature_filler[mask_feature] output = self.input_embedding( input ) # linear transformation before positional embedding output = self.pos_encoder(output, t) output = torch.cat( [self.cls_token.expand(output.shape[0], -1, -1), output], dim=1 ) mask = torch.cat( [ torch.zeros((output.shape[0], 1), dtype=torch.bool, device=mask.device), mask, ], dim=1, ) output = self.transformer_encoder( output, src_key_padding_mask=mask ) # ,self.src_mask) output_CLS = output[:, 0, :] output_bout_cat = self.dense_bout_cat(output_CLS) output_bout_sign = self.dense_bout_sign(output_CLS) output_t_peak = nn.Sigmoid()(self.dense_peak_loc_1(output_CLS)) output_t_peak = self.dense_peak_loc_2(output_t_peak) # To add to pytorch logit_sublabel = output_bout_cat logit_label = torch.zeros( (output_bout_cat.shape[0], 13), dtype=output_bout_cat.dtype, device=output_bout_cat.device, ) for i in range(13): id = self.mapping_label_to_sublabel[i] logit_label[:, i] = torch.log( torch.sum(torch.exp(logit_sublabel[:, id]), axis=1) ) # return output_bout_cat,out_firstHB return logit_label, logit_sublabel, output_bout_sign, output_t_peak