SMC
This sub-repository provides modular tools useful for constructing sequential Monte
Carlo (SMC) algorithms. It pairs with the cuthbertlib.resampling sub-repository.
Backward simulation
cuthbertlib.smc.backward provides tools for backward simulation, i.e. converting
particles x0 from the filter distribution at time t0 and particles x1 from the smoothing
distribution at time t1 into joint particles from the smoothing distribution (x0, x1).
Backward simulation requires a log conditional density log_density(x0, x1) and
has computational cost O(N^2) where N is the number of particles.
cuthbertlib.smc.smoothing.protocols
Shared protocols for backward smoothing functions in SMC.
BackwardSampling
Bases: Protocol
Protocol for backward sampling functions.
__call__(key, x0_all, x1_all, log_weight_x0_all, log_density, x1_ancestor_indices)
Performs a backward sampling step.
Samples a collection of \(x_0\) that combine with the provided \(x_1\) to give a collection of pairs \((x_0, x_1)\) from the smoothing distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
KeyArray
|
JAX PRNG key. |
required |
x0_all
|
ArrayTreeLike
|
A collection of previous states \(x_0\). |
required |
x1_all
|
ArrayTreeLike
|
A collection of current states \(x_1\). |
required |
log_weight_x0_all
|
ArrayLike
|
The log weights of \(x_0\). |
required |
log_density
|
LogConditionalDensity
|
The log density function of \(x_1\) given \(x_0\). |
required |
x1_ancestor_indices
|
ArrayLike
|
The ancestor indices of \(x_1\). |
required |
Returns:
| Type | Description |
|---|---|
tuple[ArrayTree, Array]
|
A collection of samples \(x_0\) and their sampled indices. |
Source code in cuthbertlib/smc/smoothing/protocols.py
cuthbertlib.smc.smoothing.tracing
Implements the ancestor/genealogy tracing algorithm for smoothing in SMC.
simulate(key, x0_all, x1_all, log_weight_x0_all, log_density, x1_ancestor_indices)
Implements the ancestor/genealogy tracing algorithm for smoothing in SMC.
Some arguments are only included for protocol compatibility and not used in this implementation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
KeyArray
|
JAX PRNG key. Not used |
required |
x0_all
|
ArrayTreeLike
|
A collection of previous states \(x_0\). |
required |
x1_all
|
ArrayTreeLike
|
A collection of current states \(x_1\). Not used. |
required |
log_weight_x0_all
|
ArrayLike
|
The log weights of \(x_0\). Not used. |
required |
log_density
|
LogConditionalDensity
|
The log density function of \(x_1\) given \(x_0\). Not used. |
required |
x1_ancestor_indices
|
ArrayLike
|
The ancestor indices of \(x_1\). |
required |
Returns:
| Type | Description |
|---|---|
tuple[ArrayTree, Array]
|
A collection of samples \(x_0\) and their sampled indices. |
References
https://arxiv.org/abs/2207.00976
Source code in cuthbertlib/smc/smoothing/tracing.py
cuthbertlib.smc.smoothing.exact_sampling
Implements exact backward sampling for smoothing in SMC.
simulate(key, x0_all, x1_all, log_weight_x0_all, log_density, x1_ancestor_indices)
Implements the exact backward sampling algorithm for smoothing in SMC.
Some arguments are only included for protocol compatibility and not used in this implementation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
KeyArray
|
JAX PRNG key. |
required |
x0_all
|
ArrayTreeLike
|
A collection of previous states \(x_0\). |
required |
x1_all
|
ArrayTreeLike
|
A collection of current states \(x_1\). |
required |
log_weight_x0_all
|
ArrayLike
|
The log weights of \(x_0\). |
required |
log_density
|
LogConditionalDensity
|
The log density function of \(x_1\) given \(x_0\). |
required |
x1_ancestor_indices
|
ArrayLike
|
The ancestor indices of \(x_1\). Not used. |
required |
Returns:
| Type | Description |
|---|---|
tuple[ArrayTree, Array]
|
A collection of samples \(x_0\) and their sampled indices. |
Source code in cuthbertlib/smc/smoothing/exact_sampling.py
cuthbertlib.smc.smoothing.mcmc
Implements MCMC backward smoothing in SMC.
simulate(key, x0_all, x1_all, log_weight_x0_all, log_density, x1_ancestor_indices, n_steps)
Implements the IMH algorithm for smoothing in SMC.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
KeyArray
|
JAX PRNG key. |
required |
x0_all
|
ArrayTreeLike
|
A collection of previous states \(x_0\). |
required |
x1_all
|
ArrayTreeLike
|
A collection of current states \(x_1\). |
required |
log_weight_x0_all
|
ArrayLike
|
The log weights of \(x_0\). |
required |
log_density
|
LogConditionalDensity
|
The log density function of \(x_1\) given \(x_0\). |
required |
x1_ancestor_indices
|
ArrayLike
|
The ancestor indices of \(x_1\). |
required |
n_steps
|
int
|
Number of MCMC steps to perform. |
required |
Returns:
| Type | Description |
|---|---|
tuple[ArrayTree, Array]
|
A collection of samples \(x_0\) and their sampled indices. |
References
https://arxiv.org/abs/2207.00976