torchfilter.train._train_particle_filter_measurement

Private module; avoid importing from directly.

Module Contents

Functions

train_particle_filter_measurement(buddy: fannypack.utils.Buddy, measurement_model: torchfilter.base.ParticleFilterMeasurementModel, dataloader: DataLoader, *, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, log_interval: int = 10) → None

Reference implementation for pre-training a particle filter measurement model.

torchfilter.train._train_particle_filter_measurement.train_particle_filter_measurement(buddy: fannypack.utils.Buddy, measurement_model: torchfilter.base.ParticleFilterMeasurementModel, dataloader: DataLoader, *, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, log_interval: int = 10) None[source]

Reference implementation for pre-training a particle filter measurement model. Minimizes prediction error for log-likelihood outputs from (state, observation) pairs.

Parameters:
Keyword Arguments:
  • loss_function (callable, optional) – Loss function, from torch.nn.functional. Defaults to MSE.

  • log_interval (int, optional) – Minibatches between each Tensorboard log.