torchfilter.train._train_dynamics
Private module; avoid importing from directly.
Module Contents
Functions
|
Optimizes a dynamics model's single-step prediction accuracy. This is roughly |
|
Trains a dynamics model via backpropagation through time. |
- torchfilter.train._train_dynamics.train_dynamics_single_step(buddy: fannypack.utils.Buddy, dynamics_model: torchfilter.base.DynamicsModel, dataloader: DataLoader, *, loss_function: str = 'nll', log_interval: int = 10) None [source]
Optimizes a dynamics model’s single-step prediction accuracy. This is roughly equivalent to training with
train_dynamics_recurrent()
with a subsequence length of 2.- Parameters:
buddy (fannypack.utils.Buddy) – Training helper.
dynamics_model (torchfilter.base.DynamicsModel) – Model to train.
dataloader (DataLoader) – Loader for a SingleStepDataset.
- Keyword Arguments:
loss_function (str, optional) – Either “nll” for negative log-likelihood or “mse” for mean-squared error. Defaults to “nll”.
log_interval (int, optional) – Minibatches between each Tensorboard log.
- torchfilter.train._train_dynamics.train_dynamics_recurrent(buddy: fannypack.utils.Buddy, dynamics_model: torchfilter.base.DynamicsModel, dataloader: DataLoader, *, loss_function: str = 'nll', log_interval: int = 10) None [source]
Trains a dynamics model via backpropagation through time.
- Parameters:
buddy (fannypack.utils.Buddy) – Training helper.
dynamics_model (torchfilter.base.DynamicsModel) – Model to train.
dataloader (DataLoader) – Loader for a SubsequenceDataset.
- Keyword Arguments:
loss_function (str, optional) – Either “nll” for negative log-likelihood or “mse” for mean-squared error. Defaults to “nll”.
log_interval (int, optional) – Minibatches between each Tensorboard log.