torchfilter.base._filter
Private module; avoid importing from directly.
Module Contents
Classes
Base class for a generic differentiable state estimator. |
- class torchfilter.base._filter.Filter(*, state_dim: int)[source]
Bases:
torch.nn.Module
,abc.ABC
Base class for a generic differentiable state estimator.
As a minimum, subclasses should override:
initialize_beliefs
for populating the initial belief of our estimatorforward
orforward_loop
for computing state predictions
- state_dim
Dimensionality of our state.
- Type:
int
- abstract initialize_beliefs(self, *, mean: types.StatesTorch, covariance: types.CovarianceTorch) None [source]
Initialize our filter with a Gaussian prior.
- Parameters:
mean (torch.Tensor) – Mean of belief. Shape should be
(N, state_dim)
.covariance (torch.Tensor) – Covariance of belief. Shape should be
(N, state_dim, state_dim)
.
- forward(self, *, observations: types.ObservationsTorch, controls: types.ControlsTorch) types.StatesTorch [source]
Filtering forward pass, over a single timestep.
By default, this is implemented by bootstrapping the
forward_loop()
method.- Parameters:
observations (dict or torch.Tensor) – Observation inputs. Should be either a dict of tensors or tensor of size
(N, ...)
.controls (dict or torch.Tensor) – Control inputs. Should be either a dict of tensors or tensor of size
(N, ...)
.
- Returns:
torch.Tensor – Predicted state for each batch element. Shape should be
(N, state_dim).
- forward_loop(self, *, observations: types.ObservationsTorch, controls: types.ControlsTorch) types.StatesTorch [source]
Filtering forward pass, over sequence length
T
and batch sizeN
. By default, this is implemented by iteratively callingforward()
.To inject code between timesteps (for example, to inspect hidden state), use
register_forward_hook()
.- Parameters:
observations (dict or torch.Tensor) – observation inputs. Should be either a dict of tensors or tensor of size
(T, N, ...)
.controls (dict or torch.Tensor) – control inputs. Should be either a dict of tensors or tensor of size
(T, N, ...)
.
- Returns:
torch.Tensor – Predicted states at each timestep. Shape should be
(T, N, state_dim).