Temporally-Parallelized Kalman Filter
In cuthbert, we provide an implementation of the Kalman filter that can be
executed in parallel across time steps. For a problem with \(T\) time steps, if
there are \(T\) available parallel workers, this implementation achieves
logarithmic time complexity \(\mathcal{O}(\log(T))\) as opposed to the standard
linear time complexity, as shown in Särkka and Garcia-Fernández. Users can decide whether to run the filter in parallel
via the parallel argument to the filter
function. In this example, we demonstrate the usage and performance of the
temporally-parallelized Kalman filter.
Setup and imports
We first import the necessary libraries and specify a linear-Gaussian
state-space model. We use a helper function called generate_lgssm to create
example model parameters and observations, and then build the Kalman filter
object like we covered in the Kalman tracking example.
import timeit
import jax
import jax.numpy as jnp
import numpy as np
from cuthbert import filter
from cuthbert.gaussian import kalman
from cuthbertlib.kalman.generate import generate_lgssm
seed = 0
x_dim = 20
y_dim = 10
num_time_steps = 1000
m0, chol_P0, Fs, cs, chol_Qs, Hs, ds, chol_Rs, ys = generate_lgssm(
seed, x_dim, y_dim, num_time_steps
)
def get_init_params(model_inputs):
return m0, chol_P0
def get_dynamics_params(model_inputs):
return Fs[model_inputs - 1], cs[model_inputs - 1], chol_Qs[model_inputs - 1]
def get_observation_params(model_inputs):
return Hs[model_inputs], ds[model_inputs], chol_Rs[model_inputs], ys[model_inputs]
filter_obj = kalman.build_filter(
get_init_params, get_dynamics_params, get_observation_params
)
model_inputs = jnp.arange(num_time_steps + 1)
Time Everything
We JIT-compile the filter function, making
sure to mark the filter_obj and parallel arguments as static. We then
measure the compilation times using the
timeit module for both the
sequential and parallel implementations.
jitted_filter = jax.jit(filter, static_argnames=("filter_obj", "parallel"))
seq_compile_time = timeit.Timer(
lambda: jax.block_until_ready(
jitted_filter(filter_obj, model_inputs, parallel=False)
)
).timeit(number=1)
par_compile_time = timeit.Timer(
lambda: jax.block_until_ready(
jitted_filter(filter_obj, model_inputs, parallel=True)
)
).timeit(number=1)
Let's do the same for the runtimes. We run each implementation 10 times and report the minimum and median runtimes.
num_runs = 10
seq_runtimes = timeit.Timer(
lambda: jax.block_until_ready(
jitted_filter(filter_obj, model_inputs, parallel=False)
)
).repeat(repeat=num_runs, number=1)
par_runtimes = timeit.Timer(
lambda: jax.block_until_ready(
jitted_filter(filter_obj, model_inputs, parallel=True)
)
).repeat(repeat=num_runs, number=1)
print(" Sequential | Parallel")
print("-" * 35)
print(f"Compile time : {seq_compile_time: >7.3f}s | {par_compile_time: >7.3f}s")
print(f"Min runtime : {np.min(seq_runtimes): >7.3f}s | {np.min(par_runtimes): >7.3f}s")
print(f"Median runtime: {np.median(seq_runtimes): >7.3f}s | {np.median(par_runtimes): >7.3f}s")
Example Results
Running the above code on an AMD Ryzen 7 PRO 7840U CPU yields:
Sequential | Parallel
------------------------------------
Compile time : 0.422s | 4.932s
Min runtime : 0.042s | 0.071s
Median runtime: 0.043s | 0.076s
We highlight two things. First, the compile time for the parallel version is higher, and this is because the parallel implementation is more complex and has more operations (thus more work for the compiler). Second, since this CPU only has 16 threads, there's not enough opportunity for parallelism, and hence the parallel version ends up being slower due to the higher computational complexity.
The benefit of the parallel version becomes clear when we run it on a GPU, in this case on an NVIDIA A100-SXM4-80GB:
Sequential | Parallel
--------------------------------------
Compile time : 2.541s | 15.345s
Min runtime : 0.597s | 0.022s
Median runtime: 0.598s | 0.022s
The parallel implementation is now about 27 times faster than the sequential one, and this difference will only increase with increasing \(T\). So if you have a problem where you have to run the Kalman filter (or smoother) repeatedly for the same model, and you have a GPU available, it might be beneficial to pay the higher compilation cost and use the parallel implementation.
Key Takeaways
- Temporal parallelization:
cuthbertprovides parallel-in-time filtering that reduces time complexity from \(\mathcal{O}(T)\) to \(\mathcal{O}(\log T)\) when sufficient parallel workers are available. - Hardware-dependent performance: The parallel implementation shows significant speedups on GPUs/TPUs (27x faster in the example), but may be slower on CPUs with limited parallelism due to higher computational overhead.
- Compilation trade-off: The parallel version has higher compilation time, but this cost is amortized when running the filter multiple times on the same model.
- Simple API: Parallelization is enabled with a single
parallel=Trueargument to thefilterfunction, making it easy to experiment with both implementations. - JIT compilation: Both sequential and parallel filters can be JIT-compiled for optimal performance, with static arguments properly marked.
Next Steps
- Smoothing: Apply temporal parallelization to smoothing with
cuthbert.smootherfor backward pass efficiency. - More examples: Explore other examples including Kalman tracking and dynamax integration.
- Performance tuning: Experiment with different time series lengths and hardware configurations to find the optimal parallelization strategy for your use case.