Resampling
This sub-repository provides a unified interface for a variety of resampling methods, which convert a set of weighted samples into an unweighted one which likely contains duplicates.
A typical call to the library would be:
sampling_key, resampling_key = jax.random.split(jax.random.key(0))
particles = jax.random.normal(sampling_key, (100, 2))
logits = jax.vmap(lambda x: jnp.where(jnp.all(x > 0), 0, -jnp.inf))(particles)
resampled_indices, _, resampled_particles = resampling.multinomial.resampling(resampling_key, logits, particles, 100)
Or for conditional resampling:
# Here we resample but keep particle at index 0 fixed
conditional_resampled_indices, _, conditional_resampled_particles = resampling.multinomial.conditional_resampling(
resampling_key, logits, particles, 100, pivot_in=0, pivot_out=0
)
Adaptive resampling (i.e. resampling only when the effective sample size is below a threshold) is also supported via a decorator:
adaptive_resampling = resampling.adaptive.ess_decorator(
resampling.multinomial.resampling,
threshold=0.5,
)
adaptive_resampled_indices, _, adaptive_resampled_particles = adaptive_resampling(
resampling_key, logits, particles, 100
)
For consistent gradient estimates with respect to model parameters, the stop-gradient particle filter is also implemented as a decorator.
differentiable_resampling = resampling.stop_gradient.stop_gradient_decorator(
resampling.multinomial.resampling
)
# can be combined with adaptive resampling
adaptive_and_differentiable_resampling = resampling.adaptive.ess_decorator(
differentiable_resampling,
threshold=0.5,
)
resampled_indices, _, resampled_particles = adaptive_and_differentiable_resampling(
resampling_key, logits, particles, 100
)
cuthbertlib.resampling.protocols
Shared protocols for resampling algorithms.
Resampling
Bases: Protocol
Protocol for resampling operations.
__call__(key, logits, positions, n)
cuthbertlib.resampling.multinomial
Implements multinomial resampling.
resampling(key, logits, positions, n)
Source code in cuthbertlib/resampling/multinomial.py
conditional_resampling(key, logits, positions, n, pivot_in, pivot_out)
Source code in cuthbertlib/resampling/multinomial.py
cuthbertlib.resampling.systematic
Implements systematic resampling.
resampling(key, logits, positions, n)
Source code in cuthbertlib/resampling/systematic.py
conditional_resampling(key, logits, positions, n, pivot_in, pivot_out)
Source code in cuthbertlib/resampling/systematic.py
cuthbertlib.resampling.killing
Implements killing resampling.
resampling(key, logits, positions, n)
Source code in cuthbertlib/resampling/killing.py
conditional_resampling(key, logits, positions, n, pivot_in, pivot_out)
Source code in cuthbertlib/resampling/killing.py
cuthbertlib.resampling.adaptive
Adaptive resampling decorator.
Provides a decorator to turn any Resampling function into an adaptive resampling function which performs resampling only when the effective sample size (ESS) falls below a threshold.
ess_decorator(func, threshold)
Wrap a Resampling function so that it only resamples when ESS < threshold.
The returned function is jitted and has n as a static argument. The
original resampler's docstring is appended to this wrapper's docstring so
IDEs and users can see the underlying algorithm documentation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Resampling
|
A resampling function with signature (key, logits, positions, n) -> (indices, logits_out, positions_out). |
required |
threshold
|
float
|
Fraction of particle count specifying when to resample. Resampling is triggered when ESS < ess_threshold * n. |
required |
Returns:
| Type | Description |
|---|---|
Resampling
|
A Resampling function implementing adaptive resampling. |
Source code in cuthbertlib/resampling/adaptive.py
cuthbertlib.resampling.autodiff
Implements decorators for automatic differentiation of resampling schemes.
Current supported is the stop_gradient resampling scheme, which provides the classical Fisher estimates for the score function via automatic differentiation. This can be wrapped around a resampling scheme such as multinomial or systematic resampling.
See Scibior and Wood (2021) for more details.
stop_gradient_decorator(func)
Wrap a Resampling function to use stop gradient resampling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Resampling
|
A resampling function with signature (key, logits, positions, n) -> (indices, logits_out, positions_out). |
required |
Returns:
| Type | Description |
|---|---|
Resampling
|
A Resampling function implementing stop gradient resampling. |