torchfilter.train._train_dynamics

Private module; avoid importing from directly.

Module Contents

Functions

train_dynamics_single_step(buddy: fannypack.utils.Buddy, dynamics_model: torchfilter.base.DynamicsModel, dataloader: DataLoader, *, loss_function: str = 'nll', log_interval: int = 10) → None

Optimizes a dynamics model's single-step prediction accuracy. This is roughly

train_dynamics_recurrent(buddy: fannypack.utils.Buddy, dynamics_model: torchfilter.base.DynamicsModel, dataloader: DataLoader, *, loss_function: str = 'nll', log_interval: int = 10) → None

Trains a dynamics model via backpropagation through time.

torchfilter.train._train_dynamics.train_dynamics_single_step(buddy: fannypack.utils.Buddy, dynamics_model: torchfilter.base.DynamicsModel, dataloader: DataLoader, *, loss_function: str = 'nll', log_interval: int = 10) None[source]

Optimizes a dynamics model’s single-step prediction accuracy. This is roughly equivalent to training with train_dynamics_recurrent() with a subsequence length of 2.

Parameters:
  • buddy (fannypack.utils.Buddy) – Training helper.

  • dynamics_model (torchfilter.base.DynamicsModel) – Model to train.

  • dataloader (DataLoader) – Loader for a SingleStepDataset.

Keyword Arguments:
  • loss_function (str, optional) – Either “nll” for negative log-likelihood or “mse” for mean-squared error. Defaults to “nll”.

  • log_interval (int, optional) – Minibatches between each Tensorboard log.

torchfilter.train._train_dynamics.train_dynamics_recurrent(buddy: fannypack.utils.Buddy, dynamics_model: torchfilter.base.DynamicsModel, dataloader: DataLoader, *, loss_function: str = 'nll', log_interval: int = 10) None[source]

Trains a dynamics model via backpropagation through time.

Parameters:
  • buddy (fannypack.utils.Buddy) – Training helper.

  • dynamics_model (torchfilter.base.DynamicsModel) – Model to train.

  • dataloader (DataLoader) – Loader for a SubsequenceDataset.

Keyword Arguments:
  • loss_function (str, optional) – Either “nll” for negative log-likelihood or “mse” for mean-squared error. Defaults to “nll”.

  • log_interval (int, optional) – Minibatches between each Tensorboard log.