Skip to content

Resampling

This sub-repository provides a unified interface for a variety of resampling methods, which convert a set of weighted samples into an unweighted one which likely contains duplicates.

A typical call to the library would be:

sampling_key, resampling_key = jax.random.split(jax.random.key(0))
particles = jax.random.normal(sampling_key, (100, 2))
logits = jax.vmap(lambda x: jnp.where(jnp.all(x > 0), 0, -jnp.inf))(particles)

resampled_indices = resampling.multinomial.resampling(resampling_key, logits, 100)
resampled_particles = particles[resampled_indices]

Or for conditional resampling:

# Here we resample but keep particle at index 0 fixed
conditional_resampled_indices = resampling.multinomial.conditional_resampling(
    resampling_key, logits, 100, pivot_in=0, pivot_out=0
)
conditional_resampled_particles = particles[conditional_resampled_indices]

cuthbertlib.resampling.protocols

Shared protocols for resampling algorithms.

Resampling

Bases: Protocol

Protocol for resampling operations.

__call__(key, logits, n)

Computes resampling indices according to given logits.

Parameters:

Name Type Description Default
key KeyArray

JAX PRNG key.

required
logits ArrayLike

Logits.

required
n int

Number of indices to sample.

required

Returns:

Type Description
Array

Array of resampling indices.

Source code in cuthbertlib/resampling/protocols.py
def __call__(self, key: KeyArray, logits: ArrayLike, n: int) -> Array:
    """Computes resampling indices according to given logits.

    Args:
        key: JAX PRNG key.
        logits: Logits.
        n: Number of indices to sample.

    Returns:
        Array of resampling indices.
    """
    ...

cuthbertlib.resampling.multinomial

Implements multinomial resampling.

resampling(key, logits, n)

Source code in cuthbertlib/resampling/multinomial.py
@partial(resampling_decorator, name="Multinomial", desc=_DESCRIPTION)
def resampling(key: Array, logits: ArrayLike, n: int) -> Array:
    # In practice we don't have to sort the generated uniforms, but searchsorted
    # works faster and is more stable if both inputs are sorted, so we use the
    # _sorted_uniforms from N. Chopin, but still use searchsorted instead of his
    # O(N) loop as our code is meant to work on GPU where searchsorted is
    # O(log(N)) anyway.
    # We then permute the indices to enforce exchangeability.

    key_uniforms, key_shuffle = random.split(key)
    sorted_uniforms = _sorted_uniforms(key_uniforms, n)
    idx = inverse_cdf(sorted_uniforms, logits)
    return random.permutation(key_shuffle, idx)

conditional_resampling(key, logits, n, pivot_in, pivot_out)

Source code in cuthbertlib/resampling/multinomial.py
@partial(conditional_resampling_decorator, name="Multinomial", desc=_DESCRIPTION)
def conditional_resampling(
    key: Array, logits: ArrayLike, n: int, pivot_in: int, pivot_out: int
) -> Array:
    idx = resampling(key, logits, n)
    idx = idx.at[pivot_in].set(pivot_out)
    return idx

cuthbertlib.resampling.systematic

Implements systematic resampling.

resampling(key, logits, n)

Source code in cuthbertlib/resampling/systematic.py
@partial(resampling_decorator, name="Systematic", desc=_DESCRIPTION)
def resampling(key: Array, logits: ArrayLike, n: int) -> Array:
    us = (random.uniform(key, ()) + jnp.arange(n)) / n
    return inverse_cdf(us, logits)

conditional_resampling(key, logits, n, pivot_in, pivot_out)

Source code in cuthbertlib/resampling/systematic.py
@partial(conditional_resampling_decorator, name="Systematic", desc=_DESCRIPTION)
def conditional_resampling(
    key: Array, logits: ArrayLike, n: int, pivot_in: int, pivot_out: int
) -> Array:
    logits = jnp.asarray(logits)
    # FIXME: no need for normalizing in theory
    N = logits.shape[0]
    logits -= logsumexp(logits)

    # FIXME: this rolling should be done in a single function, but this is killing me.
    arange = jnp.arange(N)
    logits = jnp.roll(logits, -pivot_out)
    arange = jnp.roll(arange, -pivot_out)

    idx = conditional_resampling_0_to_0(key, logits, n)
    idx = arange[idx]
    idx = jnp.roll(idx, pivot_in)
    return idx

cuthbertlib.resampling.killing

Implements killing resampling.

resampling(key, logits, n)

Source code in cuthbertlib/resampling/killing.py
@partial(resampling_decorator, name="Killing", desc=_DESCRIPTION)
def resampling(key: Array, logits: ArrayLike, n: int) -> Array:
    logits = jnp.asarray(logits)
    key_1, key_2 = random.split(key)
    N = logits.shape[0]
    if n != N:
        raise AssertionError(
            "The number of sampled indices must be equal to the number of "
            f"particles for `Killing` resampling. Got {n} instead of {N}."
        )

    max_logit = jnp.max(logits)
    log_uniforms = jnp.log(random.uniform(key_1, (N,)))

    survived = log_uniforms <= logits - max_logit
    if_survived = jnp.arange(N)  # If the particle survives, it keeps its index
    otherwise = multinomial.resampling(
        key_2, logits, N
    )  # otherwise, it is replaced by another particle
    idx = jnp.where(survived, if_survived, otherwise)
    return idx

conditional_resampling(key, logits, n, pivot_in, pivot_out)

Source code in cuthbertlib/resampling/killing.py
@partial(conditional_resampling_decorator, name="Killing", desc=_DESCRIPTION)
def conditional_resampling(
    key: Array, logits: ArrayLike, n: int, pivot_in: int, pivot_out: int
) -> Array:
    # Unconditional resampling
    key_resample, key_shuffle = random.split(key)
    idx = resampling(key_resample, logits, n)

    # Conditional rolling pivot
    max_logit = jnp.max(logits)

    pivot_logits = _log1mexp(logits - max_logit)
    pivot_logits -= jnp.log(n)
    pivot_logits = pivot_logits.at[pivot_out].set(-jnp.inf)
    pivot_logits_i = _log1mexp(logsumexp(pivot_logits))
    pivot_logits = pivot_logits.at[pivot_out].set(pivot_logits_i)

    pivot_weights = jnp.exp(pivot_logits - logsumexp(pivot_logits))
    pivot = random.choice(key_shuffle, n, p=pivot_weights)
    idx = jnp.roll(idx, pivot_in - pivot)
    idx = idx.at[pivot_in].set(pivot_out)
    return idx