Source code for torchfilter.base._virtual_sensor_model

"""Private module; avoid importing from directly.
"""

import abc
from typing import Tuple

import torch.nn as nn
from overrides import overrides

from .. import types


[docs]class VirtualSensorModel(abc.ABC, nn.Module): """Virtual sensor base class for our differentiable Kalman filters. Maps each observation input to a predicted state and uncertainty, in the style of BackpropKF. This is often necessary for complex observation spaces like images or point clouds, where it's not possible to learn a standard state->observation measurement model. """ def __init__(self, state_dim: int): super().__init__() self.state_dim = state_dim """int: Dimensionality of our state."""
[docs] @abc.abstractmethod # @overrides def forward( self, *, observations: types.ObservationsTorch ) -> Tuple[types.StatesTorch, types.ScaleTrilTorch]: """Predicts states and uncertainties from observation inputs. Uncertainties should be lower-triangular Cholesky decompositions of covariance matrices. Args: observations (dict or torch.Tensor): Measurement inputs. Should be either a dict of tensors or tensor of size `(N, ...)`. Returns: Tuple[torch.Tensor, torch.Tensor]: Predicted states & uncertainties. States should have shape `(N, state_dim)`, and uncertainties should be lower triangular with shape `(N, state_dim, state_dim).` """