Skip to content

Factorial Types

cuthbert.factorial.types

Provides types for factorial state-space models.

GetFactorialIndices

Bases: Protocol

Protocol for getting the factorial indices.

__call__(model_inputs)

Extract the factorial indices from model inputs.

Parameters:

Name Type Description Default
model_inputs ArrayTreeLike

Model inputs.

required

Returns:

Type Description
ArrayLike

Indices of the factors to extract. Integer array.

Source code in cuthbert/factorial/types.py
def __call__(self, model_inputs: ArrayTreeLike) -> ArrayLike:
    """Extract the factorial indices from model inputs.

    Args:
        model_inputs: Model inputs.

    Returns:
        Indices of the factors to extract. Integer array.
    """
    ...

Extract

Bases: Protocol

Protocol for extracting the relevant factors.

__call__(factorial_state, factorial_inds)

Extract factors from factorial state.

E.g. factorial_state might encode factorial means with shape (F, d) and chol_covs with shape (F, d, d). Then model_inputs tells us factors i and j are relevant, so we extract means[i] and means[j] and chol_covs[i] and chol_covs[j]. Thus we return means with shape (2, d) and chol_covs with shape (2, d, d).

Factorial Kalman state storing means and chol_covs

with shape (F, d) and (F, d, d) respectively.

factorial_inds: Indices of the factors to extract. Integer array. factorial_inds.ndim == 0 removes the factorial dimension and extracts a single factor. factorial_inds.ndim == 1 retains the factorial dimension, even if len(factorial_inds) == 1.

Returns:

Type Description
ArrayTree

Local factorial state with factorial dimension of length len(factorial_inds). If factorial_inds is a single integer, the local factorial state should not have a factorial dimension.

Source code in cuthbert/factorial/types.py
def __call__(
    self,
    factorial_state: ArrayTreeLike,
    factorial_inds: ArrayLike,
) -> ArrayTree:
    """Extract factors from factorial state.

    E.g. factorial_state might encode factorial `means` with shape (F, d) and
    `chol_covs` with shape (F, d, d). Then `model_inputs` tells us factors `i` and
    `j` are relevant, so we extract `means[i]` and `means[j]` and `chol_covs[i]` and
    `chol_covs[j]`. Thus we return `means` with shape (2, d) and `chol_covs` with
    shape (2, d, d).

    factorial_state: Factorial Kalman state storing means and chol_covs
        with shape (F, d) and (F, d, d) respectively.
    factorial_inds: Indices of the factors to extract. Integer array.
        factorial_inds.ndim == 0 removes the factorial dimension and extracts
            a single factor.
        factorial_inds.ndim == 1 retains the factorial dimension,
            even if len(factorial_inds) == 1.

    Returns:
        Local factorial state with factorial dimension of length len(factorial_inds).
            If factorial_inds is a single integer, the local factorial state should
            not have a factorial dimension.
    """
    ...

Join

Bases: Protocol

Protocol for combining factorial states into a joint state.

__call__(local_factorial_state)

Extract factors from factorial state and combine into a joint local state.

E.g., local_factorial_state might encode factorial means with shape (2, d) and chol_covs with shape (2, d, d), which is then combined into a joint state with shape (2 * d,) and block diagonal joint_chol_cov with shape (2 * d, 2 * d).

Parameters:

Name Type Description Default
local_factorial_state ArrayTreeLike

Factorial state with factorial index as the first dimension. Typically contains only a small number of factors, as it's applied after an Extract operation.

required

Returns:

Type Description
ArrayTree

Joint state with no factorial index dimension.

Source code in cuthbert/factorial/types.py
def __call__(
    self,
    local_factorial_state: ArrayTreeLike,
) -> ArrayTree:
    """Extract factors from factorial state and combine into a joint local state.

    E.g., local_factorial_state might encode factorial `means` with shape (2, d)
    and `chol_covs` with shape (2, d, d),
    which is then combined into a joint state with shape (2 * d,)
    and block diagonal `joint_chol_cov` with shape (2 * d, 2 * d).

    Args:
        local_factorial_state: Factorial state with factorial index as the first
            dimension. Typically contains only a small number of factors, as it's
            applied after an `Extract` operation.

    Returns:
        Joint state with no factorial index dimension.
    """
    ...

Marginalize

