torchfilter.train._train_virtual_sensor

Private module; avoid importing from directly.

Module Contents

Functions

train_virtual_sensor(buddy: fannypack.utils.Buddy, virtual_sensor_model: torchfilter.base.VirtualSensorModel, dataloader: DataLoader, *, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, log_interval: int = 10, optimizer_name='train_measurement') → None

Optimizes a virtual sensor model's prediction accuracy. Minimizes output mean

torchfilter.train._train_virtual_sensor.train_virtual_sensor(buddy: fannypack.utils.Buddy, virtual_sensor_model: torchfilter.base.VirtualSensorModel, dataloader: DataLoader, *, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, log_interval: int = 10, optimizer_name='train_measurement') None[source]

Optimizes a virtual sensor model’s prediction accuracy. Minimizes output mean error only; does not define a loss on uncertainty.

Parameters:
  • buddy (fannypack.utils.Buddy) – Training helper.

  • measurement_model (torchfilter.base.VirtualSensorModel) – Model to train.

  • dataloader (DataLoader) – Loader for a SingleStepDataset.

Keyword Arguments:
  • loss_function (callable, optional) – Loss function, from torch.nn.functional. Defaults to MSE.

  • log_interval (int, optional) – Minibatches between each Tensorboard log.