Skip to content

SMC

This sub-repository provides modular tools useful for constructing sequential Monte Carlo (SMC) algorithms. It pairs with the cuthbertlib.resampling sub-repository.

Backward simulation

cuthbertlib.smc.backward provides tools for backward simulation, i.e. converting particles x0 from the filter distribution at time t0 and particles x1 from the smoothing distribution at time t1 into joint particles from the smoothing distribution (x0, x1). Backward simulation requires a log conditional density log_density(x0, x1) and has computational cost O(N^2) where N is the number of particles.

cuthbertlib.smc.smoothing.protocols

Shared protocols for backward smoothing functions in SMC.

BackwardSampling

Bases: Protocol

Protocol for backward sampling functions.

__call__(key, x0_all, x1_all, log_weight_x0_all, log_density, x1_ancestor_indices)

Performs a backward sampling step.

Samples a collection of \(x_0\) that combine with the provided \(x_1\) to give a collection of pairs \((x_0, x_1)\) from the smoothing distribution.

Parameters:

Name Type Description Default
key KeyArray

JAX PRNG key.

required
x0_all ArrayTreeLike

A collection of previous states \(x_0\).

required
x1_all ArrayTreeLike

A collection of current states \(x_1\).

required
log_weight_x0_all ArrayLike

The log weights of \(x_0\).

required
log_density LogConditionalDensity

The log density function of \(x_1\) given \(x_0\).

required
x1_ancestor_indices ArrayLike

The ancestor indices of \(x_1\).

required

Returns:

Type Description
tuple[ArrayTree, Array]

A collection of samples \(x_0\) and their sampled indices.

Source code in cuthbertlib/smc/smoothing/protocols.py
def __call__(
    self,
    key: KeyArray,
    x0_all: ArrayTreeLike,
    x1_all: ArrayTreeLike,
    log_weight_x0_all: ArrayLike,
    log_density: LogConditionalDensity,
    x1_ancestor_indices: ArrayLike,
) -> tuple[ArrayTree, Array]:
    """Performs a backward sampling step.

    Samples a collection of $x_0$ that combine with the provided $x_1$ to
    give a collection of pairs $(x_0, x_1)$ from the smoothing distribution.

    Args:
        key: JAX PRNG key.
        x0_all: A collection of previous states $x_0$.
        x1_all: A collection of current states $x_1$.
        log_weight_x0_all: The log weights of $x_0$.
        log_density: The log density function of $x_1$ given $x_0$.
        x1_ancestor_indices: The ancestor indices of $x_1$.

    Returns:
        A collection of samples $x_0$ and their sampled indices.
    """
    ...

cuthbertlib.smc.smoothing.tracing

Implements the ancestor/genealogy tracing algorithm for smoothing in SMC.

simulate(key, x0_all, x1_all, log_weight_x0_all, log_density, x1_ancestor_indices)

Implements the ancestor/genealogy tracing algorithm for smoothing in SMC.

Some arguments are only included for protocol compatibility and not used in this implementation.

Parameters:

Name Type Description Default
key KeyArray

JAX PRNG key. Not used

required
x0_all ArrayTreeLike

A collection of previous states \(x_0\).

required
x1_all ArrayTreeLike

A collection of current states \(x_1\). Not used.

required
log_weight_x0_all ArrayLike

The log weights of \(x_0\). Not used.

required
log_density LogConditionalDensity

The log density function of \(x_1\) given \(x_0\). Not used.

required
x1_ancestor_indices ArrayLike

The ancestor indices of \(x_1\).

required

Returns:

Type Description
tuple[ArrayTree, Array]

A collection of samples \(x_0\) and their sampled indices.

References

https://arxiv.org/abs/2207.00976

Source code in cuthbertlib/smc/smoothing/tracing.py
def simulate(
    key: KeyArray,
    x0_all: ArrayTreeLike,
    x1_all: ArrayTreeLike,
    log_weight_x0_all: ArrayLike,
    log_density: LogConditionalDensity,
    x1_ancestor_indices: ArrayLike,
) -> tuple[ArrayTree, Array]:
    """Implements the ancestor/genealogy tracing algorithm for smoothing in SMC.

    Some arguments are only included for protocol compatibility and not used in this
    implementation.

    Args:
        key: JAX PRNG key. Not used
        x0_all: A collection of previous states $x_0$.
        x1_all: A collection of current states $x_1$. Not used.
        log_weight_x0_all: The log weights of $x_0$. Not used.
        log_density: The log density function of $x_1$ given $x_0$. Not used.
        x1_ancestor_indices: The ancestor indices of $x_1$.

    Returns:
        A collection of samples $x_0$ and their sampled indices.

    References:
        https://arxiv.org/abs/2207.00976
    """
    x1_ancestor_indices = jnp.asarray(x1_ancestor_indices)
    x0 = jax.tree.map(lambda z: z[x1_ancestor_indices], x0_all)
    return x0, x1_ancestor_indices

cuthbertlib.smc.smoothing.exact_sampling

Implements exact backward sampling for smoothing in SMC.

simulate(key, x0_all, x1_all, log_weight_x0_all, log_density, x1_ancestor_indices)

Implements the exact backward sampling algorithm for smoothing in SMC.

Some arguments are only included for protocol compatibility and not used in this implementation.

Parameters:

Name Type Description Default
key KeyArray

JAX PRNG key.

