Source code for torchfilter.filters._extended_information_filter

"""Private module; avoid importing from directly.
"""
from typing import cast

import fannypack
import torch
from overrides import overrides

from .. import types
from ..base import DynamicsModel, KalmanFilterBase, KalmanFilterMeasurementModel


[docs]class ExtendedInformationFilter(KalmanFilterBase): """Information form of a Kalman filter; generally equivalent to an EKF but internally parameterizes uncertainties with the inverse covariance matrix. For building estimators with more complex observation spaces (eg images), see `VirtualSensorExtendedInformationFilter`. """ def __init__( self, *, dynamics_model: DynamicsModel, measurement_model: KalmanFilterMeasurementModel, ): super().__init__( dynamics_model=dynamics_model, measurement_model=measurement_model ) # Parameterize posterior uncertainty with inverse covariance self.information_vector: torch.Tensor """torch.Tensor: Information vector of our posterior; shape should be `(N, state_dim)`.""" self.information_matrix: torch.Tensor """torch.Tensor: Information matrix of our posterior; shape should be `(N, state_dim, state_dim)`.""" # overrides @property def belief_covariance(self) -> types.CovarianceTorch: """Posterior covariance. Shape should be `(N, state_dim, state_dim)`.""" return fannypack.utils.cholesky_inverse(torch.linalg.cholesky(self.information_matrix)) # overrides @belief_covariance.setter def belief_covariance(self, covariance: types.CovarianceTorch): self.information_matrix = fannypack.utils.cholesky_inverse( torch.linalg.cholesky(covariance) ) @overrides def _predict_step(self, *, controls: types.ControlsTorch) -> None: # Get previous belief prev_mean = self._belief_mean prev_covariance = self.belief_covariance N, state_dim = prev_mean.shape # Compute mu_{t+1|t}, covariance, and Jacobian pred_mean, dynamics_tril = self.dynamics_model( initial_states=prev_mean, controls=controls ) dynamics_covariance = dynamics_tril @ dynamics_tril.transpose(-1, -2) dynamics_A_matrix = self.dynamics_model.jacobian( initial_states=prev_mean, controls=controls ) assert dynamics_covariance.shape == (N, state_dim, state_dim) assert dynamics_A_matrix.shape == (N, state_dim, state_dim) # Calculate Sigma_{t+1|t} pred_information_matrix = fannypack.utils.cholesky_inverse( torch.linalg.cholesky( dynamics_A_matrix @ prev_covariance @ dynamics_A_matrix.transpose(-1, -2) + dynamics_covariance ) ) pred_information_vector = ( pred_information_matrix @ pred_mean[:, :, None] ).squeeze(-1) # Update internal state self._belief_mean = pred_mean self.information_matrix = pred_information_matrix self.information_vector = pred_information_vector @overrides def _update_step(self, *, observations: types.ObservationsTorch) -> None: # Extract/validate inputs assert isinstance( observations, types.ObservationsNoDictTorch ), "For standard EKF, observations must be tensor!" observations = cast(types.ObservationsNoDictTorch, observations) pred_mean = self._belief_mean pred_information_matrix = self.information_matrix pred_information_vector = self.information_vector # Measurement model forward pass, Jacobian observations_mean = observations pred_observations, observations_tril = self.measurement_model(states=pred_mean) observations_information = fannypack.utils.cholesky_inverse(observations_tril) C_matrix = self.measurement_model.jacobian(states=pred_mean) C_matrix_transpose = C_matrix.transpose(-1, -2) assert observations_mean.shape == pred_observations.shape # Check shapes N, observation_dim = observations_mean.shape assert observations_information.shape == (N, observation_dim, observation_dim) assert observations_mean.shape == (N, observation_dim) # Compute update information_vector = pred_information_vector + ( C_matrix_transpose @ observations_information @ ( observations_mean[:, :, None] - pred_observations[:, :, None] + C_matrix @ pred_mean[:, :, None] ) ).squeeze(-1) assert information_vector.shape == (N, self.state_dim) information_matrix = ( pred_information_matrix + C_matrix_transpose @ observations_information @ C_matrix ) assert information_matrix.shape == (N, self.state_dim, self.state_dim) # Update internal state self.information_matrix = information_matrix self.information_vector = information_vector self._belief_mean = ( fannypack.utils.cholesky_inverse(torch.linalg.cholesky(information_matrix)) @ information_vector[:, :, None] ).squeeze(-1)