Skip to content

ensemble_kalman_filter

cuthbert.enkf.ensemble_kalman_filter

Implements the high-level Ensemble Kalman Filter (EnKF).

See Algorithm 10.2, Sanz-Alonso et al., Inverse Problems and Data Assimilation. Based in part on the CD-Dynamax implementation.

EnKFState

Bases: NamedTuple

Ensemble Kalman filter state.

key instance-attribute

ensemble instance-attribute

model_inputs instance-attribute

log_normalizing_constant instance-attribute

n_particles property

Number of particles.

mean property

Ensemble mean.

chol_cov property

Generalised Cholesky factor of the ensemble sample covariance.

build_filter(init_sample, get_dynamics, get_observations, n_particles, inflation=0.0, perturbed_obs=True)

Builds an Ensemble Kalman Filter object.

Parameters:

Name Type Description Default
init_sample InitSample

Function to sample from the initial distribution from key and model inputs.

required
get_dynamics GetEnKFDynamics

Function to get dynamics function (x_t, key) -> x_{t+1} ~ p(x_{t+1} | x_t) from model inputs.

required
get_observations GetEnKFObservations

Function to get observation function, chol_R, and y from model inputs.

required
n_particles int

Number of particles.

required
inflation float

Multiplicative inflation factor for ensemble deviations.

0.0
perturbed_obs bool

If True, use perturbed observations (stochastic EnKF).

True

Returns:

Type Description
Filter

Filter object for the EnKF.

Raises:

Type Description
ValueError

If n_particles is less than 2.

Source code in cuthbert/enkf/ensemble_kalman_filter.py
def build_filter(
    init_sample: InitSample,
    get_dynamics: GetEnKFDynamics,
    get_observations: GetEnKFObservations,
    n_particles: int,
    inflation: float = 0.0,
    perturbed_obs: bool = True,
) -> Filter:
    """Builds an Ensemble Kalman Filter object.

    Args:
        init_sample: Function to sample from the initial distribution from key and model inputs.
        get_dynamics: Function to get dynamics function (x_t, key) -> x_{t+1} ~ p(x_{t+1} | x_t) from model inputs.
        get_observations: Function to get observation function, chol_R, and y from model inputs.
        n_particles: Number of particles.
        inflation: Multiplicative inflation factor for ensemble deviations.
        perturbed_obs: If True, use perturbed observations (stochastic EnKF).

    Returns:
        Filter object for the EnKF.

    Raises:
        ValueError: If ``n_particles`` is less than 2.
    """
    if n_particles < 2:
        raise ValueError("n_particles must be at least 2 for EnKF.")

    return Filter(
        init_prepare=partial(
            init_prepare,
            init_sample=init_sample,
            n_particles=n_particles,
        ),
        filter_prepare=partial(
            filter_prepare,
            init_sample=init_sample,
            n_particles=n_particles,
        ),
        filter_combine=partial(
            filter_combine,
            get_dynamics=get_dynamics,
            get_observations=get_observations,
            inflation=inflation,
            perturbed_obs=perturbed_obs,
        ),
        associative=False,
    )

init_prepare(model_inputs, init_sample, n_particles, key=None)

Prepare the initial state for the EnKF.

Parameters:

Name Type Description Default
model_inputs ArrayTreeLike

Model inputs.

required
init_sample InitSample

Function to sample from the initial distribution from key and model inputs.

required
n_particles int

Number of particles.

required
key KeyArray | None

JAX random key.

None

Returns:

Type Description
EnKFState

Initial EnKF state.

Raises:

Type Description
ValueError

If key is None.

Source code in cuthbert/enkf/ensemble_kalman_filter.py
def init_prepare(
    model_inputs: ArrayTreeLike,
    init_sample: InitSample,
    n_particles: int,
    key: KeyArray | None = None,
) -> EnKFState:
    """Prepare the initial state for the EnKF.

    Args:
        model_inputs: Model inputs.
        init_sample: Function to sample from the initial distribution from key and model inputs.
        n_particles: Number of particles.
        key: JAX random key.

    Returns:
        Initial EnKF state.

    Raises:
        ValueError: If key is None.
    """
    model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
    if key is None:
        raise ValueError("A JAX PRNG key must be provided.")

    # Sample ensemble from initial distribution
    keys = random.split(key, n_particles)
    ensemble = jax.vmap(init_sample, (0, None))(keys, model_inputs)

    return EnKFState(
        key=key,
        ensemble=ensemble,
        model_inputs=model_inputs,
        log_normalizing_constant=jnp.array(0.0),
    )

