torchfilter.base._filter

Private module; avoid importing from directly.

Module Contents

Classes

Filter

Base class for a generic differentiable state estimator.

class torchfilter.base._filter.Filter(*, state_dim: int)[source]

Bases: torch.nn.Module, abc.ABC

Inheritance diagram of torchfilter.base._filter.Filter

Base class for a generic differentiable state estimator.

As a minimum, subclasses should override:

  • initialize_beliefs for populating the initial belief of our estimator

  • forward or forward_loop for computing state predictions

state_dim

Dimensionality of our state.

Type:

int

abstract initialize_beliefs(self, *, mean: types.StatesTorch, covariance: types.CovarianceTorch) None[source]

Initialize our filter with a Gaussian prior.

Parameters:
  • mean (torch.Tensor) – Mean of belief. Shape should be (N, state_dim).

  • covariance (torch.Tensor) – Covariance of belief. Shape should be (N, state_dim, state_dim).

forward(self, *, observations: types.ObservationsTorch, controls: types.ControlsTorch) types.StatesTorch[source]

Filtering forward pass, over a single timestep.

By default, this is implemented by bootstrapping the forward_loop() method.

Parameters:
  • observations (dict or torch.Tensor) – Observation inputs. Should be either a dict of tensors or tensor of size (N, ...).

  • controls (dict or torch.Tensor) – Control inputs. Should be either a dict of tensors or tensor of size (N, ...).

Returns:

torch.Tensor – Predicted state for each batch element. Shape should be (N, state_dim).

forward_loop(self, *, observations: types.ObservationsTorch, controls: types.ControlsTorch) types.StatesTorch[source]

Filtering forward pass, over sequence length T and batch size N. By default, this is implemented by iteratively calling forward().

To inject code between timesteps (for example, to inspect hidden state), use register_forward_hook().

Parameters:
  • observations (dict or torch.Tensor) – observation inputs. Should be either a dict of tensors or tensor of size (T, N, ...).

  • controls (dict or torch.Tensor) – control inputs. Should be either a dict of tensors or tensor of size (T, N, ...).

Returns:

torch.Tensor – Predicted states at each timestep. Shape should be (T, N, state_dim).