Source code for torchfilter.filters._particle_filter

"""Private module; avoid importing from directly.
"""

from typing import Optional

import fannypack
import numpy as np
import torch
from overrides import overrides

from .. import types
from ..base import DynamicsModel, Filter, ParticleFilterMeasurementModel


[docs]class ParticleFilter(Filter): """Generic differentiable particle filter.""" def __init__( self, *, dynamics_model: DynamicsModel, measurement_model: ParticleFilterMeasurementModel, num_particles: int = 100, resample: Optional[bool] = None, soft_resample_alpha: float = 1.0, estimation_method: str = "weighted_average", ): # Check submodule consistency assert isinstance(dynamics_model, DynamicsModel) assert isinstance(measurement_model, ParticleFilterMeasurementModel) assert dynamics_model.state_dim == measurement_model.state_dim # Initialize state dimension state_dim = dynamics_model.state_dim super().__init__(state_dim=state_dim) # Assign submodules self.dynamics_model = dynamics_model """torchfilter.base.DynamicsModel: Forward model.""" self.measurement_model = measurement_model """torchfilter.base.ParticleFilterMeasurementModel: Observation model.""" # Settings self.num_particles = num_particles """int: Number of particles to represent our belief distribution. Defaults to 100.""" self.resample = resample """bool: If True, we resample particles & normalize weights at each timestep. If unset (None), we automatically turn resampling on in eval mode and off in train mode.""" self.soft_resample_alpha = soft_resample_alpha """float: Tunable constant for differentiable resampling, as described by Karkus et al. in "Particle Filter Networks with Application to Visual Localization": https://arxiv.org/abs/1805.08975 Defaults to 1.0 (disabled).""" assert estimation_method in ("weighted_average", "argmax") self.estimation_method = estimation_method """str: Method of producing state estimates. Options include: - 'weighted_average': average of particles weighted by their weights. - 'argmax': state of highest weighted particle. """ # "Hidden state" tensors self.particle_states: torch.Tensor """torch.Tensor: Discrete particles representing our current belief distribution. Shape should be `(N, M, state_dim)`. """ self.particle_log_weights: torch.Tensor """torch.Tensor: Weights corresponding to each particle, stored as log-likelihoods. Shape should be `(N, M)`. """ self._initialized = False
[docs] @overrides def initialize_beliefs( self, *, mean: types.StatesTorch, covariance: types.CovarianceTorch ) -> None: """Populates initial particles, which will be normally distributed. Args: mean (torch.Tensor): Mean of belief. Shape should be `(N, state_dim)`. covariance (torch.Tensor): Covariance of belief. Shape should be `(N, state_dim, state_dim)`. """ N = mean.shape[0] assert mean.shape == (N, self.state_dim) assert covariance.shape == (N, self.state_dim, self.state_dim) M = self.num_particles # Sample particles self.particle_states = ( torch.distributions.MultivariateNormal(mean, covariance) .sample((M,)) .transpose(0, 1) ) assert self.particle_states.shape == (N, M, self.state_dim) # Normalize weights self.particle_log_weights = self.particle_states.new_full( (N, M), float(-np.log(M, dtype=np.float32)) ) assert self.particle_log_weights.shape == (N, M) # Set initialized flag self._initialized = True
[docs] @overrides def forward( self, *, observations: types.ObservationsTorch, controls: types.ControlsTorch, ) -> types.StatesTorch: """Particle filter forward pass, single timestep. Args: observations (dict or torch.Tensor): observation inputs. should be either a dict of tensors or tensor of shape `(N, ...)`. controls (dict or torch.Tensor): control inputs. should be either a dict of tensors or tensor of shape `(N, ...)`. Returns: torch.Tensor: Predicted state for each batch element. Shape should be `(N, state_dim).` """ # Make sure our particle filter's been initialized assert self._initialized, "Particle filter not initialized!" # Get our batch size (N), current particle count (M), & state dimension N, M, state_dim = self.particle_states.shape assert state_dim == self.state_dim assert len(fannypack.utils.SliceWrapper(controls)) == N # Decide whether or not we're resampling resample = self.resample if resample is None: # If not explicitly set, we disable resampling in train mode (to allow # gradients to propagate through time) and enable in eval mode (to prevent # particle deprivation) resample = not self.training # If we're not resampling and our current particle count doesn't match # our desired particle count, we need to either expand or contract our # particle set if not resample and self.num_particles != M: indices = self.particle_states.new_zeros( (N, self.num_particles), dtype=torch.long ) # If output particles > our input particles, for the beginning part we copy # particles directly to reduce variance copy_count = (self.num_particles // M) * M if copy_count > 0: indices[:, :copy_count] = torch.arange(M).repeat(copy_count // M)[ None, : ] # For remaining particles, we sample w/o replacement (also lowers variance) remaining_count = self.num_particles - copy_count assert remaining_count >= 0 if remaining_count > 0: indices[:, copy_count:] = torch.randperm(M, device=indices.device)[ None, :remaining_count ] # Gather new particles, weights M = self.num_particles self.particle_states = self.particle_states.gather( 1, indices[:, :, None].expand((N, M, state_dim)) ) self.particle_log_weights = self.particle_log_weights.gather(1, indices) assert self.particle_states.shape == (N, self.num_particles, state_dim) assert self.particle_log_weights.shape == (N, self.num_particles) # Normalize particle weights to sum to 1.0 self.particle_log_weights = self.particle_log_weights - torch.logsumexp( self.particle_log_weights, dim=1, keepdim=True ) # Propagate particles through our dynamics model # A bit of extra effort is required for the extra particle dimension # > For our states, we flatten along the N/M axes # > For our controls, we repeat each one `M` times, if M=3: # [u0 u1 u2] should become [u0 u0 u0 u1 u1 u1 u2 u2 u2] # # Currently each of the M particles within a "sample" get the same action, but # we could also add noise in the action space (a la Jonschkowski et al. 2018) reshaped_states = self.particle_states.reshape(-1, self.state_dim) reshaped_controls = fannypack.utils.SliceWrapper(controls).map( lambda tensor: torch.repeat_interleave(tensor, repeats=M, dim=0) ) predicted_states, scale_trils = self.dynamics_model( initial_states=reshaped_states, controls=reshaped_controls ) self.particle_states = ( torch.distributions.MultivariateNormal( loc=predicted_states, scale_tril=scale_trils ) .rsample() # Note that we use `rsample` to make sampling differentiable .view(N, M, self.state_dim) ) assert self.particle_states.shape == (N, M, self.state_dim) # Re-weight particles using observations self.particle_log_weights = self.particle_log_weights + self.measurement_model( states=self.particle_states, observations=observations, ) # Normalize particle weights to sum to 1.0 self.particle_log_weights = self.particle_log_weights - torch.logsumexp( self.particle_log_weights, dim=1, keepdim=True ) # Compute output state_estimates: types.StatesTorch if self.estimation_method == "weighted_average": state_estimates = torch.sum( torch.exp(self.particle_log_weights[:, :, np.newaxis]) * self.particle_states, dim=1, ) elif self.estimation_method == "argmax": best_indices = torch.argmax(self.particle_log_weights, dim=1) state_estimates = torch.gather( self.particle_states, dim=1, index=best_indices ) else: assert False, "Unsupported estimation method!" # Resampling if resample: self._resample() # Post-condition :) assert state_estimates.shape == (N, state_dim) assert self.particle_states.shape == (N, self.num_particles, state_dim) assert self.particle_log_weights.shape == (N, self.num_particles) return state_estimates
def _resample(self) -> None: """Resample particles.""" # Note the distinction between `M`, the current number of particles, and # `self.num_particles`, the desired number of particles N, M, state_dim = self.particle_states.shape sample_logits: torch.Tensor uniform_log_weights = self.particle_log_weights.new_full( (N, self.num_particles), float(-np.log(M, dtype=np.float32)) ) if self.soft_resample_alpha < 1.0: # Soft resampling assert self.particle_log_weights.shape == (N, M) sample_logits = torch.logsumexp( torch.stack( [ self.particle_log_weights + np.log(self.soft_resample_alpha), uniform_log_weights + np.log(1.0 - self.soft_resample_alpha), ], dim=0, ), dim=0, ) self.particle_log_weights = self.particle_log_weights - sample_logits else: # Standard particle filter re-sampling -- this stops gradients # This is the most naive flavor of resampling, and not the low # variance approach # # Note the distinction between M, the current # of particles, # and self.num_particles, the desired # of particles sample_logits = self.particle_log_weights self.particle_log_weights = uniform_log_weights assert sample_logits.shape == (N, M) distribution = torch.distributions.Categorical(logits=sample_logits) state_indices = distribution.sample((self.num_particles,)).T assert state_indices.shape == (N, self.num_particles) self.particle_states = torch.gather( self.particle_states, dim=1, index=state_indices[:, :, None].expand((N, self.num_particles, state_dim)), )
# # ^This gather magic is equivalent to: # particle_states_alt = torch.zeros_like(self.particle_states) # for i in range(N): # particle_states_alt[i] = self.particle_states[i][state_indices[i]]