torchfilter.train

Reference implementations for training state estimators with learnable parameters.

These are written with a custom model manager for brevity, but can be easily translated to raw PyTorch.

Package Contents

Functions

train_dynamics_recurrent(buddy: fannypack.utils.Buddy, dynamics_model: torchfilter.base.DynamicsModel, dataloader: DataLoader, *, loss_function: str = 'nll', log_interval: int = 10) → None

Trains a dynamics model via backpropagation through time.

train_dynamics_single_step(buddy: fannypack.utils.Buddy, dynamics_model: torchfilter.base.DynamicsModel, dataloader: DataLoader, *, loss_function: str = 'nll', log_interval: int = 10) → None

Optimizes a dynamics model’s single-step prediction accuracy. This is roughly

train_filter(buddy: fannypack.utils.Buddy, filter_model: torchfilter.base.Filter, dataloader: DataLoader, initial_covariance: torch.Tensor, *, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, log_interval: int = 10, measurement_initialize=False, optimizer_name='train_filter_recurrent') → None

Trains a filter end-to-end via backpropagation through time for 1 epoch over a

train_kalman_filter_measurement(buddy: fannypack.utils.Buddy, measurement_model: torchfilter.base.KalmanFilterMeasurementModel, dataloader: DataLoader, *, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, log_interval: int = 10, optimizer_name='train_kalman_filter_measurement') → None

Optimizes a Kalman filter measurement model’s prediction accuracy. Minimizes

train_particle_filter_measurement(buddy: fannypack.utils.Buddy, measurement_model: torchfilter.base.ParticleFilterMeasurementModel, dataloader: DataLoader, *, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, log_interval: int = 10) → None

Reference implementation for pre-training a particle filter measurement model.

train_virtual_sensor(buddy: fannypack.utils.Buddy, virtual_sensor_model: torchfilter.base.VirtualSensorModel, dataloader: DataLoader, *, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, log_interval: int = 10, optimizer_name='train_measurement') → None

Optimizes a virtual sensor model’s prediction accuracy. Minimizes output mean

torchfilter.train.train_dynamics_recurrent(buddy: fannypack.utils.Buddy, dynamics_model: torchfilter.base.DynamicsModel, dataloader: DataLoader, *, loss_function: str = 'nll', log_interval: int = 10) → None[source]

Trains a dynamics model via backpropagation through time.

Parameters
  • buddy (fannypack.utils.Buddy) – Training helper.

  • dynamics_model (torchfilter.base.DynamicsModel) – Model to train.

  • dataloader (DataLoader) – Loader for a SubsequenceDataset.

Keyword Arguments
  • loss_function (str, optional) – Either “nll” for negative log-likelihood or “mse” for mean-squared error. Defaults to “nll”.

  • log_interval (int, optional) – Minibatches between each Tensorboard log.

torchfilter.train.train_dynamics_single_step(buddy: fannypack.utils.Buddy, dynamics_model: torchfilter.base.DynamicsModel, dataloader: DataLoader, *, loss_function: str = 'nll', log_interval: int = 10) → None[source]

Optimizes a dynamics model’s single-step prediction accuracy. This is roughly equivalent to training with train_dynamics_recurrent() with a subsequence length of 2.

Parameters
  • buddy (fannypack.utils.Buddy) – Training helper.

  • dynamics_model (torchfilter.base.DynamicsModel) – Model to train.

  • dataloader (DataLoader) – Loader for a SingleStepDataset.

Keyword Arguments
  • loss_function (str, optional) – Either “nll” for negative log-likelihood or “mse” for mean-squared error. Defaults to “nll”.

  • log_interval (int, optional) – Minibatches between each Tensorboard log.

torchfilter.train.train_filter(buddy: fannypack.utils.Buddy, filter_model: torchfilter.base.Filter, dataloader: DataLoader, initial_covariance: torch.Tensor, *, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, log_interval: int = 10, measurement_initialize=False, optimizer_name='train_filter_recurrent') → None[source]

Trains a filter end-to-end via backpropagation through time for 1 epoch over a subsequence dataset.

Parameters
  • buddy (fannypack.utils.Buddy) – Training helper.

  • filter_model (torchfilter.base.DynamicsModel) – Model to train.

  • dataloader (DataLoader) – Loader for a SubsequenceDataset.

  • initial_covariance (torch.Tensor) – Covariance matrix of error in initial posterior, whose mean is sampled from a Gaussian centered at the ground-truth start state. Shape should be (state_dim, state_dim).

Keyword Arguments
  • loss_function (callable, optional) – Loss function, from torch.nn.functional. Defaults to MSE.

  • log_interval (int, optional) – Minibatches between each Tensorboard log.

torchfilter.train.train_kalman_filter_measurement(buddy: fannypack.utils.Buddy, measurement_model: torchfilter.base.KalmanFilterMeasurementModel, dataloader: DataLoader, *, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, log_interval: int = 10, optimizer_name='train_kalman_filter_measurement') → None[source]

Optimizes a Kalman filter measurement model’s prediction accuracy. Minimizes output mean error only; does not define a loss on uncertainty.

Parameters
Keyword Arguments
  • loss_function (callable, optional) – Loss function, from torch.nn.functional. Defaults to MSE.

  • log_interval (int, optional) – Minibatches between each Tensorboard log.

torchfilter.train.train_particle_filter_measurement(buddy: fannypack.utils.Buddy, measurement_model: torchfilter.base.ParticleFilterMeasurementModel, dataloader: DataLoader, *, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, log_interval: int = 10) → None[source]

Reference implementation for pre-training a particle filter measurement model. Minimizes prediction error for log-likelihood outputs from (state, observation) pairs.

Parameters
Keyword Arguments
  • loss_function (callable, optional) – Loss function, from torch.nn.functional. Defaults to MSE.

  • log_interval (int, optional) – Minibatches between each Tensorboard log.

torchfilter.train.train_virtual_sensor(buddy: fannypack.utils.Buddy, virtual_sensor_model: torchfilter.base.VirtualSensorModel, dataloader: DataLoader, *, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, log_interval: int = 10, optimizer_name='train_measurement') → None[source]

Optimizes a virtual sensor model’s prediction accuracy. Minimizes output mean error only; does not define a loss on uncertainty.

Parameters
  • buddy (fannypack.utils.Buddy) – Training helper.

  • measurement_model (torchfilter.base.VirtualSensorModel) – Model to train.

  • dataloader (DataLoader) – Loader for a SingleStepDataset.

Keyword Arguments
  • loss_function (callable, optional) – Loss function, from torch.nn.functional. Defaults to MSE.

  • log_interval (int, optional) – Minibatches between each Tensorboard log.