filter_prepare(model_inputs, init_sample, n_particles, key=None)

Prepare a state for an EnKF step.

Parameters:

Name Type Description Default
model_inputs ArrayTreeLike

Model inputs.

required
init_sample InitSample

Function to sample from the initial distribution from key and model inputs.

required
n_particles int

Number of particles.

required
key KeyArray | None

JAX random key.

None

Returns:

Type Description
EnKFState

Prepared EnKF state with dummy ensemble.

Raises:

Type Description
ValueError

If key is None.

Source code in cuthbert/enkf/ensemble_kalman_filter.py
def filter_prepare(
    model_inputs: ArrayTreeLike,
    init_sample: InitSample,
    n_particles: int,
    key: KeyArray | None = None,
) -> EnKFState:
    """Prepare a state for an EnKF step.

    Args:
        model_inputs: Model inputs.
        init_sample: Function to sample from the initial distribution from key and model inputs.
        n_particles: Number of particles.
        key: JAX random key.

    Returns:
        Prepared EnKF state with dummy ensemble.

    Raises:
        ValueError: If key is None.
    """
    model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
    if key is None:
        raise ValueError("A JAX PRNG key must be provided.")

    # Infer state shape from init_sample
    dummy_particle = jax.eval_shape(init_sample, key, model_inputs)
    x_dim = dummy_particle.shape[0]
    ensemble = jnp.empty((n_particles, x_dim))
    ensemble = dummy_tree_like(ensemble)

    return EnKFState(
        key=key,
        ensemble=ensemble,
        model_inputs=model_inputs,
        log_normalizing_constant=jnp.array(0.0),
    )

filter_combine(state_1, state_2, get_dynamics, get_observations, inflation=0.0, perturbed_obs=True)

Combine previous EnKF state with prepared state for current step.

Implements the EnKF predict + update cycle.

Parameters:

Name Type Description Default
state_1 EnKFState

EnKF state from the previous time step.

required
state_2 EnKFState

EnKF state prepared for the current step.

required
get_dynamics GetEnKFDynamics

Function to get dynamics function and chol_Q from model inputs.

required
get_observations GetEnKFObservations

Function to get observation function, chol_R, and y from model inputs.

required
inflation float

Multiplicative inflation factor.

0.0
perturbed_obs bool

If True, use perturbed observations.

True

Returns:

Type Description
EnKFState

Updated EnKF state.

Source code in cuthbert/enkf/ensemble_kalman_filter.py
def filter_combine(
    state_1: EnKFState,
    state_2: EnKFState,
    get_dynamics: GetEnKFDynamics,
    get_observations: GetEnKFObservations,
    inflation: float = 0.0,
    perturbed_obs: bool = True,
) -> EnKFState:
    """Combine previous EnKF state with prepared state for current step.

    Implements the EnKF predict + update cycle.

    Args:
        state_1: EnKF state from the previous time step.
        state_2: EnKF state prepared for the current step.
        get_dynamics: Function to get dynamics function and chol_Q from model inputs.
        get_observations: Function to get observation function, chol_R, and y from model inputs.
        inflation: Multiplicative inflation factor.
        perturbed_obs: If True, use perturbed observations.

    Returns:
        Updated EnKF state.
    """
    key_pred, key_update, key_next = random.split(state_1.key, 3)

    # Predict
    dynamics_fn = get_dynamics(state_2.model_inputs)
    predicted = enkf_lib.predict(
        key_pred,
        state_1.ensemble,
        dynamics_fn,
        inflation,
    )

    # Update
    observation_fn, chol_R, y = get_observations(state_2.model_inputs)
    updated, ll = enkf_lib.update(
        key_update,
        predicted,
        observation_fn,
        chol_R,
        y,
        perturbed_obs,
    )

    return EnKFState(
        key=key_next,
        ensemble=updated,
        model_inputs=state_2.model_inputs,
        log_normalizing_constant=state_1.log_normalizing_constant + ll,
    )