Bases: Protocol

Protocol for marginalizing a joint state into a factored state.

__call__(local_state, num_factors)

Marginalize joint state into factored state.

E.g. local_state might have shape (2 * d,) and joint_chol_cov with shape (2 * d, 2 * d). Then we marginalize out the joint local state into two factorial means with shape (2, d) and chol_covs with shape (2, d, d).

Parameters:

Name Type Description Default
local_state ArrayTree

Joint local state with no factorial index dimension.

required
num_factors int

Number of factors to marginalize out. Integer. This is typically equal to len(factorial_inds).

required

Returns:

Type Description
ArrayTree

Factorial state with factorial index as the first dimension and num_factors factors (length of first dimension).

Source code in cuthbert/factorial/types.py
def __call__(
    self,
    local_state: ArrayTree,
    num_factors: int,
) -> ArrayTree:
    """Marginalize joint state into factored state.

    E.g. `local_state` might have shape (2 * d,) and `joint_chol_cov`
    with shape (2 * d, 2 * d). Then we marginalize out the joint local state into
    two factorial `means` with shape (2, d) and `chol_covs` with shape (2, d, d).

    Args:
        local_state: Joint local state with no factorial index dimension.
        num_factors: Number of factors to marginalize out. Integer.
            This is typically equal to len(factorial_inds).

    Returns:
        Factorial state with factorial index as the first dimension and
            `num_factors` factors (length of first dimension).
    """
    ...

Insert

Bases: Protocol

Protocol for inserting a local factorial state into a factorial state.

__call__(local_factorial_state, factorial_state, factorial_inds)

Marginalize joint state into factored state and insert into factorial state.

E.g. local_factorial_state might have shape (2, d) and joint_chol_cov with shape (2, d, d). Then we insert means[0] and means[1] into state[i] and state[j] respectively. Similarly, we insert chol_covs[0] and chol_covs[1]. In both cases, we overwrite the existing factors in the factorial state for i and j, leaving the other factors unchanged. Here i and j are determined from factorial_inds.

Parameters:

Name Type Description Default
local_factorial_state ArrayTree

Local factorial state with factorial index as the first dimension and len(factorial_inds) factors (length of first dimension).

required
factorial_state ArrayTree

Factorial state with factorial index as the first dimension.

required
factorial_inds ArrayLike

Indices of the factors to insert. Integer array. factorial_inds.ndim == 0 will be treated the same as factorial_inds.ndim == 1 with len(factorial_inds) == 1 (i.e. insert a single factor).

required

Returns:

Type Description
ArrayTree

Factorial state with factorial index as the first dimension. The updated factors are inserted into the factorial state. The remaining factors are left unchanged.

Source code in cuthbert/factorial/types.py
def __call__(
    self,
    local_factorial_state: ArrayTree,
    factorial_state: ArrayTree,
    factorial_inds: ArrayLike,
) -> ArrayTree:
    """Marginalize joint state into factored state and insert into factorial state.

    E.g. `local_factorial_state` might have shape (2, d) and `joint_chol_cov`
    with shape (2, d, d). Then we insert `means[0]` and `means[1]` into
    `state[i]` and `state[j]` respectively. Similarly, we insert `chol_covs[0]` and
    `chol_covs[1]`. In both cases, we overwrite the existing factors in the
    factorial state for `i` and `j`, leaving the other factors unchanged.
    Here `i` and `j` are determined from `factorial_inds`.

    Args:
        local_factorial_state: Local factorial state with factorial index as the first
            dimension and `len(factorial_inds)` factors (length of first dimension).
        factorial_state: Factorial state with factorial index as the first dimension.
        factorial_inds: Indices of the factors to insert. Integer array.
            factorial_inds.ndim == 0 will be treated the same as
            factorial_inds.ndim == 1 with len(factorial_inds) == 1
            (i.e. insert a single factor).

    Returns:
        Factorial state with factorial index as the first dimension.
            The updated factors are inserted into the factorial state.
            The remaining factors are left unchanged.
    """
    ...

FactorializeInitState

Bases: Protocol

Protocol for factorial post-processing of init_prepare.

__call__(init_state, model_inputs)

Any processing of the output of init_prepare for factorial inference.

Parameters:

Name Type Description Default
init_state ArrayTreeLike

Output from base inference method's init_prepare