required
x0_all ArrayTreeLike

A collection of previous states \(x_0\).

required
x1_all ArrayTreeLike

A collection of current states \(x_1\).

required
log_weight_x0_all ArrayLike

The log weights of \(x_0\).

required
log_density LogConditionalDensity

The log density function of \(x_1\) given \(x_0\).

required
x1_ancestor_indices ArrayLike

The ancestor indices of \(x_1\). Not used.

required

Returns:

Type Description
tuple[ArrayTree, Array]

A collection of samples \(x_0\) and their sampled indices.

Source code in cuthbertlib/smc/smoothing/exact_sampling.py
def simulate(
    key: KeyArray,
    x0_all: ArrayTreeLike,
    x1_all: ArrayTreeLike,
    log_weight_x0_all: ArrayLike,
    log_density: LogConditionalDensity,
    x1_ancestor_indices: ArrayLike,
) -> tuple[ArrayTree, Array]:
    """Implements the exact backward sampling algorithm for smoothing in SMC.

    Some arguments are only included for protocol compatibility and not used in this
    implementation.

    Args:
        key: JAX PRNG key.
        x0_all: A collection of previous states $x_0$.
        x1_all: A collection of current states $x_1$.
        log_weight_x0_all: The log weights of $x_0$.
        log_density: The log density function of $x_1$ given $x_0$.
        x1_ancestor_indices: The ancestor indices of $x_1$. Not used.

    Returns:
        A collection of samples $x_0$ and their sampled indices.
    """
    log_weight_x0_all = jnp.asarray(log_weight_x0_all)
    n_smoother_particles = jax.tree.leaves(x1_all)[0].shape[0]
    keys = random.split(key, n_smoother_particles)

    return vmap(
        lambda k, x1: simulate_single(k, x0_all, x1, log_weight_x0_all, log_density)
    )(keys, x1_all)

cuthbertlib.smc.smoothing.mcmc

Implements MCMC backward smoothing in SMC.

simulate(key, x0_all, x1_all, log_weight_x0_all, log_density, x1_ancestor_indices, n_steps)

Implements the IMH algorithm for smoothing in SMC.

Parameters:

Name Type Description Default
key KeyArray

JAX PRNG key.

required
x0_all ArrayTreeLike

A collection of previous states \(x_0\).

required
x1_all ArrayTreeLike

A collection of current states \(x_1\).

required
log_weight_x0_all ArrayLike

The log weights of \(x_0\).

required
log_density LogConditionalDensity

The log density function of \(x_1\) given \(x_0\).

required
x1_ancestor_indices ArrayLike

The ancestor indices of \(x_1\).

required
n_steps int

Number of MCMC steps to perform.

required

Returns:

Type Description
tuple[ArrayTree, Array]

A collection of samples \(x_0\) and their sampled indices.

References

https://arxiv.org/abs/2207.00976

Source code in cuthbertlib/smc/smoothing/mcmc.py
def simulate(
    key: KeyArray,
    x0_all: ArrayTreeLike,
    x1_all: ArrayTreeLike,
    log_weight_x0_all: ArrayLike,
    log_density: LogConditionalDensity,
    x1_ancestor_indices: ArrayLike,
    n_steps: int,
) -> tuple[ArrayTree, Array]:
    """Implements the IMH algorithm for smoothing in SMC.

    Args:
        key: JAX PRNG key.
        x0_all: A collection of previous states $x_0$.
        x1_all: A collection of current states $x_1$.
        log_weight_x0_all: The log weights of $x_0$.
        log_density: The log density function of $x_1$ given $x_0$.
        x1_ancestor_indices: The ancestor indices of $x_1$.
        n_steps: Number of MCMC steps to perform.

    Returns:
        A collection of samples $x_0$ and their sampled indices.

    References:
        https://arxiv.org/abs/2207.00976
    """
    key, subkey = random.split(key)
    x0_init, x1_ancestor_indices = ancestor_tracing_simulate(
        subkey, x0_all, x1_all, log_weight_x0_all, log_density, x1_ancestor_indices
    )
    n_samples = x1_ancestor_indices.shape[0]

    keys = random.split(key, (n_steps * 2)).reshape((n_steps, 2))

    def body(carry, keys_t):
        # IMH proposal
        idx, x0_res, idx_log_p = carry
        key_prop, key_acc = keys_t

        prop_idx = multinomial.resampling(key_prop, log_weight_x0_all, n_samples)
        x0_prop = jax.tree.map(lambda z: z[prop_idx], x0_all)
        prop_log_p = jax.vmap(log_density)(x0_prop, x1_all)

        log_alpha = prop_log_p - idx_log_p

        lu = jnp.log(random.uniform(key_acc, (n_samples,)))
        acc = lu < log_alpha

        idx: Array = jnp.where(acc, prop_idx, idx)
        x0_res: ArrayTreeLike = jax.tree.map(lambda z: z[idx], x0_all)
        idx_log_p: Array = jnp.where(acc, prop_log_p, idx_log_p)
        return (idx, x0_res, idx_log_p), None

    x0_init = jax.tree.map(lambda z: z[x1_ancestor_indices], x0_all)
    init_log_p = jax.vmap(log_density)(x1_all, x0_init)
    init = (x1_ancestor_indices, x0_init, init_log_p)
    (out_index, out_samples, _), _ = jax.lax.scan(body, init, keys)
    return out_samples, out_index