Skip to content

Factorial Utilities

cuthbert.factorial.utils

Utility functions to convert between serial and factorial trees.

serial_to_factorial(extract, serial_tree, factorial_inds, init_factorial_tree=None)

Convert a serial tree into a list of trees, one for each factor.

Parameters:

Name Type Description Default
extract Extract

Function to extract the relevant factors from the serial tree.

required
serial_tree ArrayTreeLike

The serial tree to convert. Each leaf of the tree should have shape (T, F, ...) where T is the number of time steps and F is the number of factors. Although some leaves may not have the factorial dimension F, as controlled by the extract function.

required
factorial_inds ArrayLike

The indices of the factors used in each element of the serial tree. Shape (T, F).

required
init_factorial_tree ArrayTree

Optional initial factorial tree to use, as the first elements of the returned list. Leaves with shape (F, ...)

None

Returns:

Type Description
list[ArrayTree]

A list of trees, one for each factor. Length max(factorial_inds) + 1. Each element has shape (T_i, ...) where T_i is the number of occurrences of index i in factorial_inds (which may be zero).

Source code in cuthbert/factorial/utils.py
def serial_to_factorial(
    extract: Extract,
    serial_tree: ArrayTreeLike,
    factorial_inds: ArrayLike,
    init_factorial_tree: ArrayTree = None,
) -> list[ArrayTree]:
    """Convert a serial tree into a list of trees, one for each factor.

    Args:
        extract: Function to extract the relevant factors from the serial tree.
        serial_tree: The serial tree to convert.
            Each leaf of the tree should have shape (T, F, ...) where T is the number of
            time steps and F is the number of factors.
            Although some leaves may not have the factorial dimension F, as controlled
            by the `extract` function.
        factorial_inds: The indices of the factors used in each element of the serial
            tree. Shape (T, F).
        init_factorial_tree: Optional initial factorial tree to use, as the first
            elements of the returned list.
            Leaves with shape (F, ...)

    Returns:
        A list of trees, one for each factor.
            Length max(factorial_inds) + 1.
            Each element has shape (T_i, ...) where T_i is the number of occurrences of
            index i in factorial_inds (which may be zero).
    """
    # TODO: This function is not very JAX-like or efficient, we may want to improve it in time.
    # although I'm not sure we can swap the for loop for a scan because the
    # elements of the list will have different lengths

    factorial_inds = jnp.asarray(factorial_inds)
    num_factors = jnp.max(factorial_inds) + 1
    T = tree.leaves(serial_tree)[0].shape[0]

    if init_factorial_tree is None:
        # Initialize factorial trees with empty tree of correct shape (for later concat)
        # This can probably be improved
        init_state = tree.map(lambda x: x[0], serial_tree)
        init_single_factor_state = extract(init_state, jnp.array([0]))
        factorial_trees = [
            tree.map(lambda x: jnp.zeros((0,) + x.shape[1:]), init_single_factor_state)
            for _ in range(num_factors)
        ]
    else:
        factorial_trees = [extract(init_factorial_tree, i) for i in range(num_factors)]
        # Add temporal dimension to init factorial trees
        factorial_trees = [tree.map(lambda x: x[None], tr) for tr in factorial_trees]

    for t in range(T):
        joint_factor_t = tree.map(lambda x: x[t], serial_tree)
        local_factors_t = vmap(extract, in_axes=(None, 0))(
            joint_factor_t, jnp.arange(len(factorial_inds[t]))
        )

        for j, ind in enumerate(factorial_inds[t]):
            factorial_trees[ind] = tree.map(
                lambda x, y: jnp.concatenate([x, y[j][None]]),
                factorial_trees[ind],
                local_factors_t,
            )

    return factorial_trees

serial_to_single_factor(extract, serial_tree, factorial_inds, factorial_index, init_factorial_tree=None)

Convert a serial tree into a single factor tree.

Parameters:

Name Type Description Default
extract Extract

Function to extract the relevant factors from the serial tree.

required
serial_tree ArrayTreeLike

The serial tree to convert. Each leaf of the tree should have shape (T, F, ...) where T is the number of time steps and F is the number of factors.

required
factorial_inds ArrayLike

The indices of the factors used in each element of the serial tree. Shape (T, F).

required
factorial_index int

Single integer index of the factor to extract.

required
init_factorial_tree ArrayTree

Optional initial factorial tree to use, as the first elements of the returned list. Leaves with shape (F, ...) of which only the factorial_index element will be used.

None

Returns:

Type Description
ArrayTree

A single ArrayTree with shape (T_i, ...) where T_i is the number of occurrences of

ArrayTree

the factorial index in factorial_inds.

Source code in cuthbert/factorial/utils.py
def serial_to_single_factor(
    extract: Extract,
    serial_tree: ArrayTreeLike,
    factorial_inds: ArrayLike,
    factorial_index: int,
    init_factorial_tree: ArrayTree = None,
) -> ArrayTree:
    """Convert a serial tree into a single factor tree.

    Args:
        extract: Function to extract the relevant factors from the serial tree.
        serial_tree: The serial tree to convert.
            Each leaf of the tree should have shape (T, F, ...) where T is the number of
            time steps and F is the number of factors.
        factorial_inds: The indices of the factors used in each element of the serial
            tree. Shape (T, F).
        factorial_index: Single integer index of the factor to extract.
        init_factorial_tree: Optional initial factorial tree to use, as the first
            elements of the returned list.
            Leaves with shape (F, ...) of which only the factorial_index element will be
            used.

    Returns:
        A single ArrayTree with shape (T_i, ...) where T_i is the number of occurrences of
        the factorial index in factorial_inds.
    """
    # TODO: As above, we can improve this and make it more JAX-like + efficient.
    # This one we might be able to swap the for loop for a scan because there is only
    # a single ArrayTree being returned

    factorial_inds = jnp.asarray(factorial_inds)
    T = tree.leaves(serial_tree)[0].shape[0]

    if init_factorial_tree is None:
        # Initialize factorial tree with empty tree of correct shape (for later concat)
        # This can probably be improved
        init_state = tree.map(lambda x: x[0], serial_tree)
        init_single_factor_state = extract(init_state, jnp.array([0]))
        factorial_tree = tree.map(
            lambda x: jnp.zeros((0,) + x.shape[1:]), init_single_factor_state
        )
    else:
        factorial_tree = extract(init_factorial_tree, factorial_index)
        factorial_tree = tree.map(lambda x: x[None], factorial_tree)

    for t in range(T):
        joint_factor_t = tree.map(lambda x: x[t], serial_tree)
        local_factors_t = vmap(extract, in_axes=(None, 0))(
            joint_factor_t, jnp.arange(len(factorial_inds[t]))
        )

        for j, ind in enumerate(factorial_inds[t]):
            if ind == factorial_index:
                factorial_tree = tree.map(
                    lambda x, y: jnp.concatenate([x, y[j][None]]),
                    factorial_tree,
                    local_factors_t,
                )

    return factorial_tree