Skip to content

Sequential Monte Carlo Factorial Models

cuthbert.factorial.smc

Factorial utilities for SMC particle-filter states.

GeneralParticleFilterState = TypeVar('GeneralParticleFilterState', ParticleFilterState, MarginalParticleFilterState) module-attribute

build_factorializer(get_factorial_indices, resampling_fn)

Build a factorializer for particle-filter states.

In cuthbert.smc, resampling happens before propagation/reweighting. Factorial join needs unweighted particles, so resampling_fn is required and applied in join whenever local factor weights are not constant. It is therefore recommended to set the main SMC filter's resampling_fn to cuthbertlib.resampling.no_resampling.resampling to avoid redundant resampling.

Any weights passed to marginalize will be duplicated across factors.

Parameters:

Name Type Description Default
get_factorial_indices GetFactorialIndices

Function to extract factorial indices from model inputs.

required
resampling_fn Resampling

Resampling function used in join when local factor weights are not constant. Consider setting the main SMC filter's resampling_fn to cuthbertlib.resampling.no_resampling.resampling. Any adaptive ESS threshold will be overwritten, and resampling will be applied with a threshold of 1.

required

Returns:

Type Description
Factorializer

Factorializer for SMC states with extract, join, marginalize, and insert.

Source code in cuthbert/factorial/smc.py
def build_factorializer(
    get_factorial_indices: GetFactorialIndices,
    resampling_fn: Resampling,
) -> Factorializer:
    """Build a factorializer for particle-filter states.

    In `cuthbert.smc`, resampling happens before propagation/reweighting.
    Factorial `join` needs unweighted particles, so `resampling_fn` is required
    and applied in `join` whenever local factor weights are not constant.
    It is therefore recommended to set the main SMC filter's `resampling_fn` to
    `cuthbertlib.resampling.no_resampling.resampling` to avoid redundant resampling.

    Any weights passed to marginalize will be duplicated across factors.

    Args:
        get_factorial_indices: Function to extract factorial indices from model inputs.
        resampling_fn: Resampling function used in `join` when local factor
            weights are not constant.
            Consider setting the main SMC filter's `resampling_fn` to
            `cuthbertlib.resampling.no_resampling.resampling`.
            Any adaptive ESS threshold will be overwritten, and resampling will be
            applied with a threshold of 1.

    Returns:
        Factorializer for SMC states with extract, join, marginalize, and insert.
    """
    return Factorializer(
        get_factorial_indices=get_factorial_indices,
        extract=extract,
        join=lambda local_factorial_state: join(local_factorial_state, resampling_fn),
        marginalize=marginalize,
        insert=insert,
        factorialize_init_state=factorialize_init_state,
    )

factorialize_init_state(init_state, model_inputs)

Convert initial SMC state particles from (N, F, ...) to (F, N, ...).

Generic SMC filters sample initial particles with a leading particle axis. The factorial SMC machinery expects the factor axis to lead instead. Initial weights and particle filter ancestor indices are broadcast from (N,) to (F, N), matching the factorial SMC state layout.

Parameters:

Name Type Description Default
init_state GeneralParticleFilterState

Output from particle filter init_prepare

required
model_inputs ArrayTreeLike

The model inputs at the first time point - unused.

required
Source code in cuthbert/factorial/smc.py
def factorialize_init_state(
    init_state: GeneralParticleFilterState, model_inputs: ArrayTreeLike
) -> GeneralParticleFilterState:
    """Convert initial SMC state particles from `(N, F, ...)` to `(F, N, ...)`.

    Generic SMC filters sample initial particles with a leading particle axis.
    The factorial SMC machinery expects the factor axis to lead instead.
    Initial weights and particle filter ancestor indices are broadcast from
    `(N,)` to `(F, N)`, matching the factorial SMC state layout.

    Args:
        init_state: Output from particle filter `init_prepare`
        model_inputs: The model inputs at the first time point - unused.
    """
    particles = tree.map(lambda x: jnp.moveaxis(x, 0, 1), init_state.particles)
    n_factors = tree.leaves(particles)[0].shape[0]
    n_particles = init_state.log_weights.shape[0]

    new_state = init_state._replace(
        particles=particles,
        log_weights=jnp.broadcast_to(init_state.log_weights, (n_factors, n_particles)),
    )

    if isinstance(init_state, ParticleFilterState):
        new_state = new_state._replace(
            ancestor_indices=jnp.broadcast_to(
                init_state.ancestor_indices, (n_factors, n_particles)
            )
        )

    return new_state

