Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
021550d
support Array and lambda diffusions (special support for scalar/diag/…
mattlevine22 Apr 27, 2026
9652cbf
adding back padding of diffusion-coefficient in CD-Dynamax integration
mattlevine22 Apr 27, 2026
8c6726e
Update dynestyx/inference/integrations/cd_dynamax/utils.py
mattlevine22 Apr 27, 2026
38df332
Update dynestyx/models/diffusions.py
mattlevine22 Apr 27, 2026
aa50c38
docs: rename misleading test names to reflect bm_dim > state_dim cons…
Copilot Apr 27, 2026
ecd987e
plz lint
mattlevine22 Apr 27, 2026
bc0c252
Merge branch 'main' into ml-feature-212
mattlevine22 Apr 28, 2026
375cf59
get diffusion info at init
mattlevine22 Apr 28, 2026
0fa5667
give discretizer fallback if dynamicalModel init hadnt been run (supp…
mattlevine22 Apr 28, 2026
bf1f456
new Diffusion class with subtypes.
mattlevine22 May 10, 2026
ef5dab6
merge with main (plate upgrades)
mattlevine22 May 10, 2026
20d1d37
dont use cast, just assert bm_dim is not None (fight w linter...could…
mattlevine22 May 11, 2026
e9ea930
merge with main (smoothers)
mattlevine22 May 14, 2026
bc41705
please lint
mattlevine22 May 14, 2026
3f37300
rename as gram
mattlevine22 May 14, 2026
657a190
rename func to _coerce_to_param_dtype
mattlevine22 May 14, 2026
958a2a7
simplify edit to Quick example
mattlevine22 May 14, 2026
0e04441
dont change faq
mattlevine22 May 14, 2026
2cfab46
simplify notebook changes for PR
mattlevine22 May 14, 2026
90a2813
simplify notebook changes for PR
mattlevine22 May 14, 2026
c8c5cd8
simplify notebook changes for PR
mattlevine22 May 14, 2026
2e02f2f
improve API documentation for diffusions
mattlevine22 May 14, 2026
740750e
improve API documentation for diffusions
mattlevine22 May 14, 2026
2ebb45c
Update dynestyx/models/diffusions.py
mattlevine22 May 15, 2026
3612c92
simplify code and improve docs
mattlevine22 May 15, 2026
a4cee35
clarify resolve metadata docstring
mattlevine22 May 15, 2026
19a344b
Update dynestyx/inference/integrations/cd_dynamax/utils.py
mattlevine22 May 15, 2026
6d39004
probing and errors for cd-dynamax diffusion plumbing
mattlevine22 May 15, 2026
d2f8564
fix lint
mattlevine22 May 15, 2026
83641cb
Update dynestyx/models/core.py
mattlevine22 May 15, 2026
8501a17
streamline DynamicalModel init
mattlevine22 May 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/api_reference/developer/models/core_models.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Core Models

This page includes the developer-facing refined continuous-time state-evolution
classes:

- `DeterministicContinuousTimeStateEvolution`
- `StochasticContinuousTimeStateEvolution`

Most users should continue to construct `ContinuousTimeStateEvolution` directly and let
`DynamicalModel` canonicalize it internally.

::: dynestyx.models.core
options:
filters: []
59 changes: 59 additions & 0 deletions docs/api_reference/developer/models/diffusions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Diffusions

This page documents the diffusion internals used by continuous-time solvers,
discretizers, and backend integrations.

## Overview

Public model code should usually work with `FullDiffusion`,
`DiagonalDiffusion`, and `ScalarDiffusion` through
`ContinuousTimeStateEvolution(diffusion=...)`.

Developer-facing code sometimes needs the lower-level evaluation layer:

- `Diffusion.evaluate(...)` evaluates a structured diffusion object at a
concrete `(x, u, t)`.
- `EvaluatedDiffusion.as_matrix(...)` returns the corresponding matrix `L`.
- `EvaluatedDiffusion.apply(...)` applies that matrix to a Brownian increment.
- `resolve_metadata(...)` canonicalizes `bm_dim` and validates the coefficient
shape against `state_dim`.

In other words, the public classes describe the structure of the diffusion,
while `EvaluatedDiffusion` is the solver-facing object used after a concrete
state, control, and time have been chosen.

## API

### `EvaluatedDiffusion`

::: dynestyx.models.diffusions.EvaluatedDiffusion
options:
show_root_heading: false

### `Diffusion`

::: dynestyx.models.diffusions.Diffusion
options:
show_root_heading: false
filters: []

### `FullDiffusion`

::: dynestyx.models.diffusions.FullDiffusion
options:
show_root_heading: false
filters: []

### `DiagonalDiffusion`

::: dynestyx.models.diffusions.DiagonalDiffusion
options:
show_root_heading: false
filters: []

### `ScalarDiffusion`

::: dynestyx.models.diffusions.ScalarDiffusion
options:
show_root_heading: false
filters: []
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
# ContinuousTimeStateEvolution

`ContinuousTimeStateEvolution` is the public entry point for defining continuous-time
state evolution. Most users should instantiate this class directly and pass an optional
`diffusion=` built from [`FullDiffusion`](./diffusion.md),
[`DiagonalDiffusion`](./diffusion.md), or [`ScalarDiffusion`](./diffusion.md).

Internally, `DynamicalModel` canonicalizes continuous-time dynamics to refined
deterministic and stochastic subclasses. Those specialized classes are intended for
developer-facing integrations and are documented in the developer API rather than the
public tutorials.

::: dynestyx.models.core.ContinuousTimeStateEvolution
options:
show_root_heading: false
show_root_toc_entry: false

233 changes: 233 additions & 0 deletions docs/api_reference/public/models/core/diffusion.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# Diffusion

`Diffusion` objects define the stochastic term in a continuous-time state evolution

\[
dx_t = f(x_t, u_t, t)\,dt + L(x_t, u_t, t)\,dW_t,
\]

where:

- \(x_t \in \mathbb{R}^{d_x}\) is the latent state,
- \(u_t\) is an optional control,
- \(W_t \in \mathbb{R}^{d_w}\) is Brownian motion,
- `bm_dim = d_w`,
- and \(L(x_t, u_t, t)\) is the diffusion coefficient.

Dynestyx exposes three structured diffusion classes:

- `FullDiffusion`: general matrix-valued diffusion \(L \in \mathbb{R}^{d_x \times d_w}\)
- `DiagonalDiffusion`: vector-valued diffusion \(v \in \mathbb{R}^{d_x}\)
- `ScalarDiffusion`: scalar-valued diffusion \(\sigma \in \mathbb{R}\)

In practice, many models use a constant diffusion coefficient. In those cases,
pass the matrix/vector/scalar value directly. Reserve callable diffusion
coefficients for cases where the coefficient genuinely depends on state,
control, or time.
Comment thread
DanWaxman marked this conversation as resolved.

## `Diffusion`

`Diffusion` is the common base class for structured diffusion coefficients.
Most users should instantiate one of its public subclasses instead of using
`Diffusion` directly:

- Use `ScalarDiffusion` when one scalar scale is enough. This is the most common choice for isotropic diffusion.
- Use `DiagonalDiffusion` when each state coordinate should have its own scale but the loading remains axis-aligned.
- Use `FullDiffusion` when you need to specify a genuinely matrix-valued loading.

Each class accepts either:

- a constant coefficient, or
- a callable `(x, u, t) -> value`.

### `FullDiffusion`

Use `FullDiffusion(coefficient, bm_dim=None)` when you want to specify the
matrix-valued diffusion coefficient directly.

Accepted `coefficient` forms:

- a constant array with trailing shape `(state_dim, bm_dim)`, or
- a callable `(x, u, t) -> array` with trailing shape `(state_dim, bm_dim)`.

If `coefficient` is constant, `bm_dim` is inferred automatically when omitted.

Mathematically, `FullDiffusion` represents

\[
L(x_t, u_t, t) \in \mathbb{R}^{d_x \times d_w}.
\]

Example:

```python
import jax.numpy as jnp
from dynestyx import ContinuousTimeStateEvolution, FullDiffusion

state_evolution = ContinuousTimeStateEvolution(
drift=lambda x, u, t: -x,
diffusion=FullDiffusion(jnp.eye(2)),
)
```

### `DiagonalDiffusion`

Use `DiagonalDiffusion(coefficient, bm_dim)` when the diffusion is naturally
parameterized by a vector of state-wise loadings.

Accepted `coefficient` forms:

- a constant vector with trailing shape `(state_dim,)`, or
- a callable `(x, u, t) -> array` with trailing shape `(state_dim,)`.

`bm_dim` is required and must be either `1` or `state_dim`.

Mathematically, `DiagonalDiffusion` represents a vector-valued coefficient

\[
v(x_t, u_t, t) \in \mathbb{R}^{d_x}.
\]

If `bm_dim = d_x`, the vector is interpreted as a diagonal matrix:

\[
L = \mathrm{diag}(v(x_t, u_t, t)).
\]

If `bm_dim = 1`, the same vector is interpreted as a column loading vector:

\[
L = v(x_t, u_t, t) \in \mathbb{R}^{d_x \times 1}.
\]

Example:

```python
import jax.numpy as jnp
from dynestyx import ContinuousTimeStateEvolution, DiagonalDiffusion

state_evolution = ContinuousTimeStateEvolution(
drift=lambda x, u, t: -x,
diffusion=DiagonalDiffusion(jnp.array([0.1, 0.2]), bm_dim=2),
)
```

### `ScalarDiffusion`

Use `ScalarDiffusion(coefficient, bm_dim)` when one scalar scale is enough.
This is usually the simplest choice for isotropic diffusion.

Accepted `coefficient` forms:

- a scalar,
- a constant array with trailing shape `(1,)`, or
- a callable `(x, u, t) -> scalar_or_length_1_array`.

`bm_dim` is required and must be either `1` or `state_dim`.

Mathematically, `ScalarDiffusion` represents a scalar-valued coefficient

\[
\sigma(x_t, u_t, t) \in \mathbb{R}.
\]

If `bm_dim = d_x`, it is interpreted as isotropic independent noise:

\[
L = \sigma(x_t, u_t, t)\,I_{d_w}
\]

with \(d_w = d_x\).

If `bm_dim = 1`, it is interpreted as a shared scalar driver applied equally to
every state coordinate:

\[
L = \sigma(x_t, u_t, t)\,\mathbf{1}_{d_x},
\]

viewed as a column vector in \(\mathbb{R}^{d_x \times 1}\).

Constant example:

```python
from dynestyx import ContinuousTimeStateEvolution, ScalarDiffusion

state_evolution = ContinuousTimeStateEvolution(
drift=lambda x, u, t: -x,
diffusion=ScalarDiffusion(0.1, bm_dim=2),
)
```

Callable example:

```python
import jax.numpy as jnp
from dynestyx import ContinuousTimeStateEvolution, ScalarDiffusion

state_evolution = ContinuousTimeStateEvolution(
drift=lambda x, u, t: -x,
diffusion=ScalarDiffusion(
lambda x, u, t: 0.1 + 0.05 * jnp.tanh(x[0]),
bm_dim=2,
),
)
```

## API

### `Diffusion`

::: dynestyx.models.diffusions.Diffusion
options:
show_root_heading: false
show_root_toc_entry: false
filters:
- "!^evaluate_value$"
- "!^resolve_metadata$"
- "!^evaluate$"
- "!^as_matrix$"
- "!^gram_matrix$"
- "!^apply$"

### `FullDiffusion`

::: dynestyx.models.diffusions.FullDiffusion
options:
show_root_heading: false
show_root_toc_entry: false
filters:
- "!^evaluate_value$"
- "!^resolve_metadata$"
- "!^evaluate$"
- "!^as_matrix$"
- "!^gram_matrix$"
- "!^apply$"

### `DiagonalDiffusion`

::: dynestyx.models.diffusions.DiagonalDiffusion
options:
show_root_heading: false
show_root_toc_entry: false
filters:
- "!^evaluate_value$"
- "!^resolve_metadata$"
- "!^evaluate$"
- "!^as_matrix$"
- "!^gram_matrix$"
- "!^apply$"

### `ScalarDiffusion`

::: dynestyx.models.diffusions.ScalarDiffusion
options:
show_root_heading: false
show_root_toc_entry: false
filters:
- "!^evaluate_value$"
- "!^resolve_metadata$"
- "!^evaluate$"
- "!^as_matrix$"
- "!^gram_matrix$"
- "!^apply$"
4 changes: 2 additions & 2 deletions docs/api_reference/public/models/core/dynamical_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from dynestyx import (
DynamicalModel,
ContinuousTimeStateEvolution,
FullDiffusion,
LinearGaussianObservation,
)

Expand All @@ -47,12 +48,11 @@
),
state_evolution=ContinuousTimeStateEvolution(
drift=lambda x, u, t: -x + u,
diffusion_coefficient=lambda x, u, t: jnp.eye(state_dim, bm_dim),
diffusion=FullDiffusion(jnp.eye(state_dim, bm_dim)),
),
observation_model=LinearGaussianObservation(
H=jnp.eye(observation_dim, state_dim),
R=jnp.eye(observation_dim),
),
)
```

1 change: 0 additions & 1 deletion docs/api_reference/public/models/core_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@
- ObservationModel
- Drift
- Potential

4 changes: 2 additions & 2 deletions docs/api_reference/public/models/specialized/affine_drift.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
??? example "Ornstein–Uhlenbeck (OU) process"
```python
import jax.numpy as jnp
from dynestyx import AffineDrift, ContinuousTimeStateEvolution
from dynestyx import AffineDrift, ContinuousTimeStateEvolution, FullDiffusion

# OU SDE: dX_t = -theta (X_t - mu) dt + sigma dW_t
theta = 0.7
Expand All @@ -38,6 +38,6 @@

ou_sde = ContinuousTimeStateEvolution(
drift=drift,
diffusion_coefficient=lambda x, u, t: jnp.array([[sigma]]),
diffusion=FullDiffusion(jnp.array([[sigma]])),
)
```
9 changes: 7 additions & 2 deletions docs/api_reference/public/simulators/sde_simulator.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
import jax.random as jr
import numpyro
import numpyro.distributions as dist
from dynestyx import ContinuousTimeStateEvolution, DynamicalModel, SDESimulator
from dynestyx import (
ContinuousTimeStateEvolution,
DynamicalModel,
FullDiffusion,
SDESimulator,
)
from numpyro.infer import Predictive

state_dim = 1
Expand All @@ -33,7 +38,7 @@
),
state_evolution=ContinuousTimeStateEvolution(
drift=lambda x, u, t: -theta * x,
diffusion_coefficient=lambda x, u, t: sigma_x * jnp.eye(state_dim, bm_dim),
diffusion=FullDiffusion(sigma_x * jnp.eye(state_dim, bm_dim)),
),
observation_model=lambda x, u, t: dist.MultivariateNormal(
x,
Expand Down
Loading
Loading