:mod:`torchfilter.train` ======================== .. py:module:: torchfilter.train .. autoapi-nested-parse:: Reference implementations for training state estimators with learnable parameters. These are written with a custom `model manager `_ for brevity, but can be easily translated to raw PyTorch. Package Contents ---------------- Functions ~~~~~~~~~ .. autoapisummary:: torchfilter.train.train_dynamics_recurrent torchfilter.train.train_dynamics_single_step torchfilter.train.train_filter torchfilter.train.train_kalman_filter_measurement torchfilter.train.train_particle_filter_measurement torchfilter.train.train_virtual_sensor .. function:: train_dynamics_recurrent(buddy: fannypack.utils.Buddy, dynamics_model: torchfilter.base.DynamicsModel, dataloader: DataLoader, *, loss_function: str = 'nll', log_interval: int = 10) -> None Trains a dynamics model via backpropagation through time. :param buddy: Training helper. :type buddy: fannypack.utils.Buddy :param dynamics_model: Model to train. :type dynamics_model: torchfilter.base.DynamicsModel :param dataloader: Loader for a SubsequenceDataset. :type dataloader: DataLoader :keyword loss_function: Either "nll" for negative log-likelihood or "mse" for mean-squared error. Defaults to "nll". :kwtype loss_function: str, optional :keyword log_interval: Minibatches between each Tensorboard log. :kwtype log_interval: int, optional .. function:: train_dynamics_single_step(buddy: fannypack.utils.Buddy, dynamics_model: torchfilter.base.DynamicsModel, dataloader: DataLoader, *, loss_function: str = 'nll', log_interval: int = 10) -> None Optimizes a dynamics model's single-step prediction accuracy. This is roughly equivalent to training with ``train_dynamics_recurrent()`` with a subsequence length of 2. :param buddy: Training helper. :type buddy: fannypack.utils.Buddy :param dynamics_model: Model to train. :type dynamics_model: torchfilter.base.DynamicsModel :param dataloader: Loader for a SingleStepDataset. :type dataloader: DataLoader :keyword loss_function: Either "nll" for negative log-likelihood or "mse" for mean-squared error. Defaults to "nll". :kwtype loss_function: str, optional :keyword log_interval: Minibatches between each Tensorboard log. :kwtype log_interval: int, optional .. function:: train_filter(buddy: fannypack.utils.Buddy, filter_model: torchfilter.base.Filter, dataloader: DataLoader, initial_covariance: torch.Tensor, *, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, log_interval: int = 10, measurement_initialize=False, optimizer_name='train_filter_recurrent') -> None Trains a filter end-to-end via backpropagation through time for 1 epoch over a subsequence dataset. :param buddy: Training helper. :type buddy: fannypack.utils.Buddy :param filter_model: Model to train. :type filter_model: torchfilter.base.DynamicsModel :param dataloader: Loader for a SubsequenceDataset. :type dataloader: DataLoader :param initial_covariance: Covariance matrix of error in initial posterior, whose mean is sampled from a Gaussian centered at the ground-truth start state. Shape should be ``(state_dim, state_dim)``. :type initial_covariance: torch.Tensor :keyword loss_function: Loss function, from ``torch.nn.functional``. Defaults to MSE. :kwtype loss_function: callable, optional :keyword log_interval: Minibatches between each Tensorboard log. :kwtype log_interval: int, optional .. function:: train_kalman_filter_measurement(buddy: fannypack.utils.Buddy, measurement_model: torchfilter.base.KalmanFilterMeasurementModel, dataloader: DataLoader, *, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, log_interval: int = 10, optimizer_name='train_kalman_filter_measurement') -> None Optimizes a Kalman filter measurement model's prediction accuracy. Minimizes output mean error only; does not define a loss on uncertainty. :param buddy: Training helper. :type buddy: fannypack.utils.Buddy :param measurement_model: Model to train. :type measurement_model: torchfilter.base.KalmanFilterMeasurementModel :param dataloader: Loader for a SingleStepDataset. :type dataloader: DataLoader :keyword loss_function: Loss function, from ``torch.nn.functional``. Defaults to MSE. :kwtype loss_function: callable, optional :keyword log_interval: Minibatches between each Tensorboard log. :kwtype log_interval: int, optional .. function:: train_particle_filter_measurement(buddy: fannypack.utils.Buddy, measurement_model: torchfilter.base.ParticleFilterMeasurementModel, dataloader: DataLoader, *, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, log_interval: int = 10) -> None Reference implementation for pre-training a particle filter measurement model. Minimizes prediction error for log-likelihood outputs from (state, observation) pairs. :param buddy: Training helper. :type buddy: fannypack.utils.Buddy :param measurement_model: Model. :type measurement_model: torchfilter.base.ParticleFilterMeasurementModel :param dataloader: Loader for a ParticleFilterMeasurementDataset. :type dataloader: DataLoader :keyword loss_function: Loss function, from ``torch.nn.functional``. Defaults to MSE. :kwtype loss_function: callable, optional :keyword log_interval: Minibatches between each Tensorboard log. :kwtype log_interval: int, optional .. function:: train_virtual_sensor(buddy: fannypack.utils.Buddy, virtual_sensor_model: torchfilter.base.VirtualSensorModel, dataloader: DataLoader, *, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, log_interval: int = 10, optimizer_name='train_measurement') -> None Optimizes a virtual sensor model's prediction accuracy. Minimizes output mean error only; does not define a loss on uncertainty. :param buddy: Training helper. :type buddy: fannypack.utils.Buddy :param measurement_model: Model to train. :type measurement_model: torchfilter.base.VirtualSensorModel :param dataloader: Loader for a SingleStepDataset. :type dataloader: DataLoader :keyword loss_function: Loss function, from ``torch.nn.functional``. Defaults to MSE. :kwtype loss_function: callable, optional :keyword log_interval: Minibatches between each Tensorboard log. :kwtype log_interval: int, optional