Skip to content

Resampling

This sub-repository provides a unified interface for a variety of resampling methods, which convert a set of weighted samples into an unweighted one which likely contains duplicates.

A typical call to the library would be:

sampling_key, resampling_key = jax.random.split(jax.random.key(0))
particles = jax.random.normal(sampling_key, (100, 2))
logits = jax.vmap(lambda x: jnp.where(jnp.all(x > 0), 0, -jnp.inf))(particles)

resampled_indices, _, resampled_particles = resampling.multinomial.resampling(resampling_key, logits, particles, 100)

Or for conditional resampling:

# Here we resample but keep particle at index 0 fixed
conditional_resampled_indices, _, conditional_resampled_particles = resampling.multinomial.conditional_resampling(
    resampling_key, logits, particles, 100, pivot_in=0, pivot_out=0
)

Adaptive resampling (i.e. resampling only when the effective sample size is below a threshold) is also supported via a decorator:

adaptive_resampling = resampling.adaptive.ess_decorator(
    resampling.multinomial.resampling,
    threshold=0.5,
)
adaptive_resampled_indices, _, adaptive_resampled_particles = adaptive_resampling(
    resampling_key, logits, particles, 100
)

For consistent gradient estimates with respect to model parameters, the stop-gradient particle filter is also implemented as a decorator.

differentiable_resampling = resampling.stop_gradient.stop_gradient_decorator(
    resampling.multinomial.resampling
)
# can be combined with adaptive resampling
adaptive_and_differentiable_resampling = resampling.adaptive.ess_decorator(
    differentiable_resampling,
    threshold=0.5,
)
resampled_indices, _, resampled_particles = adaptive_and_differentiable_resampling(
    resampling_key, logits, particles, 100
)

cuthbertlib.resampling.protocols

Shared protocols for resampling algorithms.

Resampling

Bases: Protocol

Protocol for resampling operations.

__call__(key, logits, positions, n)

Source code in cuthbertlib/resampling/protocols.py
def __call__(
    self, key: KeyArray, logits: ArrayLike, positions: ArrayTreeLike, n: int
) -> tuple[Array, Array, ArrayTree]:
    f"""Computes resampling indices according to given logits.
    {_RESAMPLING_DOC}
    """
    ...

cuthbertlib.resampling.multinomial

Implements multinomial resampling.

resampling(key, logits, positions, n)

Source code in cuthbertlib/resampling/multinomial.py
@partial(resampling_decorator, name="Multinomial", desc=_DESCRIPTION)
def resampling(
    key: Array, logits: ArrayLike, positions: ArrayTreeLike, n: int
) -> tuple[Array, Array, ArrayTree]:
    # In practice we don't have to sort the generated uniforms, but searchsorted
    # works faster and is more stable if both inputs are sorted, so we use the
    # _sorted_uniforms from N. Chopin, but still use searchsorted instead of his
    # O(N) loop as our code is meant to work on GPU where searchsorted is
    # O(log(N)) anyway.
    # We then permute the indices to enforce exchangeability.

    key_uniforms, key_shuffle = random.split(key)
    sorted_uniforms = _sorted_uniforms(key_uniforms, n)
    idx = inverse_cdf(sorted_uniforms, logits)
    idx = random.permutation(key_shuffle, idx)
    logits_out = jnp.zeros_like(sorted_uniforms)
    return idx, logits_out, apply_resampling_indices(positions, idx)

conditional_resampling(key, logits, positions, n, pivot_in, pivot_out)

Source code in cuthbertlib/resampling/multinomial.py
@partial(conditional_resampling_decorator, name="Multinomial", desc=_DESCRIPTION)
def conditional_resampling(
    key: Array,
    logits: ArrayLike,
    positions: ArrayTreeLike,
    n: int,
    pivot_in: ScalarArrayLike,
    pivot_out: ScalarArrayLike,
) -> tuple[Array, Array, ArrayTree]:
    pivot_in = jnp.asarray(pivot_in)
    pivot_out = jnp.asarray(pivot_out)

    idx, logits_out, _ = resampling(key, logits, positions, n)
    idx = idx.at[pivot_in].set(pivot_out)
    return idx, logits_out, apply_resampling_indices(positions, idx)

cuthbertlib.resampling.systematic

Implements systematic resampling.

resampling(key, logits, positions, n)

Source code in cuthbertlib/resampling/systematic.py
@partial(resampling_decorator, name="Systematic", desc=_DESCRIPTION)
def resampling(
    key: Array, logits: ArrayLike, positions: ArrayTreeLike, n: int
) -> tuple[Array, Array, ArrayTree]:
    us = (random.uniform(key, ()) + jnp.arange(n)) / n
    idx = inverse_cdf(us, logits)
    logits_out = jnp.zeros_like(us)
    return idx, logits_out, apply_resampling_indices(positions, idx)

