Skip to content

filtering

cuthbertlib.enkf.filtering

Implements the Ensemble Kalman Filter (EnKF) predict and update steps.

See Algorithm 10.2, Sanz-Alonso et al., Inverse Problems and Data Assimilation. Based in part on the CD-Dynamax implementation.

ObservationFn = Callable[[Array], Array] module-attribute

DynamicsFn = Callable[[Array, KeyArray], Array] module-attribute

predict(key, ensemble, dynamics_fn, inflation=0.0)

Propagate ensemble members through an arbitrary simulator p(x_{t+1} | x_t).

Parameters:

Name Type Description Default
key KeyArray

JAX PRNG key.

required
ensemble Array

Ensemble of state vectors, shape (N, x_dim).

required
dynamics_fn DynamicsFn

Dynamics function mapping (state, key) -> state.

required
inflation float

Multiplicative inflation factor applied to ensemble deviations.

0.0

Returns:

Type Description
Array

Predicted ensemble, shape (N, x_dim).

Source code in cuthbertlib/enkf/filtering.py
def predict(
    key: KeyArray,
    ensemble: Array,
    dynamics_fn: DynamicsFn,
    inflation: float = 0.0,
) -> Array:
    """Propagate ensemble members through an arbitrary simulator p(x_{t+1} | x_t).

    Args:
        key: JAX PRNG key.
        ensemble: Ensemble of state vectors, shape (N, x_dim).
        dynamics_fn: Dynamics function mapping (state, key) -> state.
        inflation: Multiplicative inflation factor applied to ensemble deviations.

    Returns:
        Predicted ensemble, shape (N, x_dim).
    """
    N, x_dim = ensemble.shape

    # Propagate each member through the dynamics
    keys = random.split(key, N)
    propagated = jax.vmap(dynamics_fn, (0, 0))(ensemble, keys)

    # Apply multiplicative inflation
    mean = jnp.mean(propagated, axis=0)
    propagated = mean + (1 + inflation) * (propagated - mean)

    return propagated

update(key, predicted_ensemble, observation_fn, chol_R, y, perturbed_obs=True)

Update ensemble members with an observation using the EnKF update.

NaNs in y are treated as missing dimensions and are excluded from the update. When y is entirely NaN, the update is a no-op: the predicted ensemble is returned unchanged with zero log-likelihood contribution.

Parameters:

Name Type Description Default
key KeyArray

JAX PRNG key.

required
predicted_ensemble Array

Predicted ensemble, shape (N, x_dim).

required
observation_fn ObservationFn

Observation function mapping state -> obs.

required
chol_R Array

Cholesky factor of the observation noise covariance, shape (y_dim, y_dim).

required
y Array

Observation vector, shape (y_dim,). NaNs indicate missing dimensions.

required
perturbed_obs bool

If True, use perturbed observations (stochastic EnKF). If False, use deterministic update.

True

Returns:

Type Description
tuple[Array, ScalarArray]

Tuple of (updated_ensemble, log_likelihood).

Source code in cuthbertlib/enkf/filtering.py
def update(
    key: KeyArray,
    predicted_ensemble: Array,
    observation_fn: ObservationFn,
    chol_R: Array,
    y: Array,
    perturbed_obs: bool = True,
) -> tuple[Array, ScalarArray]:
    """Update ensemble members with an observation using the EnKF update.

    NaNs in ``y`` are treated as missing dimensions and are excluded from the
    update. When ``y`` is entirely NaN, the update is a no-op: the predicted
    ensemble is returned unchanged with zero log-likelihood contribution.

    Args:
        key: JAX PRNG key.
        predicted_ensemble: Predicted ensemble, shape (N, x_dim).
        observation_fn: Observation function mapping state -> obs.
        chol_R: Cholesky factor of the observation noise covariance, shape (y_dim, y_dim).
        y: Observation vector, shape (y_dim,). NaNs indicate missing dimensions.
        perturbed_obs: If True, use perturbed observations (stochastic EnKF).
            If False, use deterministic update.

    Returns:
        Tuple of (updated_ensemble, log_likelihood).
    """
    N, x_dim = predicted_ensemble.shape

    # Map ensemble to observation space
    y_pred = jax.vmap(observation_fn, (0,))(predicted_ensemble)

    # Handle partially-missing observations by reordering and zeroing missing dims.
    # Use y_pred.T because y_pred is (N, y_dim) and we want to reorder along axis 0.
    flag = jnp.isnan(y)
    flag, chol_R, y, y_pred = collect_nans_chol(flag, chol_R, y, y_pred.T)
    y_pred = y_pred.T
    y_dim = y.shape[0]

    # Ensemble means
    x_mean = jnp.mean(predicted_ensemble, axis=0)
    y_mean = jnp.mean(y_pred, axis=0)

    # Deviations from ensemble mean
    x_dev = predicted_ensemble - x_mean
    y_dev = y_pred - y_mean

    # Square-root innovation covariance via tria
    chol_S = tria(jnp.concatenate([y_dev.T / jnp.sqrt(N - 1), chol_R], axis=1))

    # Cross-covariance
    C_xy = x_dev.T @ y_dev / (N - 1)

    # Kalman gain: K = C_xy @ S^{-1} = C_xy @ cho_solve(chol_S, I)
    K = cho_solve((chol_S, True), C_xy.T).T

    # Innovation per member
    if perturbed_obs:
        y_n = y[None, :] + (chol_R @ random.normal(key, (y_dim, N))).T
    else:
        y_n = jnp.broadcast_to(y[None, :], (N, y_dim))

    # Update ensemble
    updated = predicted_ensemble + (y_n - y_pred) @ K.T

    # Log-likelihood
    ll = multivariate_normal.logpdf(y, y_mean, chol_S, nan_support=False)

    return updated, jnp.asarray(ll)