Source code for torchfilter.train._train_filter

"""Private module; avoid importing from directly.
"""

from typing import Callable, cast

import fannypack
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

import torchfilter


def _swap_batch_sequence_axes(tensor: torch.Tensor) -> torch.Tensor:
    """Converts data formatted as (N, T, ...) to (T, N, ...)"""
    return torch.transpose(tensor, 0, 1)


[docs]def train_filter( buddy: fannypack.utils.Buddy, filter_model: torchfilter.base.Filter, dataloader: DataLoader, initial_covariance: torch.Tensor, *, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, log_interval: int = 10, measurement_initialize=False, optimizer_name="train_filter_recurrent", ) -> None: """Trains a filter end-to-end via backpropagation through time for 1 epoch over a subsequence dataset. Args: buddy (fannypack.utils.Buddy): Training helper. filter_model (torchfilter.base.DynamicsModel): Model to train. dataloader (DataLoader): Loader for a SubsequenceDataset. initial_covariance (torch.Tensor): Covariance matrix of error in initial posterior, whose mean is sampled from a Gaussian centered at the ground-truth start state. Shape should be `(state_dim, state_dim)`. Keyword Args: loss_function (callable, optional): Loss function, from `torch.nn.functional`. Defaults to MSE. log_interval (int, optional): Minibatches between each Tensorboard log. """ # Dataloader should load a SubsequenceDataset assert isinstance(dataloader.dataset, torchfilter.data.SubsequenceDataset) assert initial_covariance.shape == (filter_model.state_dim, filter_model.state_dim) assert filter_model.training, "Model needs to be set to train mode" # Track mean epoch loss epoch_loss = 0.0 # Train filter model for 1 epoch batch_data: torchfilter.types.TrajectoryTorch for batch_idx, batch_data in enumerate(tqdm(dataloader)): # Move & unpack data batch_gpu = fannypack.utils.to_device(batch_data, buddy.device) states_labels = batch_gpu.states observations = batch_gpu.observations controls = batch_gpu.controls # Swap batch size, sequence length axes states_labels = _swap_batch_sequence_axes(states_labels) observations = fannypack.utils.SliceWrapper(observations).map( _swap_batch_sequence_axes ) controls = fannypack.utils.SliceWrapper(controls).map(_swap_batch_sequence_axes) # Shape checks T, N, state_dim = states_labels.shape assert state_dim == filter_model.state_dim assert fannypack.utils.SliceWrapper(observations).shape[:2] == (T, N) assert fannypack.utils.SliceWrapper(controls).shape[:2] == (T, N) assert batch_idx != 0 or N == dataloader.batch_size # Populate initial filter belief if measurement_initialize: assert isinstance( filter_model, torchfilter.filters.VirtualSensorExtendedKalmanFilter ), "Only supported for virtual sensor EKFs!" # Initialize belief using virtual sensor cast( torchfilter.filters.VirtualSensorExtendedKalmanFilter, filter_model ).virtual_sensor_initialize_beliefs( observations=fannypack.utils.SliceWrapper(observations)[0] ) else: # Initialize belief using `initial_covariance` initial_states_covariance = initial_covariance[None, :, :].expand( (N, state_dim, state_dim) ) initial_states = torch.distributions.MultivariateNormal( states_labels[0], covariance_matrix=initial_states_covariance ).sample() filter_model.initialize_beliefs( mean=initial_states, covariance=initial_states_covariance ) # Forward pass from the first state state_predictions = filter_model.forward_loop( observations=fannypack.utils.SliceWrapper(observations)[1:], controls=fannypack.utils.SliceWrapper(controls)[1:], ) assert state_predictions.shape == (T - 1, N, state_dim) # Minimize loss loss = loss_function(state_predictions, states_labels[1:]) buddy.minimize(loss, optimizer_name=optimizer_name) epoch_loss += fannypack.utils.to_numpy(loss) # Logging if batch_idx % log_interval == 0: with buddy.log_scope("train_filter_recurrent"): buddy.log_scalar("Training loss", loss) # Print average training loss epoch_loss /= len(dataloader) print("(train_filter) Epoch training loss: ", epoch_loss)