A JAX library for state-space model inference (filtering, smoothing, static parameter estimation).
Disclaimer: The name
cuthbertwas chosen as a playful nod to the well-known caterpillar cake rivalry between Aldi and M&S in the UK, as the classic state-space model diagram looks vaguely like a caterpillar. However, this software project has no formal connection to Aldi, M&S, or any food products (notwithstanding the coffee drunk during its writeup).cuthbertis simply a fun name for this state-space model library and should not be interpreted as an endorsement, association, or affiliation with any brand or animal themed baked goods.
Codebase structure
The codebase is structured as follows:
cuthbert: The main package with unified interface for filtering and smoothing.cuthbertlib: A collection of atomic, smaller-scoped tools useful for state-space model inference, that represent the building blocks that power the maincuthbertpackage.
Goals
- Simple, flexible and performant interface for state-space model inference.
- Decoupling of model specification and inference.
cuthbertis built to swap between different inference methods without be tied to a specific model specification. - Compose with the JAX ecosystem for extensive external tools.
- Functional API: The only classes in
cuthbertareNamedTuples andProtocols. All functions are pure and work seamlessly withjax.grad,jax.jit,jax.vmapetc. - Methods for filtering: \(p(x_t \mid y_{0:t}, \theta)\).
- Methods for smoothing: \(p(x_{0:T} \mid y_{0:T}, \theta)\) or \(p(x_{t} \mid y_{0:T}, \theta)\).
- Methods for static parameter estimation: \(p(\theta \mid y_{0:T})\) or \(\text{argmax} p(y_{0:T} \mid \theta)\).
- This includes support for forward-backward/Baum-Welch, particle filtering/sequential Monte Carlo, Kalman filtering (+ extended/unscented/ensemble), expectation-maximization and more!
Non-goals
- Tools for defining models and distributions.
cuthbertis not a probabilistic programming language (PPL). But can easily compose withdynamax,distrax,numpyroandpymcin a similar way to howblackjaxdoes. - "SMC Samplers" which sample from a posterior
distribution which is not (necessarily) a state-space model -
blackjaxis great for this.
Installation
cuthbert depends on JAX, so you'll need to install JAX for the available hardware (CPU, GPU, or TPU).
For example, on computers with NVIDIA GPUs:
Now install cuthbert from PyPI:
Installing cuthbert will also install cuthbertlib.
Ecosystem
cuthbertis built on top ofjaxand composes easily with other JAX packages, e.g.optaxfor optimization,flaxfor neural networks, andblackjaxfor (SG)MCMC as well as the PPLs mentioned above.- What about
dynamax?dynamaxis a great library for state-space model specification and inference with discrete or Gaussian state-space models.cuthbertis focused on inference with arbitrary state-space models via e.g. SMC that is not supported indynamax. However as they are both built onjaxthey can be used together! Adynamaxmodel can be passed tocuthbertfor inference.
- And
particles?particlesand the accompanying book Sequential Monte Carlo Methods in Practice are wonderful learning materials for state-space models and SMC.cuthbertis more focused on performance and composability with the JAX ecosystem.
- Much of the code in
cuthbertis built on work fromsqrt-parallel-smoothers,mocatandabile.