"""Private module; avoid importing from directly.
"""
from typing import Optional
import fannypack
import numpy as np
import torch
from overrides import overrides
from .. import types
from ..base import DynamicsModel, Filter, ParticleFilterMeasurementModel
[docs]class ParticleFilter(Filter):
"""Generic differentiable particle filter."""
def __init__(
self,
*,
dynamics_model: DynamicsModel,
measurement_model: ParticleFilterMeasurementModel,
num_particles: int = 100,
resample: Optional[bool] = None,
soft_resample_alpha: float = 1.0,
estimation_method: str = "weighted_average",
):
# Check submodule consistency
assert isinstance(dynamics_model, DynamicsModel)
assert isinstance(measurement_model, ParticleFilterMeasurementModel)
assert dynamics_model.state_dim == measurement_model.state_dim
# Initialize state dimension
state_dim = dynamics_model.state_dim
super().__init__(state_dim=state_dim)
# Assign submodules
self.dynamics_model = dynamics_model
"""torchfilter.base.DynamicsModel: Forward model."""
self.measurement_model = measurement_model
"""torchfilter.base.ParticleFilterMeasurementModel: Observation model."""
# Settings
self.num_particles = num_particles
"""int: Number of particles to represent our belief distribution.
Defaults to 100."""
self.resample = resample
"""bool: If True, we resample particles & normalize weights at each
timestep. If unset (None), we automatically turn resampling on in eval mode
and off in train mode."""
self.soft_resample_alpha = soft_resample_alpha
"""float: Tunable constant for differentiable resampling, as described
by Karkus et al. in "Particle Filter Networks with Application to Visual
Localization": https://arxiv.org/abs/1805.08975
Defaults to 1.0 (disabled)."""
assert estimation_method in ("weighted_average", "argmax")
self.estimation_method = estimation_method
"""str: Method of producing state estimates. Options include:
- 'weighted_average': average of particles weighted by their weights.
- 'argmax': state of highest weighted particle.
"""
# "Hidden state" tensors
self.particle_states: torch.Tensor
"""torch.Tensor: Discrete particles representing our current belief
distribution. Shape should be `(N, M, state_dim)`.
"""
self.particle_log_weights: torch.Tensor
"""torch.Tensor: Weights corresponding to each particle, stored as
log-likelihoods. Shape should be `(N, M)`.
"""
self._initialized = False
[docs] @overrides
def initialize_beliefs(
self, *, mean: types.StatesTorch, covariance: types.CovarianceTorch
) -> None:
"""Populates initial particles, which will be normally distributed.
Args:
mean (torch.Tensor): Mean of belief. Shape should be
`(N, state_dim)`.
covariance (torch.Tensor): Covariance of belief. Shape should be
`(N, state_dim, state_dim)`.
"""
N = mean.shape[0]
assert mean.shape == (N, self.state_dim)
assert covariance.shape == (N, self.state_dim, self.state_dim)
M = self.num_particles
# Sample particles
self.particle_states = (
torch.distributions.MultivariateNormal(mean, covariance)
.sample((M,))
.transpose(0, 1)
)
assert self.particle_states.shape == (N, M, self.state_dim)
# Normalize weights
self.particle_log_weights = self.particle_states.new_full(
(N, M), float(-np.log(M, dtype=np.float32))
)
assert self.particle_log_weights.shape == (N, M)
# Set initialized flag
self._initialized = True
[docs] @overrides
def forward(
self,
*,
observations: types.ObservationsTorch,
controls: types.ControlsTorch,
) -> types.StatesTorch:
"""Particle filter forward pass, single timestep.
Args:
observations (dict or torch.Tensor): observation inputs. should be
either a dict of tensors or tensor of shape `(N, ...)`.
controls (dict or torch.Tensor): control inputs. should be either a
dict of tensors or tensor of shape `(N, ...)`.
Returns:
torch.Tensor: Predicted state for each batch element. Shape should
be `(N, state_dim).`
"""
# Make sure our particle filter's been initialized
assert self._initialized, "Particle filter not initialized!"
# Get our batch size (N), current particle count (M), & state dimension
N, M, state_dim = self.particle_states.shape
assert state_dim == self.state_dim
assert len(fannypack.utils.SliceWrapper(controls)) == N
# Decide whether or not we're resampling
resample = self.resample
if resample is None:
# If not explicitly set, we disable resampling in train mode (to allow
# gradients to propagate through time) and enable in eval mode (to prevent
# particle deprivation)
resample = not self.training
# If we're not resampling and our current particle count doesn't match
# our desired particle count, we need to either expand or contract our
# particle set
if not resample and self.num_particles != M:
indices = self.particle_states.new_zeros(
(N, self.num_particles), dtype=torch.long
)
# If output particles > our input particles, for the beginning part we copy
# particles directly to reduce variance
copy_count = (self.num_particles // M) * M
if copy_count > 0:
indices[:, :copy_count] = torch.arange(M).repeat(copy_count // M)[
None, :
]
# For remaining particles, we sample w/o replacement (also lowers variance)
remaining_count = self.num_particles - copy_count
assert remaining_count >= 0
if remaining_count > 0:
indices[:, copy_count:] = torch.randperm(M, device=indices.device)[
None, :remaining_count
]
# Gather new particles, weights
M = self.num_particles
self.particle_states = self.particle_states.gather(
1, indices[:, :, None].expand((N, M, state_dim))
)
self.particle_log_weights = self.particle_log_weights.gather(1, indices)
assert self.particle_states.shape == (N, self.num_particles, state_dim)
assert self.particle_log_weights.shape == (N, self.num_particles)
# Normalize particle weights to sum to 1.0
self.particle_log_weights = self.particle_log_weights - torch.logsumexp(
self.particle_log_weights, dim=1, keepdim=True
)
# Propagate particles through our dynamics model
# A bit of extra effort is required for the extra particle dimension
# > For our states, we flatten along the N/M axes
# > For our controls, we repeat each one `M` times, if M=3:
# [u0 u1 u2] should become [u0 u0 u0 u1 u1 u1 u2 u2 u2]
#
# Currently each of the M particles within a "sample" get the same action, but
# we could also add noise in the action space (a la Jonschkowski et al. 2018)
reshaped_states = self.particle_states.reshape(-1, self.state_dim)
reshaped_controls = fannypack.utils.SliceWrapper(controls).map(
lambda tensor: torch.repeat_interleave(tensor, repeats=M, dim=0)
)
predicted_states, scale_trils = self.dynamics_model(
initial_states=reshaped_states, controls=reshaped_controls
)
self.particle_states = (
torch.distributions.MultivariateNormal(
loc=predicted_states, scale_tril=scale_trils
)
.rsample() # Note that we use `rsample` to make sampling differentiable
.view(N, M, self.state_dim)
)
assert self.particle_states.shape == (N, M, self.state_dim)
# Re-weight particles using observations
self.particle_log_weights = self.particle_log_weights + self.measurement_model(
states=self.particle_states,
observations=observations,
)
# Normalize particle weights to sum to 1.0
self.particle_log_weights = self.particle_log_weights - torch.logsumexp(
self.particle_log_weights, dim=1, keepdim=True
)
# Compute output
state_estimates: types.StatesTorch
if self.estimation_method == "weighted_average":
state_estimates = torch.sum(
torch.exp(self.particle_log_weights[:, :, np.newaxis])
* self.particle_states,
dim=1,
)
elif self.estimation_method == "argmax":
best_indices = torch.argmax(self.particle_log_weights, dim=1)
state_estimates = torch.gather(
self.particle_states, dim=1, index=best_indices
)
else:
assert False, "Unsupported estimation method!"
# Resampling
if resample:
self._resample()
# Post-condition :)
assert state_estimates.shape == (N, state_dim)
assert self.particle_states.shape == (N, self.num_particles, state_dim)
assert self.particle_log_weights.shape == (N, self.num_particles)
return state_estimates
def _resample(self) -> None:
"""Resample particles."""
# Note the distinction between `M`, the current number of particles, and
# `self.num_particles`, the desired number of particles
N, M, state_dim = self.particle_states.shape
sample_logits: torch.Tensor
uniform_log_weights = self.particle_log_weights.new_full(
(N, self.num_particles), float(-np.log(M, dtype=np.float32))
)
if self.soft_resample_alpha < 1.0:
# Soft resampling
assert self.particle_log_weights.shape == (N, M)
sample_logits = torch.logsumexp(
torch.stack(
[
self.particle_log_weights + np.log(self.soft_resample_alpha),
uniform_log_weights + np.log(1.0 - self.soft_resample_alpha),
],
dim=0,
),
dim=0,
)
self.particle_log_weights = self.particle_log_weights - sample_logits
else:
# Standard particle filter re-sampling -- this stops gradients
# This is the most naive flavor of resampling, and not the low
# variance approach
#
# Note the distinction between M, the current # of particles,
# and self.num_particles, the desired # of particles
sample_logits = self.particle_log_weights
self.particle_log_weights = uniform_log_weights
assert sample_logits.shape == (N, M)
distribution = torch.distributions.Categorical(logits=sample_logits)
state_indices = distribution.sample((self.num_particles,)).T
assert state_indices.shape == (N, self.num_particles)
self.particle_states = torch.gather(
self.particle_states,
dim=1,
index=state_indices[:, :, None].expand((N, self.num_particles, state_dim)),
)
# # ^This gather magic is equivalent to:
# particle_states_alt = torch.zeros_like(self.particle_states)
# for i in range(N):
# particle_states_alt[i] = self.particle_states[i][state_indices[i]]