required
model_inputs ArrayTreeLike

The model inputs at the first time point.

required
Source code in cuthbert/factorial/types.py
def __call__(
    self, init_state: ArrayTreeLike, model_inputs: ArrayTreeLike
) -> ArrayTree:
    """Any processing of the output of `init_prepare` for factorial inference.

    Args:
        init_state: Output from base inference method's `init_prepare`
        model_inputs: The model inputs at the first time point.
    """
    ...

Factorializer

Bases: NamedTuple

Factorializer object.

All functions are inference method dependent (e.g. Gaussian/SMC etc), aside from the get_factorial_indices function which acts purely on model_inputs.

Attributes:

Name Type Description
get_factorial_indices GetFactorialIndices

Function to extract factorial indices from model inputs.

extract Extract

Function to extract the relevant factors.

join Join

Function to combine factorial states into a joint state.

marginalize Marginalize

Function to marginalize a joint state into a factored state.

insert Insert

Function to insert a local factorial state into a factorial state.

factorialize_init_state FactorializeInitState

Optional post-processing function to init_prepare. By default leaves the output of init_prepare unchanged.

extract_and_join ArrayTree

Apply extract and then join. Input: Global factorial state. Output: Local joint state.

marginalize_and_insert ArrayTree

Apply marginalize and then insert. Input: Local joint state and global factorial state. Output: Global factorial state.

get_factorial_indices instance-attribute

extract instance-attribute

join instance-attribute

marginalize instance-attribute

insert instance-attribute

factorialize_init_state = lambda init_state, model_inputs: init_state class-attribute instance-attribute

extract_and_join(factorial_state, model_inputs)

Extract and join the relevant factors into a joint local state.

Parameters:

Name Type Description Default
factorial_state ArrayTreeLike

Factorial state with factorial index as the first dimension.

required
model_inputs ArrayTreeLike

Model inputs, from which the factorial indices are extracted.

required

Returns:

Type Description
ArrayTree

Joint local state with no factorial index dimension.

Source code in cuthbert/factorial/types.py
def extract_and_join(
    self, factorial_state: ArrayTreeLike, model_inputs: ArrayTreeLike
) -> ArrayTree:
    """Extract and join the relevant factors into a joint local state.

    Args:
        factorial_state: Factorial state with factorial index as the first dimension.
        model_inputs: Model inputs, from which the factorial indices are extracted.

    Returns:
        Joint local state with no factorial index dimension.
    """
    factorial_inds = self.get_factorial_indices(model_inputs)
    factorial_inds = jnp.asarray(factorial_inds)

    assert factorial_inds.ndim == 1, (
        "factorial_inds must be a 1D array to be used with join"
    )

    local_factorial_state = self.extract(factorial_state, factorial_inds)
    return self.join(local_factorial_state)

marginalize_and_insert(local_state, factorial_state, model_inputs)

Marginalize and insert the relevant factors into a factorial state.

Parameters:

Name Type Description Default
local_state ArrayTree

Joint local state with no factorial index dimension.

required
factorial_state ArrayTree

Factorial state with factorial index as the first dimension.

required
model_inputs ArrayTreeLike

Model inputs, from which the factorial indices are extracted.

required

Returns:

Type Description
ArrayTree

Factorial state with factorial index as the first dimension. The updated factors are inserted into the factorial state. The remaining factors are left unchanged.

Source code in cuthbert/factorial/types.py
def marginalize_and_insert(
    self,
    local_state: ArrayTree,
    factorial_state: ArrayTree,
    model_inputs: ArrayTreeLike,
) -> ArrayTree:
    """Marginalize and insert the relevant factors into a factorial state.

    Args:
        local_state: Joint local state with no factorial index dimension.
        factorial_state: Factorial state with factorial index as the first dimension.
        model_inputs: Model inputs, from which the factorial indices are extracted.

    Returns:
        Factorial state with factorial index as the first dimension.
            The updated factors are inserted into the factorial state.
            The remaining factors are left unchanged.
    """
    factorial_inds = self.get_factorial_indices(model_inputs)
    factorial_inds = jnp.asarray(factorial_inds)
    num_factors = len(factorial_inds)
    local_factorial_state = self.marginalize(local_state, num_factors)
    return self.insert(local_factorial_state, factorial_state, factorial_inds)