Skip to content

Gaussian Factorial Models

cuthbert.factorial.gaussian

Factorial utilities for Kalman states.

KalmanState = TypeVar('KalmanState', KalmanFilterState, LinearizedKalmanFilterState) module-attribute

build_factorializer(get_factorial_indices)

Build a factorializer for Kalman states.

Parameters:

Name Type Description Default
get_factorial_indices GetFactorialIndices

Function to extract the factorial indices from model inputs.

required

Returns:

Type Description
Factorializer

Factorializer object for Kalman states with functions to extract and join

Factorializer

the relevant factors and marginalize and insert the updated factors.

Source code in cuthbert/factorial/gaussian.py
def build_factorializer(get_factorial_indices: GetFactorialIndices) -> Factorializer:
    """Build a factorializer for Kalman states.

    Args:
        get_factorial_indices: Function to extract the factorial indices
            from model inputs.

    Returns:
        Factorializer object for Kalman states with functions to extract and join
        the relevant factors and marginalize and insert the updated factors.
    """
    return Factorializer(
        get_factorial_indices=get_factorial_indices,
        extract=extract,
        join=join,
        marginalize=marginalize,
        insert=insert,
    )

extract(factorial_state, factorial_inds)

Extract the relevant factors from a factorial Kalman state.

Single dimensional arrays will be treated as scalars e.g. log normalizing constants. This means univariate problems still need to be stored with a dimension array (e.g. means with shape (F, 1) and chol_covs with shape (F, 1, 1)). Multidimensional arrays will be treated as arrays with shape (F, *). In this case the factorial_inds indices will be extracted from the first dimension and then the remaining dimensions will be preserved.

Here F is the number of factors and d is the dimension of the state.

Parameters:

Name Type Description Default
factorial_state KalmanState

Factorial Kalman state storing means and chol_covs with shape (F, d) and (F, d, d) respectively.

required
factorial_inds ArrayLike

Indices of the factors to extract. Integer array. factorial_inds.ndim == 0 removes the factorial dimension and extracts a single factor. factorial_inds.ndim == 1 retains the factorial dimension, even if len(factorial_inds) == 1.

required

Returns:

Type Description
KalmanState

Factorial Kalman state storing means and chol_covs with shape (len(factorial_inds), d) and (len(factorial_inds), d, d). If factorial_inds is a single integer, the returned local factorial state will not have a factorial dimension.

Source code in cuthbert/factorial/gaussian.py
def extract(factorial_state: KalmanState, factorial_inds: ArrayLike) -> KalmanState:
    """Extract the relevant factors from a factorial Kalman state.

    Single dimensional arrays will be treated as scalars e.g. log normalizing constants.
        This means univariate problems still need to be stored with a dimension array
        (e.g. means with shape (F, 1) and chol_covs with shape (F, 1, 1)).
    Multidimensional arrays will be treated as arrays with shape (F, *).
        In this case the factorial_inds indices will be extracted from the first
        dimension and then the remaining dimensions will be preserved.

    Here F is the number of factors and d is the dimension of the state.

    Args:
        factorial_state: Factorial Kalman state storing means and chol_covs
            with shape (F, d) and (F, d, d) respectively.
        factorial_inds: Indices of the factors to extract. Integer array.
            factorial_inds.ndim == 0 removes the factorial dimension and extracts
                a single factor.
            factorial_inds.ndim == 1 retains the factorial dimension,
                even if len(factorial_inds) == 1.

    Returns:
        Factorial Kalman state storing means and chol_covs
            with shape (len(factorial_inds), d) and (len(factorial_inds), d, d).
            If factorial_inds is a single integer, the returned local factorial state
            will not have a factorial dimension.
    """
    factorial_inds = jnp.asarray(factorial_inds)
    new_elem = tree.map(lambda x: _extract_arr(x, factorial_inds), factorial_state.elem)
    new_state = factorial_state._replace(elem=new_elem)

    if isinstance(new_state, LinearizedKalmanFilterState):
        new_mean_prev = _extract_arr(factorial_state.mean_prev, factorial_inds)
        new_state = new_state._replace(mean_prev=new_mean_prev)

    return new_state

join(local_factorial_state)

Convert a factorial Kalman state into a joint local Kalman state.

Single dimensional arrays will be treated as scalars e.g. log normalizing constants. This means univariate problems still need to be stored with a dimension array (e.g. means with shape (F, 1) and chol_covs with shape (F, 1, 1)). Two dimensional arrays will be treated as means with shape (F, d). In this case the factorial_inds indices will be extracted from the first dimension and then stacked into a single array. Three dimensional arrays will be treated as chol_covs with shape (F, d, d). In this case the factorial_inds indices will be extracted from the first dimension and then stacked into a block diagonal array.

Here F is the number of factors and d is the dimension of the state.

Parameters:

Name Type Description Default
local_factorial_state KalmanState

Factorial Kalman state storing means and chol_covs with shape (F, d) and (F, d, d) respectively.

required

Returns:

Type Description
KalmanState

Joint local Kalman state with no factorial index dimension.

