Resampling
This sub-repository provides a unified interface for a variety of resampling methods, which convert a set of weighted samples into an unweighted one which likely contains duplicates.
A typical call to the library would be:
sampling_key, resampling_key = jax.random.split(jax.random.key(0))
particles = jax.random.normal(sampling_key, (100, 2))
logits = jax.vmap(lambda x: jnp.where(jnp.all(x > 0), 0, -jnp.inf))(particles)
resampled_indices = resampling.multinomial.resampling(resampling_key, logits, 100)
resampled_particles = particles[resampled_indices]
Or for conditional resampling:
# Here we resample but keep particle at index 0 fixed
conditional_resampled_indices = resampling.multinomial.conditional_resampling(
resampling_key, logits, 100, pivot_in=0, pivot_out=0
)
conditional_resampled_particles = particles[conditional_resampled_indices]
cuthbertlib.resampling.protocols
Shared protocols for resampling algorithms.
Resampling
Bases: Protocol
Protocol for resampling operations.
__call__(key, logits, n)
Computes resampling indices according to given logits.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
KeyArray
|
JAX PRNG key. |
required |
logits
|
ArrayLike
|
Logits. |
required |
n
|
int
|
Number of indices to sample. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Array of resampling indices. |
Source code in cuthbertlib/resampling/protocols.py
cuthbertlib.resampling.multinomial
Implements multinomial resampling.
resampling(key, logits, n)
Source code in cuthbertlib/resampling/multinomial.py
conditional_resampling(key, logits, n, pivot_in, pivot_out)
Source code in cuthbertlib/resampling/multinomial.py
cuthbertlib.resampling.systematic
Implements systematic resampling.
resampling(key, logits, n)
conditional_resampling(key, logits, n, pivot_in, pivot_out)
Source code in cuthbertlib/resampling/systematic.py
cuthbertlib.resampling.killing
Implements killing resampling.