"""Private module; avoid importing from directly.
"""
from typing import cast
import torch
from overrides import overrides
from .. import types
from ..base import KalmanFilterBase
[docs]class ExtendedKalmanFilter(KalmanFilterBase):
"""Generic differentiable EKF.
For building estimators with more complex observation spaces (eg images), see
`VirtualSensorExtendedKalmanFilter`.
"""
@overrides
def _predict_step(self, *, controls: types.ControlsTorch) -> None:
# Get previous belief
prev_mean = self._belief_mean
prev_covariance = self._belief_covariance
N, state_dim = prev_mean.shape
# Compute mu_{t+1|t}, covariance, and Jacobian
pred_mean, dynamics_tril = self.dynamics_model(
initial_states=prev_mean, controls=controls
)
dynamics_covariance = dynamics_tril @ dynamics_tril.transpose(-1, -2)
dynamics_A_matrix = self.dynamics_model.jacobian(
initial_states=prev_mean, controls=controls
)
assert dynamics_covariance.shape == (N, state_dim, state_dim)
assert dynamics_A_matrix.shape == (N, state_dim, state_dim)
# Calculate Sigma_{t+1|t}
pred_covariance = (
dynamics_A_matrix @ prev_covariance @ dynamics_A_matrix.transpose(-1, -2)
+ dynamics_covariance
)
# Update internal state
self._belief_mean = pred_mean
self._belief_covariance = pred_covariance
@overrides
def _update_step(self, *, observations: types.ObservationsTorch) -> None:
# Extract/validate inputs
assert isinstance(
observations, types.ObservationsNoDictTorch
), "For standard EKF, observations must be tensor!"
observations = cast(types.ObservationsNoDictTorch, observations)
pred_mean = self._belief_mean
pred_covariance = self._belief_covariance
# Measurement model forward pass, Jacobian
observations_mean = observations
pred_observations, observations_tril = self.measurement_model(states=pred_mean)
observations_covariance = observations_tril @ observations_tril.transpose(
-1, -2
)
C_matrix = self.measurement_model.jacobian(states=pred_mean)
assert observations_mean.shape == pred_observations.shape
# Check shapes
N, observation_dim = observations_mean.shape
assert observations_covariance.shape == (N, observation_dim, observation_dim)
assert observations_mean.shape == (N, observation_dim)
# Compute Kalman Gain, innovation
innovation = observations_mean - pred_observations
innovation_covariance = (
C_matrix @ pred_covariance @ C_matrix.transpose(-1, -2)
+ observations_covariance
)
kalman_gain = (
pred_covariance
@ C_matrix.transpose(-1, -2)
@ torch.inverse(innovation_covariance)
)
# Get mu_{t+1|t+1}, Sigma_{t+1|t+1}
corrected_mean = pred_mean + (kalman_gain @ innovation[:, :, None]).squeeze(-1)
assert corrected_mean.shape == (N, self.state_dim)
identity = torch.eye(self.state_dim, device=kalman_gain.device)
corrected_covariance = (identity - kalman_gain @ C_matrix) @ pred_covariance
assert corrected_covariance.shape == (N, self.state_dim, self.state_dim)
# Update internal state
self._belief_mean = corrected_mean
self._belief_covariance = corrected_covariance