:orphan: :mod:`torchfilter.train._train_filter` ====================================== .. py:module:: torchfilter.train._train_filter .. autoapi-nested-parse:: Private module; avoid importing from directly. Module Contents --------------- Functions ~~~~~~~~~ .. autoapisummary:: torchfilter.train._train_filter.train_filter .. function:: train_filter(buddy: fannypack.utils.Buddy, filter_model: torchfilter.base.Filter, dataloader: DataLoader, initial_covariance: torch.Tensor, *, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = F.mse_loss, log_interval: int = 10, measurement_initialize=False, optimizer_name='train_filter_recurrent') -> None Trains a filter end-to-end via backpropagation through time for 1 epoch over a subsequence dataset. :param buddy: Training helper. :type buddy: fannypack.utils.Buddy :param filter_model: Model to train. :type filter_model: torchfilter.base.DynamicsModel :param dataloader: Loader for a SubsequenceDataset. :type dataloader: DataLoader :param initial_covariance: Covariance matrix of error in initial posterior, whose mean is sampled from a Gaussian centered at the ground-truth start state. Shape should be ``(state_dim, state_dim)``. :type initial_covariance: torch.Tensor :keyword loss_function: Loss function, from ``torch.nn.functional``. Defaults to MSE. :kwtype loss_function: callable, optional :keyword log_interval: Minibatches between each Tensorboard log. :kwtype log_interval: int, optional