Factorial Types
cuthbert.factorial.types
Provides types for factorial state-space models.
GetFactorialIndices
Bases: Protocol
Protocol for getting the factorial indices.
__call__(model_inputs)
Extract the factorial indices from model inputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_inputs
|
ArrayTreeLike
|
Model inputs. |
required |
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Indices of the factors to extract. Integer array. |
Extract
Bases: Protocol
Protocol for extracting the relevant factors.
__call__(factorial_state, factorial_inds)
Extract factors from factorial state.
E.g. factorial_state might encode factorial means with shape (F, d) and
chol_covs with shape (F, d, d). Then model_inputs tells us factors i and
j are relevant, so we extract means[i] and means[j] and chol_covs[i] and
chol_covs[j]. Thus we return means with shape (2, d) and chol_covs with
shape (2, d, d).
Factorial Kalman state storing means and chol_covs
with shape (F, d) and (F, d, d) respectively.
factorial_inds: Indices of the factors to extract. Integer array. factorial_inds.ndim == 0 removes the factorial dimension and extracts a single factor. factorial_inds.ndim == 1 retains the factorial dimension, even if len(factorial_inds) == 1.
Returns:
| Type | Description |
|---|---|
ArrayTree
|
Local factorial state with factorial dimension of length len(factorial_inds). If factorial_inds is a single integer, the local factorial state should not have a factorial dimension. |
Source code in cuthbert/factorial/types.py
Join
Bases: Protocol
Protocol for combining factorial states into a joint state.
__call__(local_factorial_state)
Extract factors from factorial state and combine into a joint local state.
E.g., local_factorial_state might encode factorial means with shape (2, d)
and chol_covs with shape (2, d, d),
which is then combined into a joint state with shape (2 * d,)
and block diagonal joint_chol_cov with shape (2 * d, 2 * d).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
local_factorial_state
|
ArrayTreeLike
|
Factorial state with factorial index as the first
dimension. Typically contains only a small number of factors, as it's
applied after an |
required |
Returns:
| Type | Description |
|---|---|
ArrayTree
|
Joint state with no factorial index dimension. |
Source code in cuthbert/factorial/types.py
Marginalize
Bases: Protocol
Protocol for marginalizing a joint state into a factored state.
__call__(local_state, num_factors)
Marginalize joint state into factored state.
E.g. local_state might have shape (2 * d,) and joint_chol_cov
with shape (2 * d, 2 * d). Then we marginalize out the joint local state into
two factorial means with shape (2, d) and chol_covs with shape (2, d, d).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
local_state
|
ArrayTree
|
Joint local state with no factorial index dimension. |
required |
num_factors
|
int
|
Number of factors to marginalize out. Integer. This is typically equal to len(factorial_inds). |
required |
Returns:
| Type | Description |
|---|---|
ArrayTree
|
Factorial state with factorial index as the first dimension and
|
Source code in cuthbert/factorial/types.py
Insert
Bases: Protocol
Protocol for inserting a local factorial state into a factorial state.
__call__(local_factorial_state, factorial_state, factorial_inds)
Marginalize joint state into factored state and insert into factorial state.
E.g. local_factorial_state might have shape (2, d) and joint_chol_cov
with shape (2, d, d). Then we insert means[0] and means[1] into
state[i] and state[j] respectively. Similarly, we insert chol_covs[0] and
chol_covs[1]. In both cases, we overwrite the existing factors in the
factorial state for i and j, leaving the other factors unchanged.
Here i and j are determined from factorial_inds.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
local_factorial_state
|
ArrayTree
|
Local factorial state with factorial index as the first
dimension and |
required |
factorial_state
|
ArrayTree
|
Factorial state with factorial index as the first dimension. |
required |
factorial_inds
|
ArrayLike
|
Indices of the factors to insert. Integer array. factorial_inds.ndim == 0 will be treated the same as factorial_inds.ndim == 1 with len(factorial_inds) == 1 (i.e. insert a single factor). |
required |
Returns:
| Type | Description |
|---|---|
ArrayTree
|
Factorial state with factorial index as the first dimension. The updated factors are inserted into the factorial state. The remaining factors are left unchanged. |
Source code in cuthbert/factorial/types.py
FactorializeInitState
Bases: Protocol
Protocol for factorial post-processing of init_prepare.
__call__(init_state, model_inputs)
Any processing of the output of init_prepare for factorial inference.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
init_state
|
ArrayTreeLike
|
Output from base inference method's |
required |
model_inputs
|
ArrayTreeLike
|
The model inputs at the first time point. |
required |
Source code in cuthbert/factorial/types.py
Factorializer
Bases: NamedTuple
Factorializer object.
All functions are inference method dependent (e.g. Gaussian/SMC etc),
aside from the get_factorial_indices function which acts purely on model_inputs.
Attributes:
| Name | Type | Description |
|---|---|---|
get_factorial_indices |
GetFactorialIndices
|
Function to extract factorial indices from model inputs. |
extract |
Extract
|
Function to extract the relevant factors. |
join |
Join
|
Function to combine factorial states into a joint state. |
marginalize |
Marginalize
|
Function to marginalize a joint state into a factored state. |
insert |
Insert
|
Function to insert a local factorial state into a factorial state. |
factorialize_init_state |
FactorializeInitState
|
Optional post-processing function to |
extract_and_join |
ArrayTree
|
Apply extract and then join. Input: Global factorial state. Output: Local joint state. |
marginalize_and_insert |
ArrayTree
|
Apply marginalize and then insert. Input: Local joint state and global factorial state. Output: Global factorial state. |
get_factorial_indices
instance-attribute
extract
instance-attribute
join
instance-attribute
marginalize
instance-attribute
insert
instance-attribute
factorialize_init_state = lambda init_state, model_inputs: init_state
class-attribute
instance-attribute
extract_and_join(factorial_state, model_inputs)
Extract and join the relevant factors into a joint local state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
factorial_state
|
ArrayTreeLike
|
Factorial state with factorial index as the first dimension. |
required |
model_inputs
|
ArrayTreeLike
|
Model inputs, from which the factorial indices are extracted. |
required |
Returns:
| Type | Description |
|---|---|
ArrayTree
|
Joint local state with no factorial index dimension. |
Source code in cuthbert/factorial/types.py
marginalize_and_insert(local_state, factorial_state, model_inputs)
Marginalize and insert the relevant factors into a factorial state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
local_state
|
ArrayTree
|
Joint local state with no factorial index dimension. |
required |
factorial_state
|
ArrayTree
|
Factorial state with factorial index as the first dimension. |
required |
model_inputs
|
ArrayTreeLike
|
Model inputs, from which the factorial indices are extracted. |
required |
Returns:
| Type | Description |
|---|---|
ArrayTree
|
Factorial state with factorial index as the first dimension. The updated factors are inserted into the factorial state. The remaining factors are left unchanged. |