Skip to content

Discrete HMMs

cuthbertlib.discrete.filtering

Implements the discrete HMM filtering associative operator.

FilterScanElement

Bases: NamedTuple

Elements carried through the discrete HMM filtering scan.

f instance-attribute

log_g instance-attribute

condition_on_obs(state_probs, log_likelihoods)

Conditions a state distribution on an observation.

Parameters:

Name Type Description Default
state_probs Array

Either the state transition probabilities or the initial distribution.

required
log_likelihoods Array

Vector of \(\log p(y_t \mid x_t)\) for each possible state \(x_t\).

required

Returns:

Type Description
tuple[Array, Array]

The conditioned state and the log normalizing constant.

Source code in cuthbertlib/discrete/filtering.py
def condition_on_obs(state_probs: Array, log_likelihoods: Array) -> tuple[Array, Array]:
    r"""Conditions a state distribution on an observation.

    Args:
        state_probs: Either the state transition probabilities or the initial distribution.
        log_likelihoods: Vector of $\log p(y_t \mid x_t)$ for each possible state $x_t$.

    Returns:
        The conditioned state and the log normalizing constant.
    """
    ll_max = log_likelihoods.max(axis=-1)
    A_cond = state_probs * jnp.exp(log_likelihoods - ll_max)
    norm = A_cond.sum(axis=-1)
    A_cond /= jnp.expand_dims(norm, axis=-1)
    return A_cond, jnp.log(norm) + ll_max

filtering_operator(elem_ij, elem_jk)

Binary associative operator for filtering in discrete HMMs.

Parameters:

Name Type Description Default
elem_ij FilterScanElement

Filter scan element.

required
elem_jk FilterScanElement

Filter scan element.

required

Returns:

Type Description
FilterScanElement

The output of the associative operator applied to the input elements.

Source code in cuthbertlib/discrete/filtering.py
def filtering_operator(
    elem_ij: FilterScanElement, elem_jk: FilterScanElement
) -> FilterScanElement:
    """Binary associative operator for filtering in discrete HMMs.

    Args:
        elem_ij: Filter scan element.
        elem_jk: Filter scan element.

    Returns:
        The output of the associative operator applied to the input elements.
    """
    f, lognorm = condition_on_obs(elem_ij.f, elem_jk.log_g)
    f = f @ elem_jk.f
    log_g = elem_ij.log_g + lognorm
    return FilterScanElement(f, log_g)

cuthbertlib.discrete.smoothing

Implements the discrete HMM smoothing associative operator.

get_reverse_kernel(x_t_dist, trans_matrix)

Computes reverse transition probabilities \(p(x_{t-1} \mid x_{t}, \dots)\) for a discrete HMM.

Parameters:

Name Type Description Default
x_t_dist ArrayLike

Array of shape (N,) where x_t_dist[i] = \(p(x_{t} = i \mid \dots)\).

required
trans_matrix ArrayLike

Array of shape (N, N) where trans_matrix[i, j] = \(p(x_{t} = j \mid x_{t-1} = i)\).

required

Returns:

Type Description
Array

An (N, N) matrix x_tm1_dist[i, j] = \(p(x_{t-1} = j \mid x_{t} = i, \dots)\).

Source code in cuthbertlib/discrete/smoothing.py
def get_reverse_kernel(x_t_dist: ArrayLike, trans_matrix: ArrayLike) -> Array:
    r"""Computes reverse transition probabilities $p(x_{t-1} \mid x_{t}, \dots)$ for a discrete HMM.

    Args:
        x_t_dist: Array of shape (N,) where `x_t_dist[i]` = $p(x_{t} = i \mid \dots)$.
        trans_matrix: Array of shape (N, N) where
            `trans_matrix[i, j]` = $p(x_{t} = j \mid x_{t-1} = i)$.

    Returns:
        An (N, N) matrix `x_tm1_dist[i, j]` = $p(x_{t-1} = j \mid x_{t} = i, \dots)$.
    """
    x_t_dist, trans_matrix = jnp.asarray(x_t_dist), jnp.asarray(trans_matrix)
    pred = jnp.dot(trans_matrix.T, x_t_dist)
    x_tm1_dist = trans_matrix.T * x_t_dist[None, :] / pred[:, None]
    return x_tm1_dist

smoothing_operator(elem_ij, elem_jk)

Binary associative operator for smoothing in discrete HMMs.

Parameters:

Name Type Description Default
elem_ij Array

Smoothing scan element.

required
elem_jk Array

Smoothing scan element.

required

Returns:

Type Description
Array

The output of the associative operator applied to the input elements.

Source code in cuthbertlib/discrete/smoothing.py
def smoothing_operator(elem_ij: Array, elem_jk: Array) -> Array:
    """Binary associative operator for smoothing in discrete HMMs.

    Args:
        elem_ij: Smoothing scan element.
        elem_jk: Smoothing scan element.

    Returns:
        The output of the associative operator applied to the input elements.
    """
    return elem_jk @ elem_ij