Skip to content

Discrete Factorial Models

cuthbert.factorial.discrete

Factorial utilities for discrete HMM states.

build_factorializer(get_factorial_indices)

Build a factorializer for discrete HMM filter states.

Parameters:

Name Type Description Default
get_factorial_indices GetFactorialIndices

Function to extract the factorial indices from model inputs.

required

Returns:

Type Description
Factorializer

Factorializer object for discrete states with functions to extract and join

Factorializer

the relevant factors and marginalize and insert the updated factors.

Source code in cuthbert/factorial/discrete.py
def build_factorializer(get_factorial_indices: GetFactorialIndices) -> Factorializer:
    """Build a factorializer for discrete HMM filter states.

    Args:
        get_factorial_indices: Function to extract the factorial indices
            from model inputs.

    Returns:
        Factorializer object for discrete states with functions to extract and join
        the relevant factors and marginalize and insert the updated factors.
    """
    return Factorializer(
        get_factorial_indices=get_factorial_indices,
        extract=extract,
        join=join,
        marginalize=marginalize,
        insert=insert,
    )

extract(factorial_state, factorial_inds)

Extract the relevant factors from a factorial discrete filter state.

Here F is the number of factors and K is the number of states per factor. The relevant leaves are: - elem.f with shape (F, K, K) - elem.log_g with shape (F, K) Both leaves are indexed on the leading factorial axis.

Parameters:

Name Type Description Default
factorial_state DiscreteFilterState

Factorial discrete filter state storing transition-like arrays and log normalizing vectors with leading factorial dimension F.

required
factorial_inds ArrayLike

Indices of the factors to extract. Integer array. factorial_inds.ndim == 0 removes the factorial dimension and extracts a single factor. factorial_inds.ndim == 1 retains the factorial dimension, even if len(factorial_inds) == 1.

required

Returns:

Type Description
DiscreteFilterState

Factorial discrete filter state with: - elem.f of shape (len(factorial_inds), K, K) - elem.log_g of shape (len(factorial_inds), K) If factorial_inds is a single integer, the returned local factorial state will not have a factorial dimension.

Source code in cuthbert/factorial/discrete.py
def extract(
    factorial_state: DiscreteFilterState, factorial_inds: ArrayLike
) -> DiscreteFilterState:
    """Extract the relevant factors from a factorial discrete filter state.

    Here F is the number of factors and K is the number of states per factor.
    The relevant leaves are:
        - elem.f with shape (F, K, K)
        - elem.log_g with shape (F, K)
    Both leaves are indexed on the leading factorial axis.

    Args:
        factorial_state: Factorial discrete filter state storing transition-like
            arrays and log normalizing vectors with leading factorial dimension F.
        factorial_inds: Indices of the factors to extract. Integer array.
            factorial_inds.ndim == 0 removes the factorial dimension and extracts
                a single factor.
            factorial_inds.ndim == 1 retains the factorial dimension,
                even if len(factorial_inds) == 1.

    Returns:
        Factorial discrete filter state with:
            - elem.f of shape (len(factorial_inds), K, K)
            - elem.log_g of shape (len(factorial_inds), K)
            If factorial_inds is a single integer, the returned local factorial
            state will not have a factorial dimension.
    """
    factorial_inds = jnp.asarray(factorial_inds)
    f = factorial_state.elem.f[factorial_inds]
    log_g = factorial_state.elem.log_g[factorial_inds]
    return factorial_state._replace(
        elem=factorial_state.elem._replace(f=f, log_g=log_g)
    )

join(local_factorial_state)

Convert a local factorial discrete state into a joint local state.

This operation is applied to
  • elem.f with shape (F, K, K), combined into a joint transition matrix of shape (KF, KF) via Kronecker products.
  • elem.log_g with shape (F, K), mapped to a joint vector of shape (K**F,). In this implementation, log_g is treated as a shared log-normalizing scalar and broadcast to the joint state dimension.

Here F is the number of local factors and K is the number of states per factor.

Parameters:

Name Type Description Default
local_factorial_state DiscreteFilterState

Local factorial discrete state storing leaves with leading factorial dimension F.

required

Returns:

Type Description
DiscreteFilterState

Joint local discrete state with no factorial index dimension.

