:orphan: :mod:`torchfilter.filters._particle_filter` =========================================== .. py:module:: torchfilter.filters._particle_filter .. autoapi-nested-parse:: Private module; avoid importing from directly. Module Contents --------------- Classes ~~~~~~~ .. autoapisummary:: torchfilter.filters._particle_filter.ParticleFilter .. py:class:: ParticleFilter(*, dynamics_model: DynamicsModel, measurement_model: ParticleFilterMeasurementModel, num_particles: int = 100, resample: Optional[bool] = None, soft_resample_alpha: float = 1.0, estimation_method: str = 'weighted_average') Bases: :class:`torchfilter.base.Filter` .. autoapi-inheritance-diagram:: torchfilter.filters._particle_filter.ParticleFilter :parts: 1 Generic differentiable particle filter. .. attribute:: dynamics_model Forward model. :type: torchfilter.base.DynamicsModel .. attribute:: measurement_model Observation model. :type: torchfilter.base.ParticleFilterMeasurementModel .. attribute:: num_particles Number of particles to represent our belief distribution. Defaults to 100. :type: int .. attribute:: resample If True, we resample particles & normalize weights at each timestep. If unset (None), we automatically turn resampling on in eval mode and off in train mode. :type: bool .. attribute:: soft_resample_alpha Tunable constant for differentiable resampling, as described by Karkus et al. in "Particle Filter Networks with Application to Visual Localization": https://arxiv.org/abs/1805.08975 Defaults to 1.0 (disabled). :type: float .. attribute:: estimation_method Method of producing state estimates. Options include: * 'weighted_average': average of particles weighted by their weights. * 'argmax': state of highest weighted particle. :type: str .. attribute:: particle_states :annotation: :torch.Tensor Discrete particles representing our current belief distribution. Shape should be ``(N, M, state_dim)``. :type: torch.Tensor .. attribute:: particle_log_weights :annotation: :torch.Tensor Weights corresponding to each particle, stored as log-likelihoods. Shape should be ``(N, M)``. :type: torch.Tensor .. method:: initialize_beliefs(self, *, mean: types.StatesTorch, covariance: types.CovarianceTorch) -> None Populates initial particles, which will be normally distributed. :param mean: Mean of belief. Shape should be ``(N, state_dim)``. :type mean: torch.Tensor :param covariance: Covariance of belief. Shape should be ``(N, state_dim, state_dim)``. :type covariance: torch.Tensor .. method:: forward(self, *, observations: types.ObservationsTorch, controls: types.ControlsTorch) -> types.StatesTorch Particle filter forward pass, single timestep. :param observations: observation inputs. should be either a dict of tensors or tensor of shape ``(N, ...)``. :type observations: dict or torch.Tensor :param controls: control inputs. should be either a dict of tensors or tensor of shape ``(N, ...)``. :type controls: dict or torch.Tensor :returns: *torch.Tensor* -- Predicted state for each batch element. Shape should be ``(N, state_dim).``