Source code for torchfilter.data._subsequence_dataset

"""Private module; avoid importing from directly.
"""

from typing import List

from torch.utils.data import Dataset

from .. import types
from ._split_trajectories import split_trajectories


[docs]class SubsequenceDataset(Dataset): """A data preprocessor for producing training subsequences from a list of trajectories. Thin wrapper around `torchfilter.data.split_trajectories()`. Args: trajectories (list): list of trajectories, where each is a tuple of `(states, observations, controls)`. Each tuple member should be either a numpy array or dict of numpy arrays with shape `(T, ...)`. subsequence_length (int): # of timesteps per subsequence. """ def __init__( self, trajectories: List[types.TrajectoryNumpy], subsequence_length: int ): # Split trajectory into overlapping subsequences self.subsequences: List[types.TrajectoryNumpy] = split_trajectories( trajectories, subsequence_length )
[docs] def __getitem__(self, index: int) -> types.TrajectoryNumpy: """Get a subsequence from our dataset. Args: index (int): Subsequence number in our dataset. Returns: tuple: `(states, observations, controls)` tuple that contains data for a single subsequence. Each tuple member should be either a numpy array or dict of numpy arrays with shape `(subsequence_length, ...)`. """ return self.subsequences[index]
[docs] def __len__(self) -> int: """Total number of subsequences in the dataset. Returns: int: Length of dataset. """ return len(self.subsequences)