Source code in cuthbert/factorial/discrete.py
def join(local_factorial_state: DiscreteFilterState) -> DiscreteFilterState:
    """Convert a local factorial discrete state into a joint local state.

    This operation is applied to:
        - elem.f with shape (F, K, K), combined into a joint transition matrix
          of shape (K**F, K**F) via Kronecker products.
        - elem.log_g with shape (F, K), mapped to a joint vector of shape (K**F,).
          In this implementation, log_g is treated as a shared log-normalizing
          scalar and broadcast to the joint state dimension.

    Here F is the number of local factors and K is the number of states per factor.

    Args:
        local_factorial_state: Local factorial discrete state storing leaves with
            leading factorial dimension F.

    Returns:
        Joint local discrete state with no factorial index dimension.
    """
    f = _join_matrices(local_factorial_state.elem.f)
    log_g = _join_log_vectors(local_factorial_state.elem.log_g)
    return local_factorial_state._replace(
        elem=local_factorial_state.elem._replace(f=f, log_g=log_g)
    )

marginalize(local_state, num_factors)

Marginalize a joint local discrete state into a local factorial state.

A joint local state stores
  • elem.f with shape (KF, KF)
  • elem.log_g with shape (K**F,)

This function returns: - elem.f with shape (F, K, K) by summing out all non-target factors. - elem.log_g with shape (F, K) by broadcasting the shared log-normalizer.

Parameters:

Name Type Description Default
local_state DiscreteFilterState

Joint local discrete state with no factorial index dimension.

required
num_factors int

Number of factors to marginalize out. Integer.

required

Returns:

Type Description
DiscreteFilterState

Local factorial discrete state with leading factorial dimension num_factors.

Source code in cuthbert/factorial/discrete.py
def marginalize(
    local_state: DiscreteFilterState, num_factors: int
) -> DiscreteFilterState:
    """Marginalize a joint local discrete state into a local factorial state.

    A joint local state stores:
        - elem.f with shape (K**F, K**F)
        - elem.log_g with shape (K**F,)
    This function returns:
        - elem.f with shape (F, K, K) by summing out all non-target factors.
        - elem.log_g with shape (F, K) by broadcasting the shared log-normalizer.

    Args:
        local_state: Joint local discrete state with no factorial index dimension.
        num_factors: Number of factors to marginalize out. Integer.

    Returns:
        Local factorial discrete state with leading factorial dimension
            num_factors.
    """
    f = _marginalize_matrix(local_state.elem.f, num_factors)
    log_g = _marginalize_log_vector(local_state.elem.log_g, num_factors)
    return local_state._replace(elem=local_state.elem._replace(f=f, log_g=log_g))

insert(local_factorial_state, factorial_state, factorial_inds)

Insert a local factorial discrete state into a factorial discrete state.

This operation is applied to
  • elem.f: local factors are inserted at factorial_inds in the leading axis.
  • elem.log_g: treated as a shared scalar log-normalizer and broadcast across all global factors/states.

Here F is the number of factors and K is the number of states per factor.

Parameters:

Name Type Description Default
local_factorial_state DiscreteFilterState

Local factorial discrete state to insert. Leaves with a factorial axis should have first dimension len(factorial_inds).

required
factorial_state DiscreteFilterState

Global factorial discrete state with first dimension F on leaves that carry a factorial axis.

required
factorial_inds ArrayLike

Indices of the factors to insert. Integer array.

required

Returns:

Type Description
DiscreteFilterState

Updated factorial discrete state with inserted factors.

Source code in cuthbert/factorial/discrete.py
def insert(
    local_factorial_state: DiscreteFilterState,
    factorial_state: DiscreteFilterState,
    factorial_inds: ArrayLike,
) -> DiscreteFilterState:
    """Insert a local factorial discrete state into a factorial discrete state.

    This operation is applied to:
        - elem.f: local factors are inserted at factorial_inds in the leading axis.
        - elem.log_g: treated as a shared scalar log-normalizer and broadcast
          across all global factors/states.

    Here F is the number of factors and K is the number of states per factor.

    Args:
        local_factorial_state: Local factorial discrete state to insert.
            Leaves with a factorial axis should have first dimension
            len(factorial_inds).
        factorial_state: Global factorial discrete state with first dimension F
            on leaves that carry a factorial axis.
        factorial_inds: Indices of the factors to insert. Integer array.

    Returns:
        Updated factorial discrete state with inserted factors.
    """
    factorial_inds = jnp.asarray(factorial_inds)
    factorial_inds = jnp.atleast_1d(factorial_inds)
    new_f = factorial_state.elem.f.at[factorial_inds].set(local_factorial_state.elem.f)
    new_log_g = jnp.full_like(
        factorial_state.elem.log_g, local_factorial_state.elem.log_g[0, 0]
    )
    return local_factorial_state._replace(
        elem=factorial_state.elem._replace(f=new_f, log_g=new_log_g),
    )