extract(factorial_state, factorial_inds)

Extract selected factors from a factorial particle-filter state.

Parameters:

Name Type Description Default
factorial_state GeneralParticleFilterState

Factorial particle-filter state with factorized fields on leading axis F (particles, log_weights).

required
factorial_inds ArrayLike

Indices of factors to extract.

required

Returns:

Type Description
GeneralParticleFilterState

Local factorial particle-filter state with selected factors on the

GeneralParticleFilterState

leading axis of factorized fields.

Source code in cuthbert/factorial/smc.py
def extract(
    factorial_state: GeneralParticleFilterState,
    factorial_inds: ArrayLike,
) -> GeneralParticleFilterState:
    """Extract selected factors from a factorial particle-filter state.

    Args:
        factorial_state: Factorial particle-filter state with factorized fields
            on leading axis F (`particles`, `log_weights`).
        factorial_inds: Indices of factors to extract.

    Returns:
        Local factorial particle-filter state with selected factors on the
        leading axis of factorized fields.
    """
    factorial_inds = jnp.asarray(factorial_inds)
    particles = tree.map(lambda x: x[factorial_inds], factorial_state.particles)
    log_weights = factorial_state.log_weights[factorial_inds]

    new_state = factorial_state._replace(
        particles=particles,
        log_weights=log_weights,
    )

    if isinstance(factorial_state, ParticleFilterState):
        new_state = new_state._replace(
            ancestor_indices=factorial_state.ancestor_indices[factorial_inds]
        )

    return new_state

join(local_factorial_state, resampling_fn)

Join local factorial state into a single joint local particle-filter state.

Resampling is applied first, independently over factors, when local factor weights are not constant (detected via effective sample size). Then factorized particles are stacked into a local joint particle state. Joined bookkeeping uses always-resampled conventions: zero log weights.

Ancestor indices are valid for the resampling but ignored for the join i.e. retain the factorial axis (F, n_particles) and assumed not used in the particle filter.

Parameters:

Name Type Description Default
local_factorial_state GeneralParticleFilterState

Local factorial particle-filter state.

required
resampling_fn Resampling

Resampling function for factor-wise pre-join resampling.

required

Returns:

Type Description
GeneralParticleFilterState

Joint local particle-filter state with no factorial axis on particle values.

Source code in cuthbert/factorial/smc.py
def join(
    local_factorial_state: GeneralParticleFilterState,
    resampling_fn: Resampling,
) -> GeneralParticleFilterState:
    """Join local factorial state into a single joint local particle-filter state.

    Resampling is applied first, independently over factors, when local
    factor weights are not constant (detected via effective sample size).
    Then factorized particles are stacked into a local joint particle state.
    Joined bookkeeping uses always-resampled conventions: zero log weights.

    Ancestor indices are valid for the resampling but ignored for the join
    i.e. retain the factorial axis (F, n_particles) and assumed not used in the
    particle filter.

    Args:
        local_factorial_state: Local factorial particle-filter state.
        resampling_fn: Resampling function for factor-wise pre-join resampling.

    Returns:
        Joint local particle-filter state with no factorial axis on particle values.
    """
    n_factors, n_particles = local_factorial_state.log_weights.shape

    # Resample independently over factors
    # Applied if logits are not constant (i.e. ess_threshold = 1)
    resampling_fn = ess_decorator(resampling_fn, threshold=1 - 1e-6)
    keys = random.split(local_factorial_state.key, n_factors + 1)
    key = keys[0]
    resampling_keys = keys[1:]
    ancestor_indices, _, particles = jax.vmap(resampling_fn, in_axes=(0, 0, 0, None))(
        resampling_keys,
        local_factorial_state.log_weights,
        local_factorial_state.particles,
        n_particles,
    )  # log_weights ignored, all zeros

    # Combine factorial particles into joint
    # resampled_local_factorial_state.particles is shape e.g (F, n_particles, d)
    # resampled_local_factorial_state.particles is shape (F, n_particles)
    joint_state = local_factorial_state._replace(
        key=key,
        particles=tree.map(_join_particles_arr, particles),
        log_weights=jnp.zeros(
            (n_particles,), dtype=local_factorial_state.log_weights.dtype
        ),  # all zeros
    )

    if isinstance(joint_state, ParticleFilterState):
        joint_state = joint_state._replace(
            ancestor_indices=ancestor_indices
        )  # ancestor_indices retain factorial axis (F, n_particles) - assumed not used in particle filter

    return joint_state

