Source code for torchfilter.types

"""Data structures and semantic type aliases for filtering.
"""
from typing import Dict, NamedTuple, Union

import numpy as np
import torch

# Make an explicit list of names to expose
__all__ = [
    "NumpyDict",
    "TorchDict",
    "NumpyArrayOrDict",
    "TorchTensorOrDict",
    "StatesNumpy",
    "StatesTorch",
    "ObservationsNumpy",
    "ObservationsTorch",
    "ObservationsNoDictNumpy",
    "ObservationsNoDictTorch",
    "ControlsNumpy",
    "ControlsTorch",
    "CovarianceTorch",
    "ScaleTrilTorch",
    "Trajectory",
    "TrajectoryNumpy",
    "TrajectoryTorch",
]

NumpyDict = Dict[str, np.ndarray]
"""Dictionary from `str` keys to `np.ndarray` values."""
TorchDict = Dict[str, torch.Tensor]
"""Dictionary from `str` keys to `torch.Tensor` values."""

NumpyArrayOrDict = Union[np.ndarray, NumpyDict]
"""Union of np.ndarray and NumpyDict types."""
TorchTensorOrDict = Union[torch.Tensor, TorchDict]
"""Union of torch.Tensor and TorchDict types."""

StatesNumpy = np.ndarray
"""State array type hint. Needs to be a raw `np.ndarray`."""
StatesTorch = torch.Tensor
"""State array type hint. Needs to be a raw `torch.Tensor`."""

ObservationsNumpy = NumpyArrayOrDict
"""Observations can be either `np.ndarray` objects or `str->np.ndarray` dictionaries."""
ObservationsTorch = TorchTensorOrDict
"""Observations can be either `torch.Tensor` objects or `str->torch.Tensor` dictionaries."""

ObservationsNoDictNumpy = np.ndarray
"""Same as `ObservationsNumpy`, but no dictionaries."""
ObservationsNoDictTorch = torch.Tensor
"""Same as `ObservationsTorch`, but no dictionaries."""

ControlsNumpy = NumpyArrayOrDict
"""Controls can be either `np.ndarray` objects or `str->np.ndarray` dictionaries."""
ControlsTorch = TorchTensorOrDict
"""Controls can be either `torch.Tensor` objects or `str->torch.Tensor` dictionaries."""

ControlsNoDictNumpy = np.ndarray
"""Same as `ControlsNumpy`, but no dictionaries."""
ControlsNoDictTorch = torch.Tensor
"""Same as `ObservationsTorch`, but no dictionaries."""

CovarianceTorch = torch.Tensor
"""Covariance matrix as `torch.Tensor`. Must be positive semi-definite."""
ScaleTrilTorch = torch.Tensor
"""Lower-triangular cholesky decomposition of covariance matrix as `torch.Tensor`."""


[docs]class Trajectory(NamedTuple): """Named tuple containing states, observations, and controls.""" states: Union[StatesNumpy, StatesTorch] observations: Union[ObservationsNumpy, ObservationsTorch] controls: Union[ControlsNumpy, ControlsTorch]
[docs]class TrajectoryNumpy(Trajectory): """Named tuple containing states, observations, and controls represented in NumPy.""" states: StatesNumpy observations: ObservationsNumpy controls: ControlsNumpy
[docs]class TrajectoryTorch(Trajectory): """Named tuple containing states, observations, and controls represented in Torch.""" states: StatesTorch observations: ObservationsTorch controls: ControlsTorch