conditional_resampling(key, logits, positions, n, pivot_in, pivot_out)

Source code in cuthbertlib/resampling/systematic.py
@partial(conditional_resampling_decorator, name="Systematic", desc=_DESCRIPTION)
def conditional_resampling(
    key: Array,
    logits: ArrayLike,
    positions: ArrayTreeLike,
    n: int,
    pivot_in: ScalarArrayLike,
    pivot_out: ScalarArrayLike,
) -> tuple[Array, Array, ArrayTree]:
    logits = jnp.asarray(logits)
    pivot_in = jnp.asarray(pivot_in)
    pivot_out = jnp.asarray(pivot_out)
    # FIXME: no need for normalizing in theory
    N = logits.shape[0]
    logits -= logsumexp(logits)

    # FIXME: this rolling should be done in a single function, but this is killing me.
    arange = jnp.arange(N)
    logits = jnp.roll(logits, -pivot_out)
    arange = jnp.roll(arange, -pivot_out)

    idx, logits_out = conditional_resampling_0_to_0(key, logits, n)
    idx = arange[idx]
    idx = jnp.roll(idx, pivot_in)
    return idx, logits_out, apply_resampling_indices(positions, idx)

cuthbertlib.resampling.killing

Implements killing resampling.

resampling(key, logits, positions, n)

Source code in cuthbertlib/resampling/killing.py
@partial(resampling_decorator, name="Killing", desc=_DESCRIPTION)
def resampling(
    key: Array, logits: ArrayLike, positions: ArrayTreeLike, n: int
) -> tuple[Array, Array, ArrayTree]:
    logits = jnp.asarray(logits)
    key_1, key_2 = random.split(key)
    N = logits.shape[0]
    if n != N:
        raise AssertionError(
            "The number of sampled indices must be equal to the number of "
            f"particles for `Killing` resampling. Got {n} instead of {N}."
        )

    max_logit = jnp.max(logits)
    log_uniforms = jnp.log(random.uniform(key_1, (N,)))

    survived = log_uniforms <= logits - max_logit
    if_survived = jnp.arange(N)  # If the particle survives, it keeps its index
    otherwise_idx, _, _ = multinomial.resampling(
        key_2, logits, positions, N
    )  # otherwise, it is replaced by another particle
    idx = jnp.where(survived, if_survived, otherwise_idx)
    # After resampling, all particles have equal weight
    logits_out = jnp.zeros_like(logits)
    return idx, logits_out, apply_resampling_indices(positions, idx)

conditional_resampling(key, logits, positions, n, pivot_in, pivot_out)

Source code in cuthbertlib/resampling/killing.py
@partial(conditional_resampling_decorator, name="Killing", desc=_DESCRIPTION)
def conditional_resampling(
    key: Array,
    logits: ArrayLike,
    positions: ArrayTreeLike,
    n: int,
    pivot_in: ScalarArrayLike,
    pivot_out: ScalarArrayLike,
) -> tuple[Array, Array, ArrayTree]:
    pivot_in = jnp.asarray(pivot_in)
    pivot_out = jnp.asarray(pivot_out)

    # Unconditional resampling
    key_resample, key_shuffle = random.split(key)
    idx_uncond, _, _ = resampling(key_resample, logits, positions, n)

    # Conditional rolling pivot
    max_logit = jnp.max(logits)

    pivot_logits = _log1mexp(logits - max_logit)
    pivot_logits -= jnp.log(n)
    pivot_logits = pivot_logits.at[pivot_out].set(-jnp.inf)
    pivot_logits_i = _log1mexp(logsumexp(pivot_logits))
    pivot_logits = pivot_logits.at[pivot_out].set(pivot_logits_i)

    pivot_weights = jnp.exp(pivot_logits - logsumexp(pivot_logits))
    pivot = random.choice(key_shuffle, n, p=pivot_weights)
    idx = jnp.roll(idx_uncond, pivot_in - pivot)
    idx = idx.at[pivot_in].set(pivot_out)
    # After resampling, all particles have equal weight
    logits_out = jnp.zeros_like(logits)
    return idx, logits_out, apply_resampling_indices(positions, idx)

cuthbertlib.resampling.adaptive

Adaptive resampling decorator.

Provides a decorator to turn any Resampling function into an adaptive resampling function which performs resampling only when the effective sample size (ESS) falls below a threshold.

ess_decorator(func, threshold)

Wrap a Resampling function so that it only resamples when ESS < threshold.