marginalize(local_state, num_factors)

Marginalize a joint local particle state back to factorial form.

Weights are duplicated across factors. Ancestor indices are ignored for marginalization.

Parameters:

Name Type Description Default
local_state GeneralParticleFilterState

Joint local particle-filter state with particle leaves shaped (N, D_joint).

required
num_factors int

Number of local factors in the factorial representation.

required

Returns:

Type Description
GeneralParticleFilterState

Local factorial particle-filter state where particle leaves are shaped

GeneralParticleFilterState

(F, N, D_local) and bookkeeping follows always-resampled semantics

GeneralParticleFilterState

with missing ancestor indices (-1).

Source code in cuthbert/factorial/smc.py
def marginalize(
    local_state: GeneralParticleFilterState,
    num_factors: int,
) -> GeneralParticleFilterState:
    """Marginalize a joint local particle state back to factorial form.

    Weights are duplicated across factors.
    Ancestor indices are ignored for marginalization.

    Args:
        local_state: Joint local particle-filter state with particle leaves
            shaped `(N, D_joint)`.
        num_factors: Number of local factors in the factorial representation.

    Returns:
        Local factorial particle-filter state where particle leaves are shaped
        `(F, N, D_local)` and bookkeeping follows always-resampled semantics
        with missing ancestor indices (`-1`).
    """
    particles = tree.map(
        lambda x: _marginalize_particles_arr(x, num_factors), local_state.particles
    )
    log_weights = jnp.tile(local_state.log_weights[None], (num_factors, 1))
    return local_state._replace(
        particles=particles,
        log_weights=log_weights,
    )

insert(local_factorial_state, factorial_state, factorial_inds)

Insert local factorial update into the global factorial state.

Parameters:

Name Type Description Default
local_factorial_state GeneralParticleFilterState

Updated local factorial particle-filter state.

required
factorial_state GeneralParticleFilterState

Previous global factorial particle-filter state.

required
factorial_inds ArrayLike

Factor indices where local updates are inserted.

required

Returns:

Type Description
GeneralParticleFilterState

Updated global factorial particle-filter state.

Source code in cuthbert/factorial/smc.py
def insert(
    local_factorial_state: GeneralParticleFilterState,
    factorial_state: GeneralParticleFilterState,
    factorial_inds: ArrayLike,
) -> GeneralParticleFilterState:
    """Insert local factorial update into the global factorial state.

    Args:
        local_factorial_state: Updated local factorial particle-filter state.
        factorial_state: Previous global factorial particle-filter state.
        factorial_inds: Factor indices where local updates are inserted.

    Returns:
        Updated global factorial particle-filter state.
    """
    factorial_inds = jnp.asarray(factorial_inds)
    factorial_inds = jnp.atleast_1d(factorial_inds)

    particles = tree.map(
        lambda loc, glob: _insert_particles_arr(loc, glob, factorial_inds),
        local_factorial_state.particles,
        factorial_state.particles,
    )
    log_weights = factorial_state.log_weights.at[factorial_inds].set(
        local_factorial_state.log_weights
    )

    new_factorial_state = local_factorial_state._replace(
        particles=particles,
        log_weights=log_weights,
    )

    if isinstance(factorial_state, ParticleFilterState) and isinstance(
        local_factorial_state, ParticleFilterState
    ):
        ancestor_indices = factorial_state.ancestor_indices.at[factorial_inds].set(
            local_factorial_state.ancestor_indices
        )
        new_factorial_state = new_factorial_state._replace(
            ancestor_indices=ancestor_indices
        )

    return new_factorial_state