torchfilter.train._train_filter

Private module; avoid importing from directly.

Module Contents

Functions

train_filter(buddy: fannypack.utils.Buddy, filter_model: torchfilter.base.Filter, dataloader: DataLoader, initial_covariance: torch.Tensor, *, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, log_interval: int = 10, measurement_initialize=False, optimizer_name='train_filter_recurrent') → None

Trains a filter end-to-end via backpropagation through time for 1 epoch over a

torchfilter.train._train_filter.train_filter(buddy: fannypack.utils.Buddy, filter_model: torchfilter.base.Filter, dataloader: DataLoader, initial_covariance: torch.Tensor, *, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, log_interval: int = 10, measurement_initialize=False, optimizer_name='train_filter_recurrent') None[source]

Trains a filter end-to-end via backpropagation through time for 1 epoch over a subsequence dataset.

Parameters:
  • buddy (fannypack.utils.Buddy) – Training helper.

  • filter_model (torchfilter.base.DynamicsModel) – Model to train.

  • dataloader (DataLoader) – Loader for a SubsequenceDataset.

  • initial_covariance (torch.Tensor) – Covariance matrix of error in initial posterior, whose mean is sampled from a Gaussian centered at the ground-truth start state. Shape should be (state_dim, state_dim).

Keyword Arguments:
  • loss_function (callable, optional) – Loss function, from torch.nn.functional. Defaults to MSE.

  • log_interval (int, optional) – Minibatches between each Tensorboard log.