Source code in cuthbert/factorial/gaussian.py
def join(local_factorial_state: KalmanState) -> KalmanState:
    """Convert a factorial Kalman state into a joint local Kalman state.

    Single dimensional arrays will be treated as scalars e.g. log normalizing constants.
        This means univariate problems still need to be stored with a dimension array
        (e.g. means with shape (F, 1) and chol_covs with shape (F, 1, 1)).
    Two dimensional arrays will be treated as means with shape (F, d).
        In this case the factorial_inds indices will be extracted from the first
        dimension and then stacked into a single array.
    Three dimensional arrays will be treated as chol_covs with shape (F, d, d).
        In this case the factorial_inds indices will be extracted from the first
        dimension and then stacked into a block diagonal array.

    Here F is the number of factors and d is the dimension of the state.

    Args:
        local_factorial_state: Factorial Kalman state storing means and chol_covs
            with shape (F, d) and (F, d, d) respectively.

    Returns:
        Joint local Kalman state with no factorial index dimension.
    """
    new_elem = tree.map(_join_arr, local_factorial_state.elem)
    new_state = local_factorial_state._replace(elem=new_elem)

    if isinstance(local_factorial_state, LinearizedKalmanFilterState):
        new_mean_prev = _join_arr(local_factorial_state.mean_prev)
        new_state = new_state._replace(mean_prev=new_mean_prev)

    return new_state

marginalize(local_state, num_factors)

Marginalize a joint local Kalman state into a factorial Kalman state.

Parameters:

Name Type Description Default
local_state KalmanState

Joint local Kalman state to marginalize and insert. With means and chol_covs with shape (d * len(factorial_inds),) and (d * len(factorial_inds), d * len(factorial_inds)) respectively.

required
num_factors int

Number of factors to marginalize out. Integer.

required

Returns:

Type Description
KalmanState

Joint local Kalman state with no factorial index dimension.

Source code in cuthbert/factorial/gaussian.py
def marginalize(local_state: KalmanState, num_factors: int) -> KalmanState:
    """Marginalize a joint local Kalman state into a factorial Kalman state.

    Args:
        local_state: Joint local Kalman state to marginalize and insert.
            With means and chol_covs with shape (d * len(factorial_inds),)
            and (d * len(factorial_inds), d * len(factorial_inds)) respectively.
        num_factors: Number of factors to marginalize out. Integer.

    Returns:
        Joint local Kalman state with no factorial index dimension.
    """
    new_elem = tree.map(
        lambda loc: _marginalize_arr(loc, num_factors),
        local_state.elem,
    )
    new_state = local_state._replace(elem=new_elem)
    if isinstance(local_state, LinearizedKalmanFilterState):
        new_mean_prev = _marginalize_arr(local_state.mean_prev, num_factors)
        new_state = new_state._replace(mean_prev=new_mean_prev)

    return new_state

insert(local_factorial_state, factorial_state, factorial_inds)

Insert a local factorial Kalman state into a factorial Kalman state.

Single dimensional arrays will be treated as scalars e.g. log normalizing constants. This means univariate problems still need to be stored with a dimension array (e.g. means with shape (F, 1) and chol_covs with shape (F, 1, 1)). Multidimensional arrays will be treated as arrays with shape (F, *). In this case the factorial_inds indices will be inserted into the first dimension and then the remaining dimensions will be preserved.

Here F is the number of factors and d is the dimension of the state.

Parameters:

Name Type Description Default
local_factorial_state KalmanState

Joint local Kalman state to marginalize and insert. With means and chol_covs with shape (len(factorial_inds), d) and (len(factorial_inds), d, d) respectively.

required
factorial_state KalmanState

Factorial Kalman state storing means and chol_covs with shape (F, d) and (F, d, d) respectively.

required
factorial_inds ArrayLike

Indices of the factors to insert. Integer array.

required

Returns:

Type Description
KalmanState

Joint local Kalman state with no factorial index dimension.

Source code in cuthbert/factorial/gaussian.py
def insert(
    local_factorial_state: KalmanState,
    factorial_state: KalmanState,
    factorial_inds: ArrayLike,
) -> KalmanState:
    """Insert a local factorial Kalman state into a factorial Kalman state.

    Single dimensional arrays will be treated as scalars e.g. log normalizing constants.
        This means univariate problems still need to be stored with a dimension array
        (e.g. means with shape (F, 1) and chol_covs with shape (F, 1, 1)).
    Multidimensional arrays will be treated as arrays with shape (F, *).
        In this case the factorial_inds indices will be inserted into the first
        dimension and then the remaining dimensions will be preserved.

    Here F is the number of factors and d is the dimension of the state.

    Args:
        local_factorial_state: Joint local Kalman state to marginalize and insert.
            With means and chol_covs with shape (len(factorial_inds), d)
            and (len(factorial_inds), d, d) respectively.
        factorial_state: Factorial Kalman state storing means and chol_covs
            with shape (F, d) and (F, d, d) respectively.
        factorial_inds: Indices of the factors to insert. Integer array.

    Returns:
        Joint local Kalman state with no factorial index dimension.
    """
    factorial_inds = jnp.asarray(factorial_inds)
    factorial_inds = jnp.atleast_1d(factorial_inds)
    new_elem = tree.map(
        lambda loc, glob: _insert_arr(loc, glob, factorial_inds),
        local_factorial_state.elem,
        factorial_state.elem,
    )
    new_state = factorial_state._replace(
        elem=new_elem,
        model_inputs=local_factorial_state.model_inputs,
    )

    if isinstance(local_factorial_state, LinearizedKalmanFilterState) and isinstance(
        factorial_state, LinearizedKalmanFilterState
    ):
        new_mean_prev = _insert_arr(
            local_factorial_state.mean_prev, factorial_state.mean_prev, factorial_inds
        )
        new_state = new_state._replace(mean_prev=new_mean_prev)

    return new_state