Skip to content

Types

cuthbert.smc.types

Provides types for representing generic Feynman--Kac models.

\[ \mathbb{Q}_{t}(x_{0:t}) \propto \mathbb{M}_0(x_0) \, G_0(x_0) \prod_{s=1}^{t} M_s(x_s \mid x_{s-1}) \, G_s(x_{s-1}, x_s). \]

InitSample

Bases: Protocol

Protocol for sampling from the initial distribution \(M_0(x_0)\).

__call__(key, model_inputs)

Samples from the initial distribution \(M_0(x_0)\).

Parameters:

Name Type Description Default
key KeyArray

JAX PRNG key.

required
model_inputs ArrayTreeLike

Model inputs.

required

Returns:

Type Description
ArrayTree

A sample \(x_0\).

Source code in cuthbert/smc/types.py
def __call__(self, key: KeyArray, model_inputs: ArrayTreeLike) -> ArrayTree:
    """Samples from the initial distribution $M_0(x_0)$.

    Args:
        key: JAX PRNG key.
        model_inputs: Model inputs.

    Returns:
        A sample $x_0$.
    """
    ...

PropagateSample

Bases: Protocol

Protocol for sampling from the Markov kernel \(M_t(x_t \mid x_{t-1})\).

__call__(key, state, model_inputs)

Samples from the Markov kernel \(M_t(x_t \mid x_{t-1})\).

Parameters:

Name Type Description Default
key KeyArray

JAX PRNG key.

required
state ArrayTreeLike

State at the previous step \(x_{t-1}\).

required
model_inputs ArrayTreeLike

Model inputs.

required

Returns:

Type Description
ArrayTree

A sample \(x_t\).

Source code in cuthbert/smc/types.py
def __call__(
    self, key: KeyArray, state: ArrayTreeLike, model_inputs: ArrayTreeLike
) -> ArrayTree:
    r"""Samples from the Markov kernel $M_t(x_t \mid x_{t-1})$.

    Args:
        key: JAX PRNG key.
        state: State at the previous step $x_{t-1}$.
        model_inputs: Model inputs.

    Returns:
        A sample $x_t$.
    """
    ...

LogPotential

Bases: Protocol

Protocol for computing the log potential function \(\log G_t(x_{t-1}, x_t)\).

__call__(state_prev, state, model_inputs)

Computes the log potential function \(\log G_t(x_{t-1}, x_t)\).

Parameters:

Name Type Description Default
state_prev ArrayTreeLike

State at the previous step \(x_{t-1}\).

required
state ArrayTreeLike

State at the current step \(x_{t}\).

required
model_inputs ArrayTreeLike

Model inputs.

required

Returns:

Type Description
ScalarArray

A scalar value \(\log G_t(x_{t-1}, x_t)\).

Source code in cuthbert/smc/types.py
def __call__(
    self,
    state_prev: ArrayTreeLike,
    state: ArrayTreeLike,
    model_inputs: ArrayTreeLike,
) -> ScalarArray:
    r"""Computes the log potential function $\log G_t(x_{t-1}, x_t)$.

    Args:
        state_prev: State at the previous step $x_{t-1}$.
        state: State at the current step $x_{t}$.
        model_inputs: Model inputs.

    Returns:
        A scalar value $\log G_t(x_{t-1}, x_t)$.
    """
    ...