The returned function is jitted and has n as a static argument. The original resampler's docstring is appended to this wrapper's docstring so IDEs and users can see the underlying algorithm documentation.

Parameters:

Name Type Description Default
func Resampling

A resampling function with signature (key, logits, positions, n) -> (indices, logits_out, positions_out).

required
threshold float

Fraction of particle count specifying when to resample. Resampling is triggered when ESS < ess_threshold * n.

required

Returns:

Type Description
Resampling

A Resampling function implementing adaptive resampling.

Source code in cuthbertlib/resampling/adaptive.py
def ess_decorator(func: Resampling, threshold: float) -> Resampling:
    """Wrap a Resampling function so that it only resamples when ESS < threshold.

    The returned function is jitted and has `n` as a static argument. The
    original resampler's docstring is appended to this wrapper's docstring so
    IDEs and users can see the underlying algorithm documentation.

    Args:
        func: A resampling function with signature
              (key, logits, positions, n) -> (indices, logits_out, positions_out).
        threshold: Fraction of particle count specifying when to resample.
            Resampling is triggered when ESS < ess_threshold * n.

    Returns:
        A Resampling function implementing adaptive resampling.
    """
    # Build a descriptive docstring that includes the wrapped function doc
    wrapped_doc = func.__doc__ or ""
    doc = f"""
    Adaptive resampling decorator (threshold={threshold}).

    This wrapper will call the provided resampling function only when the
    effective sample size (ESS) is below `ess_threshold * n`.

    Wrapped resampler documentation:
    {wrapped_doc}
    """

    @wraps(func)
    def _wrapped(
        key: Array, logits: ArrayLike, positions: ArrayTreeLike, n: int
    ) -> tuple[Array, Array, ArrayTree]:
        logits_arr = jnp.asarray(logits)
        N = logits_arr.shape[0]
        if n != N:
            raise AssertionError(
                "The number of sampled indices must be equal to the number of "
                f"particles for `adaptive` resampling. Got {n} instead of {N}."
            )

        def _do_resample():
            return func(key, logits_arr, positions, n)

        def _no_resample():
            return jnp.arange(n), logits_arr, positions

        return jax.lax.cond(
            log_ess(logits_arr) < jnp.log(threshold * n),
            _do_resample,
            _no_resample,
        )

    # Attach the composed docstring and return a jitted version
    _wrapped.__doc__ = doc
    return jax.jit(_wrapped, static_argnames=("n",))

cuthbertlib.resampling.autodiff

Implements decorators for automatic differentiation of resampling schemes.

Current supported is the stop_gradient resampling scheme, which provides the classical Fisher estimates for the score function via automatic differentiation. This can be wrapped around a resampling scheme such as multinomial or systematic resampling.

See Scibior and Wood (2021) for more details.

stop_gradient_decorator(func)

Wrap a Resampling function to use stop gradient resampling.

Parameters:

Name Type Description Default
func Resampling

A resampling function with signature (key, logits, positions, n) -> (indices, logits_out, positions_out).

required

Returns:

Type Description
Resampling

A Resampling function implementing stop gradient resampling.

Source code in cuthbertlib/resampling/autodiff.py
def stop_gradient_decorator(func: Resampling) -> Resampling:
    """Wrap a Resampling function to use stop gradient resampling.

    Args:
        func: A resampling function with signature
              (key, logits, positions, n) -> (indices, logits_out, positions_out).

    Returns:
        A Resampling function implementing stop gradient resampling.
    """
    # Build a descriptive docstring that includes the wrapped function doc
    wrapped_doc = func.__doc__ or ""
    doc = f"""
    Stop gradient resampling decorator.

    This wrapper will call the provided resampling function, and then apply 
    the stop gradient trick of [Scibior and Wood (2021)](https://arxiv.org/abs/2106.10314). 
    Resulting estimates of the score function (i.e., the gradient of the 
    log-likelihood with respect to model parameters) are unbiased, 
    corresponding to the classical Fisher estimate.

    Wrapped resampler documentation:
    {wrapped_doc}
    """

    @wraps(func)
    def _wrapped(
        key: Array, logits: ArrayLike, positions: ArrayTreeLike, n: int
    ) -> tuple[Array, Array, ArrayTree]:
        idx_base, logits_base, positions_base = func(
            key, jax.lax.stop_gradient(logits), positions, n
        )

        logits = jnp.asarray(
            logits_base
            + apply_resampling_indices(logits, idx_base)
            - jax.lax.stop_gradient(apply_resampling_indices(logits, idx_base))
        )
        return idx_base, logits, positions_base

    # Attach the composed docstring and return a jitted version
    _wrapped.__doc__ = doc
    return jax.jit(_wrapped, static_argnames=("n",))