diff --git a/docs/api_reference/developer/models/core_models.md b/docs/api_reference/developer/models/core_models.md index fdae00ea..f7626794 100644 --- a/docs/api_reference/developer/models/core_models.md +++ b/docs/api_reference/developer/models/core_models.md @@ -1,5 +1,14 @@ # Core Models +This page includes the developer-facing refined continuous-time state-evolution +classes: + +- `DeterministicContinuousTimeStateEvolution` +- `StochasticContinuousTimeStateEvolution` + +Most users should continue to construct `ContinuousTimeStateEvolution` directly and let +`DynamicalModel` refine it internally. + ::: dynestyx.models.core options: filters: [] diff --git a/docs/api_reference/developer/models/diffusions.md b/docs/api_reference/developer/models/diffusions.md new file mode 100644 index 00000000..0c02900c --- /dev/null +++ b/docs/api_reference/developer/models/diffusions.md @@ -0,0 +1,59 @@ +# Diffusions + +This page documents the diffusion internals used by continuous-time solvers, +discretizers, and backend integrations. + +## Overview + +Public model code should usually work with `FullDiffusion`, +`DiagonalDiffusion`, and `ScalarDiffusion` through +`ContinuousTimeStateEvolution(diffusion=...)`. + +Developer-facing code sometimes needs the lower-level evaluation layer: + +- `Diffusion.evaluate(...)` evaluates a structured diffusion object at a + concrete `(x, u, t)`. +- `EvaluatedDiffusion.as_matrix(...)` returns the corresponding matrix `L`. +- `EvaluatedDiffusion.apply(...)` applies that matrix to a Brownian increment. +- `resolve_metadata(...)` resolves `bm_dim` and validates the coefficient + shape against `state_dim`. + +In other words, the public classes describe the structure of the diffusion, +while `EvaluatedDiffusion` is the solver-facing object used after a concrete +state, control, and time have been chosen. + +## API + +### `EvaluatedDiffusion` + +::: dynestyx.models.diffusions.EvaluatedDiffusion + options: + show_root_heading: false + +### `Diffusion` + +::: dynestyx.models.diffusions.Diffusion + options: + show_root_heading: false + filters: [] + +### `FullDiffusion` + +::: dynestyx.models.diffusions.FullDiffusion + options: + show_root_heading: false + filters: [] + +### `DiagonalDiffusion` + +::: dynestyx.models.diffusions.DiagonalDiffusion + options: + show_root_heading: false + filters: [] + +### `ScalarDiffusion` + +::: dynestyx.models.diffusions.ScalarDiffusion + options: + show_root_heading: false + filters: [] diff --git a/docs/api_reference/public/models/core/continuous_time_state_evolution.md b/docs/api_reference/public/models/core/continuous_time_state_evolution.md index bff371c8..1d15e0b3 100644 --- a/docs/api_reference/public/models/core/continuous_time_state_evolution.md +++ b/docs/api_reference/public/models/core/continuous_time_state_evolution.md @@ -1,7 +1,16 @@ # ContinuousTimeStateEvolution +`ContinuousTimeStateEvolution` is the public entry point for defining continuous-time +state evolution. Most users should instantiate this class directly and pass an optional +`diffusion=` built from [`FullDiffusion`](./diffusion.md), +[`DiagonalDiffusion`](./diffusion.md), or [`ScalarDiffusion`](./diffusion.md). + +Internally, `DynamicalModel` refines continuous-time dynamics to +deterministic and stochastic subclasses. Those specialized classes are intended for +developer-facing integrations and are documented in the developer API rather than the +public tutorials. + ::: dynestyx.models.core.ContinuousTimeStateEvolution options: show_root_heading: false show_root_toc_entry: false - diff --git a/docs/api_reference/public/models/core/diffusion.md b/docs/api_reference/public/models/core/diffusion.md new file mode 100644 index 00000000..d1ee5918 --- /dev/null +++ b/docs/api_reference/public/models/core/diffusion.md @@ -0,0 +1,207 @@ +# Diffusion + +`Diffusion` objects define the stochastic term in a continuous-time state evolution + +\[ +dx_t = f(x_t, u_t, t)\,dt + L(x_t, u_t, t)\,dW_t, +\] + +where: + +- \(x_t \in \mathbb{R}^{d_x}\) is the latent state, +- \(u_t\) is an optional control, +- \(W_t \in \mathbb{R}^{d_w}\) is Brownian motion, +- `bm_dim = d_w`, +- and \(L(x_t, u_t, t)\) is the diffusion coefficient. + +Dynestyx exposes three structured diffusion classes: + +- If your diffusion has the form \(L = \sigma I_{d_w}\) with \(d_w = d_x\), use `ScalarDiffusion(sigma, bm_dim=state_dim)`. +- If your diffusion has the form \(L = \sigma \mathbf{1}_{d_x}\), use `ScalarDiffusion(sigma, bm_dim=1)`. +- If your diffusion has the form \(L = \mathrm{diag}(v)\), use `DiagonalDiffusion(v, bm_dim=state_dim)`. +- If your diffusion has the form \(L = v \in \mathbb{R}^{d_x \times 1}\), use `DiagonalDiffusion(v, bm_dim=1)`. +- If your diffusion is a general matrix \(L \in \mathbb{R}^{d_x \times d_w}\), use `FullDiffusion(L)`. + +The same constructors also accept callables `(x, u, t) -> value` instead of +constants. For example, if \(L(x_t, u_t, t) = \sigma(x_t) I_{d_w}\), use +`ScalarDiffusion(lambda x, u, t: sigma(x), bm_dim=state_dim)`. The same pattern +applies across `ScalarDiffusion`, `DiagonalDiffusion`, and `FullDiffusion`. + +In practice, many models use a constant diffusion coefficient. In those cases, +pass the matrix/vector/scalar value directly. Reserve callable diffusion +coefficients for cases where the coefficient genuinely depends on state, +control, or time. + +## `Diffusion` + +`Diffusion` is the common base class for structured diffusion coefficients. +Most users should instantiate one of its public subclasses instead of using +`Diffusion` directly: + +- Use `ScalarDiffusion` when one scalar scale is enough. This is the most common choice for isotropic diffusion. +- Use `DiagonalDiffusion` when each state coordinate should have its own scale but the loading remains axis-aligned. +- Use `FullDiffusion` when you need to specify a genuinely matrix-valued loading. + +Each class accepts either: + +- a constant coefficient, or +- a callable `(x, u, t) -> value`. + +### `FullDiffusion` + +Mathematically, `FullDiffusion` represents +\(L(x_t, u_t, t) \in \mathbb{R}^{d_x \times d_w}\). + +Use `FullDiffusion(coefficient, bm_dim=None)` when you want to specify this +matrix-valued diffusion coefficient directly. + +Accepted `coefficient` forms: + +- a constant array with trailing shape `(state_dim, bm_dim)`, or +- a callable `(x, u, t) -> array` with trailing shape `(state_dim, bm_dim)`. + +If `coefficient` is constant, `bm_dim` is inferred automatically when omitted. + +Example: + +```python +import jax.numpy as jnp +from dynestyx import ContinuousTimeStateEvolution, FullDiffusion + +state_evolution = ContinuousTimeStateEvolution( + drift=lambda x, u, t: -x, + diffusion=FullDiffusion(jnp.eye(2)), +) +``` + +### `DiagonalDiffusion` + +Mathematically, `DiagonalDiffusion` represents a vector-valued coefficient +\(v(x_t, u_t, t) \in \mathbb{R}^{d_x}\). If `bm_dim = d_x`, this means +\(L = \mathrm{diag}(v(x_t, u_t, t))\). If `bm_dim = 1`, this means +\(L = v(x_t, u_t, t) \in \mathbb{R}^{d_x \times 1}\). + +Use `DiagonalDiffusion(coefficient, bm_dim)` when the diffusion is naturally +parameterized by a vector of state-wise loadings. + +Accepted `coefficient` forms: + +- a constant vector with trailing shape `(state_dim,)`, or +- a callable `(x, u, t) -> array` with trailing shape `(state_dim,)`. + +`bm_dim` is required and must be either `1` or `state_dim`. + +Example: + +```python +import jax.numpy as jnp +from dynestyx import ContinuousTimeStateEvolution, DiagonalDiffusion + +state_evolution = ContinuousTimeStateEvolution( + drift=lambda x, u, t: -x, + diffusion=DiagonalDiffusion(jnp.array([0.1, 0.2]), bm_dim=2), +) +``` + +### `ScalarDiffusion` + +Mathematically, `ScalarDiffusion` represents a scalar-valued coefficient +\(\sigma(x_t, u_t, t) \in \mathbb{R}\). If `bm_dim = d_x`, this means +\(L = \sigma(x_t, u_t, t)\,I_{d_w}\) with \(d_w = d_x\). If `bm_dim = 1`, this +means \(L = \sigma(x_t, u_t, t)\,\mathbf{1}_{d_x}\), viewed as a column vector +in \(\mathbb{R}^{d_x \times 1}\). + +Use `ScalarDiffusion(coefficient, bm_dim)` when one scalar scale is enough. +This is usually the simplest choice for isotropic diffusion. + +Accepted `coefficient` forms: + +- a scalar, +- a constant array with trailing shape `(1,)`, or +- a callable `(x, u, t) -> scalar_or_length_1_array`. + +`bm_dim` is required and must be either `1` or `state_dim`. + +Constant example: + +```python +from dynestyx import ContinuousTimeStateEvolution, ScalarDiffusion + +state_evolution = ContinuousTimeStateEvolution( + drift=lambda x, u, t: -x, + diffusion=ScalarDiffusion(0.1, bm_dim=2), +) +``` + +Callable example: + +```python +import jax.numpy as jnp +from dynestyx import ContinuousTimeStateEvolution, ScalarDiffusion + +state_evolution = ContinuousTimeStateEvolution( + drift=lambda x, u, t: -x, + diffusion=ScalarDiffusion( + lambda x, u, t: 0.1 + 0.05 * jnp.tanh(x[0]), + bm_dim=2, + ), +) +``` + +## API + +### `Diffusion` + +::: dynestyx.models.diffusions.Diffusion + options: + show_root_heading: false + show_root_toc_entry: false + filters: + - "!^evaluate_value$" + - "!^resolve_metadata$" + - "!^evaluate$" + - "!^as_matrix$" + - "!^gram_matrix$" + - "!^apply$" + +### `FullDiffusion` + +::: dynestyx.models.diffusions.FullDiffusion + options: + show_root_heading: false + show_root_toc_entry: false + filters: + - "!^evaluate_value$" + - "!^resolve_metadata$" + - "!^evaluate$" + - "!^as_matrix$" + - "!^gram_matrix$" + - "!^apply$" + +### `DiagonalDiffusion` + +::: dynestyx.models.diffusions.DiagonalDiffusion + options: + show_root_heading: false + show_root_toc_entry: false + filters: + - "!^evaluate_value$" + - "!^resolve_metadata$" + - "!^evaluate$" + - "!^as_matrix$" + - "!^gram_matrix$" + - "!^apply$" + +### `ScalarDiffusion` + +::: dynestyx.models.diffusions.ScalarDiffusion + options: + show_root_heading: false + show_root_toc_entry: false + filters: + - "!^evaluate_value$" + - "!^resolve_metadata$" + - "!^evaluate$" + - "!^as_matrix$" + - "!^gram_matrix$" + - "!^apply$" diff --git a/docs/api_reference/public/models/core/dynamical_model.md b/docs/api_reference/public/models/core/dynamical_model.md index e437aba2..4c4acc42 100644 --- a/docs/api_reference/public/models/core/dynamical_model.md +++ b/docs/api_reference/public/models/core/dynamical_model.md @@ -33,6 +33,7 @@ from dynestyx import ( DynamicalModel, ContinuousTimeStateEvolution, + FullDiffusion, LinearGaussianObservation, ) @@ -47,7 +48,7 @@ ), state_evolution=ContinuousTimeStateEvolution( drift=lambda x, u, t: -x + u, - diffusion_coefficient=lambda x, u, t: jnp.eye(state_dim, bm_dim), + diffusion=FullDiffusion(jnp.eye(state_dim, bm_dim)), ), observation_model=LinearGaussianObservation( H=jnp.eye(observation_dim, state_dim), @@ -55,4 +56,3 @@ ), ) ``` - diff --git a/docs/api_reference/public/models/core_models.md b/docs/api_reference/public/models/core_models.md index 673d35b1..9abf8bdf 100644 --- a/docs/api_reference/public/models/core_models.md +++ b/docs/api_reference/public/models/core_models.md @@ -9,4 +9,3 @@ - ObservationModel - Drift - Potential - diff --git a/docs/api_reference/public/models/specialized/affine_drift.md b/docs/api_reference/public/models/specialized/affine_drift.md index 993c17ee..b9865e32 100644 --- a/docs/api_reference/public/models/specialized/affine_drift.md +++ b/docs/api_reference/public/models/specialized/affine_drift.md @@ -26,7 +26,7 @@ ??? example "Ornstein–Uhlenbeck (OU) process" ```python import jax.numpy as jnp - from dynestyx import AffineDrift, ContinuousTimeStateEvolution + from dynestyx import AffineDrift, ContinuousTimeStateEvolution, FullDiffusion # OU SDE: dX_t = -theta (X_t - mu) dt + sigma dW_t theta = 0.7 @@ -38,6 +38,6 @@ ou_sde = ContinuousTimeStateEvolution( drift=drift, - diffusion_coefficient=lambda x, u, t: jnp.array([[sigma]]), + diffusion=FullDiffusion(jnp.array([[sigma]])), ) ``` diff --git a/docs/api_reference/public/simulators/sde_simulator.md b/docs/api_reference/public/simulators/sde_simulator.md index 70b8e64d..8cd1f309 100644 --- a/docs/api_reference/public/simulators/sde_simulator.md +++ b/docs/api_reference/public/simulators/sde_simulator.md @@ -14,7 +14,12 @@ import jax.random as jr import numpyro import numpyro.distributions as dist - from dynestyx import ContinuousTimeStateEvolution, DynamicalModel, SDESimulator + from dynestyx import ( + ContinuousTimeStateEvolution, + DynamicalModel, + FullDiffusion, + SDESimulator, + ) from numpyro.infer import Predictive state_dim = 1 @@ -33,7 +38,7 @@ ), state_evolution=ContinuousTimeStateEvolution( drift=lambda x, u, t: -theta * x, - diffusion_coefficient=lambda x, u, t: sigma_x * jnp.eye(state_dim, bm_dim), + diffusion=FullDiffusion(sigma_x * jnp.eye(state_dim, bm_dim)), ), observation_model=lambda x, u, t: dist.MultivariateNormal( x, diff --git a/docs/deep_dives/fhn_sparse_id.ipynb b/docs/deep_dives/fhn_sparse_id.ipynb index 22a6b036..5cd59767 100644 --- a/docs/deep_dives/fhn_sparse_id.ipynb +++ b/docs/deep_dives/fhn_sparse_id.ipynb @@ -98,6 +98,7 @@ " DynamicalModel,\n", " Filter,\n", " LinearGaussianObservation,\n", + " ScalarDiffusion,\n", " SDESimulator,\n", ")\n", "\n", @@ -215,7 +216,7 @@ " return ContinuousTimeStateEvolution(\n", " drift=lambda x, u, t: adjust_rhs(x, drift_fn(x, u, t)),\n", " # sigma_x I_2 (isotropic, constant diffusion)\n", - " diffusion_coefficient=lambda x, u, t: diffusion_coeff * jnp.eye(state_dim),\n", + " diffusion=ScalarDiffusion(diffusion_coeff, bm_dim=state_dim),\n", " )\n" ] }, diff --git a/docs/deep_dives/gp_drift.ipynb b/docs/deep_dives/gp_drift.ipynb index 19c9c082..9967f137 100644 --- a/docs/deep_dives/gp_drift.ipynb +++ b/docs/deep_dives/gp_drift.ipynb @@ -94,6 +94,7 @@ " ContinuousTimeStateEvolution,\n", " LinearGaussianObservation,\n", " Filter,\n", + " ScalarDiffusion,\n", " SDESimulator,\n", ")\n", "from dynestyx.diagnostics.plotting_utils import plot_drift_field\n", @@ -194,7 +195,7 @@ "def make_state_evolution(drift_fn):\n", " return ContinuousTimeStateEvolution(\n", " drift=drift_fn,\n", - " diffusion_coefficient=lambda x, u, t: diffusion_coeff * jnp.eye(state_dim),\n", + " diffusion=ScalarDiffusion(diffusion_coeff, bm_dim=state_dim),\n", " )" ] }, diff --git a/docs/deep_dives/l63_speedup_dirac_vs_enkf.ipynb b/docs/deep_dives/l63_speedup_dirac_vs_enkf.ipynb index 9dd69fd9..dae9208f 100644 --- a/docs/deep_dives/l63_speedup_dirac_vs_enkf.ipynb +++ b/docs/deep_dives/l63_speedup_dirac_vs_enkf.ipynb @@ -58,6 +58,7 @@ " LinearGaussianObservation,\n", " DiscreteTimeSimulator,\n", " Filter,\n", + " ScalarDiffusion,\n", " SDESimulator,\n", " Discretizer,\n", ")\n" @@ -97,7 +98,7 @@ " x[0] * x[1] - (8.0 / 3.0) * x[2],\n", " ]\n", " ),\n", - " diffusion_coefficient=lambda x, u, t: jnp.eye(state_dim),\n", + " diffusion=ScalarDiffusion(1.0, bm_dim=state_dim),\n", " ),\n", " observation_model=observation_model,\n", " )\n", diff --git a/docs/logo/make_logo.py b/docs/logo/make_logo.py index 7b21deed..484be407 100644 --- a/docs/logo/make_logo.py +++ b/docs/logo/make_logo.py @@ -15,23 +15,24 @@ import argparse import os -import dynestyx as dsx import imageio.v2 as imageio import jax.numpy as jnp import jax.random as jr import numpy as np import numpyro.distributions as dist +from numpyro.infer import Predictive +from PIL import Image, ImageDraw, ImageFont +from scipy import ndimage +from scipy.ndimage import binary_erosion + +import dynestyx as dsx from dynestyx import ( ContinuousTimeStateEvolution, DiracIdentityObservation, DynamicalModel, + FullDiffusion, SDESimulator, ) -from numpyro.infer import Predictive -from PIL import Image, ImageDraw, ImageFont -from scipy import ndimage -from scipy.ndimage import binary_erosion - # ----------------------------- # Fonts (cross-platform) @@ -286,9 +287,6 @@ def potential_fn(x, u, t): dist_to_boundary = bilinear_sample_jax(dist_grid_j, x[0], x[1]) return 0.5 * kappa * (dist_to_boundary - ring) ** 2 - def diffusion_fn(x, u, t): - return sigma * jnp.eye(2, dtype=jnp.float32) - dynamics = DynamicalModel( control_dim=0, initial_condition=dist.Uniform( @@ -299,7 +297,7 @@ def diffusion_fn(x, u, t): drift=drift_fn, potential=potential_fn, use_negative_gradient=True, - diffusion_coefficient=diffusion_fn, + diffusion=FullDiffusion(sigma * jnp.eye(2, dtype=jnp.float32)), ), observation_model=DiracIdentityObservation(), ) diff --git a/docs/math_intro.md b/docs/math_intro.md index ec905601..d1fe6f10 100644 --- a/docs/math_intro.md +++ b/docs/math_intro.md @@ -44,7 +44,7 @@ dynamics = DynamicalModel( initial_condition=dist.MultivariateNormal(...), state_evolution=ContinuousTimeStateEvolution( drift=lambda x, u, t: ..., - diffusion_coefficient=lambda x, u, t: ..., + diffusion=dsx.FullDiffusion(...), ), observation_model=lambda x, u, t: ..., ) @@ -136,4 +136,4 @@ Detailed API documentation is available for all modules, classes, and functions [^1]: We also allow for some more general state space models than are described here. For example, the state need not be real (see, for example, the Hidden Markov Model tutorials). -[^2]: As another example of support for general models, we also allow for deterministic dynamical systems (i.e., ODEs). \ No newline at end of file +[^2]: As another example of support for general models, we also allow for deterministic dynamical systems (i.e., ODEs). diff --git a/docs/quick_example.ipynb b/docs/quick_example.ipynb index 9944d2c8..6a72a506 100644 --- a/docs/quick_example.ipynb +++ b/docs/quick_example.ipynb @@ -47,7 +47,7 @@ "import numpyro\n", "import numpyro.distributions as dist\n", "from dynestyx import (\n", - " ContinuousTimeStateEvolution, DynamicalModel,\n", + " ContinuousTimeStateEvolution, DynamicalModel, FullDiffusion,\n", " LinearGaussianObservation, SDESimulator, Filter,\n", " sample,\n", ")\n", @@ -74,7 +74,7 @@ " x[0] * (rho - x[2]) - x[1],\n", " x[0] * x[1] - (8.0 / 3.0) * x[2],\n", " ]),\n", - " diffusion_coefficient=lambda x, u, t: diff_std * jnp.eye(state_dim),\n", + " diffusion=FullDiffusion(diff_std * jnp.eye(state_dim)),\n", " )\n", " \n", " # define the model\n", diff --git a/docs/tutorials/gentle_intro/02_dynestyx_discrete_intro.ipynb b/docs/tutorials/gentle_intro/02_dynestyx_discrete_intro.ipynb index 4eedb22f..e23ad9a0 100644 --- a/docs/tutorials/gentle_intro/02_dynestyx_discrete_intro.ipynb +++ b/docs/tutorials/gentle_intro/02_dynestyx_discrete_intro.ipynb @@ -131,8 +131,7 @@ " - A callable $(x, u, t_\\text{now}, t_\\text{next}) \\mapsto \\ $ `numpyro.distributions` object (which can include things like categorical distributions)\n", " - `ContinuousTimeStateEvolution`: has fields\n", " - `drift`: A callable $(x, u, t) \\mapsto \\ \\mathbb{R}^{d_x}$\n", - " - `diffusion_coefficient`: A callable $(x, u, t) \\mapsto \\ \\mathbb{R}^{d_x \\times d_x}$\n", - " - `diffusion_covariance`: A callable $(x, u, t) \\mapsto \\ \\mathbb{R}^{d_x \\times d_x}$\n", + " - `diffusion`: a `Diffusion` object, e.g. `ScalarDiffusion`, `DiagonalDiffusion`, or `FullDiffusion`\n", "- `observation_model`: A callable $(x, u, t) \\mapsto \\ $ `numpyro.distributions` object\n", "\n", "It also requires specification of the key dimensions in the problem:\n", diff --git a/docs/tutorials/gentle_intro/06_continuous_time.ipynb b/docs/tutorials/gentle_intro/06_continuous_time.ipynb index 944f97e2..c0164462 100644 --- a/docs/tutorials/gentle_intro/06_continuous_time.ipynb +++ b/docs/tutorials/gentle_intro/06_continuous_time.ipynb @@ -22,7 +22,7 @@ "where $W_t$ is a vector Brownian motion. We specify:\n", "\n", "- **`drift`**: $f(x, u, t)$ — the deterministic part (vector of same dimension as state).\n", - "- **`diffusion_coefficient`**: $L(x, u, t)$ — matrix such that the diffusion term is $L\\,dW_t$. Shape is `(state_dim, brownian_dim)`.\n", + "- **`diffusion`**: a `Diffusion` object representing the SDE diffusion term.\n", "\n", "All three are callables with signature `(x, u, t)`: state `x`, control `u` (or `None`), and time `t`." ] @@ -68,6 +68,7 @@ " DynamicalModel,\n", " LinearGaussianObservation,\n", " SDESimulator,\n", + " ScalarDiffusion,\n", ")\n", "\n", "state_dim = 3\n", @@ -88,7 +89,7 @@ " x[0] * x[1] - (8.0 / 3.0) * x[2],\n", " ]\n", " ),\n", - " diffusion_coefficient=lambda x, u, t: jnp.eye(3),\n", + " diffusion=ScalarDiffusion(1.0, bm_dim=state_dim),\n", " ),\n", " observation_model=LinearGaussianObservation(\n", " H=jnp.eye(observation_dim, state_dim), # observe only x[0]\n", diff --git a/docs/tutorials/quickstart.ipynb b/docs/tutorials/quickstart.ipynb index 5b64310c..7cc4184e 100644 --- a/docs/tutorials/quickstart.ipynb +++ b/docs/tutorials/quickstart.ipynb @@ -50,7 +50,7 @@ "import numpyro\n", "import numpyro.distributions as dist\n", "import jax.numpy as jnp\n", - "from dynestyx import DynamicalModel, ContinuousTimeStateEvolution, LinearGaussianObservation\n", + "from dynestyx import DynamicalModel, ContinuousTimeStateEvolution, LinearGaussianObservation, ScalarDiffusion\n", "import dynestyx as dsx\n", "\n", "def continuous_time_lti_gaussian_model(rho=None, predict_times=None, obs_times=None, obs_values=None):\n", @@ -66,7 +66,7 @@ " ),\n", " state_evolution=ContinuousTimeStateEvolution(\n", " drift=lambda x, u, t: A @ x,\n", - " diffusion_coefficient=lambda x, u, t: jnp.eye(2),\n", + " diffusion=ScalarDiffusion(1.0, bm_dim=2),\n", " ),\n", " observation_model=LinearGaussianObservation(\n", " H=jnp.array([[0.0, 1.0]]), R=jnp.array([[0.15**2]])\n", diff --git a/docs/tutorials/sde_non_gaussian_observations.ipynb b/docs/tutorials/sde_non_gaussian_observations.ipynb index 1b2d0263..9469aeb8 100644 --- a/docs/tutorials/sde_non_gaussian_observations.ipynb +++ b/docs/tutorials/sde_non_gaussian_observations.ipynb @@ -47,7 +47,7 @@ "import jax.random as jr\n", "import numpyro\n", "import numpyro.distributions as dist\n", - "from dynestyx import DynamicalModel, ContinuousTimeStateEvolution, ObservationModel\n", + "from dynestyx import DynamicalModel, ContinuousTimeStateEvolution, ObservationModel, ScalarDiffusion\n", "import dynestyx as dsx\n", "import equinox as eqx\n", "\n", @@ -80,7 +80,7 @@ " ),\n", " state_evolution=ContinuousTimeStateEvolution(\n", " drift=drift,\n", - " diffusion_coefficient=lambda x, u, t: sigma * jnp.eye(1),\n", + " diffusion=ScalarDiffusion(sigma, bm_dim=1),\n", " ),\n", " observation_model=lambda x, u, t: dist.Poisson(rate=dt * jnp.exp(x[0] + bias)),\n", " )\n", diff --git a/dynestyx/__init__.py b/dynestyx/__init__.py index 287bd127..84435882 100644 --- a/dynestyx/__init__.py +++ b/dynestyx/__init__.py @@ -10,9 +10,13 @@ from dynestyx.inference.smoothers import Smoother from dynestyx.models import ( ContinuousTimeStateEvolution, + DeterministicContinuousTimeStateEvolution, + DiagonalDiffusion, + Diffusion, DiracIdentityObservation, DiscreteTimeStateEvolution, DynamicalModel, + FullDiffusion, GaussianObservation, GaussianStateEvolution, LinearGaussianObservation, @@ -20,6 +24,8 @@ LTI_continuous, LTI_discrete, ObservationModel, + ScalarDiffusion, + StochasticContinuousTimeStateEvolution, ) from dynestyx.simulators import ( DiscreteTimeSimulator, @@ -32,6 +38,12 @@ __all__ = [ "__version__", "ContinuousTimeStateEvolution", + "DeterministicContinuousTimeStateEvolution", + "Diffusion", + "FullDiffusion", + "DiagonalDiffusion", + "ScalarDiffusion", + "StochasticContinuousTimeStateEvolution", "DiscreteTimeStateEvolution", "DynamicalModel", "AffineDrift", diff --git a/dynestyx/discretizers.py b/dynestyx/discretizers.py index cb9a6d39..bd75c6d2 100644 --- a/dynestyx/discretizers.py +++ b/dynestyx/discretizers.py @@ -1,58 +1,40 @@ -import jax.numpy as jnp import numpyro.distributions as dist from effectful.ops.semantics import fwd from effectful.ops.syntax import ObjectInterpretation, implements from dynestyx.handlers import HandlesSelf, _sample_intp from dynestyx.models import ( - ContinuousTimeStateEvolution, DynamicalModel, GaussianStateEvolution, + StochasticContinuousTimeStateEvolution, ) -from dynestyx.models.checkers import _infer_bm_dim from dynestyx.solvers import euler_maruyama_loc_cov from dynestyx.types import FunctionOfTime -def _ensure_ctse_bm_dim(dynamics: DynamicalModel) -> DynamicalModel: - """Infer and set bm_dim when CT dynamics are built under active plates.""" - if not isinstance(dynamics.state_evolution, ContinuousTimeStateEvolution): - return dynamics - - cte = dynamics.state_evolution - if cte.diffusion_coefficient is None or cte.bm_dim is not None: - return dynamics - - x0 = jnp.zeros((dynamics.state_dim,)) - u0 = None if dynamics.control_dim == 0 else jnp.zeros((dynamics.control_dim,)) - t0 = jnp.array(0.0) if dynamics.t0 is None else jnp.asarray(dynamics.t0) - inferred_bm_dim = _infer_bm_dim(cte, dynamics.state_dim, x0, u0, t0) - if inferred_bm_dim is not None: - object.__setattr__(cte, "bm_dim", inferred_bm_dim) - return dynamics - - class EulerMaruyamaGaussianStateEvolution(GaussianStateEvolution): """`GaussianStateEvolution` backed by Euler-Maruyama moments.""" - cte: ContinuousTimeStateEvolution + cte: StochasticContinuousTimeStateEvolution def __init__( self, - cte: ContinuousTimeStateEvolution, + cte: StochasticContinuousTimeStateEvolution, F=None, cov=None, ): # Accept these for reconstruction paths, but derive both from `cte`. - del F, cov self.cte = cte + + def _loc(x, u, t_now, t_next): + return euler_maruyama_loc_cov(cte, x, u, t_now, t_next)["loc"] + + def _cov(x, u, t_now, t_next): + return euler_maruyama_loc_cov(cte, x, u, t_now, t_next)["cov"] + super().__init__( - F=lambda x, u, t_now, t_next: euler_maruyama_loc_cov( - cte, x, u, t_now, t_next - )["loc"], - cov=lambda x, u, t_now, t_next: euler_maruyama_loc_cov( - cte, x, u, t_now, t_next - )["cov"], + F=_loc, + cov=_cov, ) def __call__(self, x, u, t_now, t_next): @@ -63,7 +45,9 @@ def __call__(self, x, u, t_now, t_next): ) -def euler_maruyama(cte: ContinuousTimeStateEvolution) -> GaussianStateEvolution: +def euler_maruyama( + cte: StochasticContinuousTimeStateEvolution, +) -> GaussianStateEvolution: """Discretize continuous-time state evolution via Euler-Maruyama. Euler-Maruyama is a first-order discrete approximation of a continuous-time @@ -73,7 +57,7 @@ def euler_maruyama(cte: ContinuousTimeStateEvolution) -> GaussianStateEvolution: (depends on `t_next - t_now`) and passed as a callable `cov`. Args: - cte: `ContinuousTimeStateEvolution` to discretize. + cte: `StochasticContinuousTimeStateEvolution` to discretize. Returns: GaussianStateEvolution: Discrete-time Gaussian transition with the same Euler–Maruyama semantics as before this refactor. @@ -124,6 +108,7 @@ class Discretizer(ObjectInterpretation, HandlesSelf): ContinuousTimeStateEvolution, DiscreteTimeStateEvolution, DynamicalModel, + FullDiffusion, ) def model_with_ctse(obs_times=None, obs_values=None): @@ -135,7 +120,9 @@ def model_with_ctse(obs_times=None, obs_values=None): ), state_evolution=ContinuousTimeStateEvolution( drift=lambda x, u, t: x, - diffusion_coefficient=lambda x, u, t: jnp.eye(state_dim, bm_dim), + diffusion=FullDiffusion( + lambda x, u, t: jnp.eye(state_dim, bm_dim) + ), ), observation_model=lambda x, u, t: dist.MultivariateNormal( x, @@ -177,8 +164,7 @@ def _sample_ds( ctrl_values=None, **kwargs, ) -> FunctionOfTime: - if isinstance(dynamics.state_evolution, ContinuousTimeStateEvolution): - dynamics = _ensure_ctse_bm_dim(dynamics) + if isinstance(dynamics.state_evolution, StochasticContinuousTimeStateEvolution): discrete_evolution = self.discretize(dynamics.state_evolution) dynamics = DynamicalModel( initial_condition=dynamics.initial_condition, diff --git a/dynestyx/inference/filters.py b/dynestyx/inference/filters.py index 1a782543..364e0afe 100644 --- a/dynestyx/inference/filters.py +++ b/dynestyx/inference/filters.py @@ -53,7 +53,6 @@ from dynestyx.inference.plate_utils import _array_plate_axis, _make_plate_in_axes from dynestyx.models import DynamicalModel from dynestyx.types import FunctionOfTime -from dynestyx.utils import _ensure_continuous_bm_dim type SSMType = ContDiscreteNonlinearGaussianSSM | ContDiscreteNonlinearSSM @@ -208,8 +207,6 @@ def _add_log_factors( if obs_times is None or obs_values is None: raise ValueError("obs_times and obs_values are required for filtering.") - dynamics = _ensure_continuous_bm_dim(dynamics) - config = ( self.filter_config if self.filter_config is not None @@ -235,8 +232,10 @@ def _add_log_factors( if not isinstance(config, ContinuousTimeConfigs): valid = [c.__name__ for c in ContinuousTimeConfigs] raise ValueError( - f"Invalid filter config: {type(config).__name__}. " - f"Valid config types: {valid}" + "Continuous-time models require a continuous-time filter config. " + "If you want to use a discrete-time filter, nest `Discretizer()` " + "inside `Filter()`. " + f"Got {type(config).__name__}; valid continuous-time config types: {valid}." ) return _filter_continuous_time( name, diff --git a/dynestyx/inference/integrations/cd_dynamax/utils.py b/dynestyx/inference/integrations/cd_dynamax/utils.py index e1883a25..fb41e80b 100644 --- a/dynestyx/inference/integrations/cd_dynamax/utils.py +++ b/dynestyx/inference/integrations/cd_dynamax/utils.py @@ -16,43 +16,49 @@ from dynestyx.inference.integrations.utils import squeeze_leading_singletons from dynestyx.models import ( AffineDrift, - ContinuousTimeStateEvolution, + DeterministicContinuousTimeStateEvolution, DynamicalModel, GaussianObservation, GaussianStateEvolution, LinearGaussianObservation, LinearGaussianStateEvolution, + StochasticContinuousTimeStateEvolution, ) type SSMType = ContDiscreteNonlinearGaussianSSM | ContDiscreteNonlinearSSM -def _normalize_cd_dynamax_diffusion( - diffusion_coefficient, +def _as_cd_dynamax_diffusion_coefficient( + state_evolution: StochasticContinuousTimeStateEvolution, state_dim: int, ): - """Return a diffusion coeff compatible with cd-dynamax's EnKF SDE solve. - - cd-dynamax's internal diffrax wrapper builds Brownian controls with shape - equal to `y0.shape` (state_dim). For non-square diffusion coefficients - (state_dim, bm_dim) with bm_dim != state_dim, pad/truncate columns so the - returned matrix is always (state_dim, state_dim). + """Return a full-matrix diffusion coefficient for cd-dynamax. + + This adapts any Dynestyx diffusion shorthand (scalar/diagonal/full, constant + or callable) into a function ``(x, u, t) -> L(x, u, t)`` returning a full + matrix. cd-dynamax's continuous solver paths assume Brownian controls have + dimension ``state_dim``, so rectangular diffusions with ``bm_dim < + state_dim`` are padded with trailing zero columns. ``bm_dim > state_dim`` is + rejected. """ + if state_evolution.bm_dim > state_dim: + raise ValueError( + "Continuous cd-dynamax filters require bm_dim <= state_dim. " + f"Got state_dim={state_dim}, bm_dim={state_evolution.bm_dim}." + ) def _wrapped(x, u, t): - L = diffusion_coefficient(x, u, t) - if L.ndim == 1: - L = jnp.diag(L) - if L.ndim != 2: + L = state_evolution.diffusion.as_matrix(x=x, u=u, t=t, state_dim=state_dim) + n_cols = L.shape[-1] + if n_cols > state_dim: raise ValueError( - "diffusion_coefficient must return a vector or matrix for cd-dynamax." + "cd-dynamax continuous diffusion requires bm_dim <= state_dim. " + f"Got state_dim={state_dim}, bm_dim={n_cols}." ) - n_cols = L.shape[-1] - if n_cols == state_dim: - return L if n_cols < state_dim: - return jnp.pad(L, ((0, 0), (0, state_dim - n_cols))) - return L[:, :state_dim] + pad_width = ((0, 0),) * (L.ndim - 1) + ((0, state_dim - n_cols),) + L = jnp.pad(L, pad_width) + return L return _wrapped @@ -181,30 +187,41 @@ def dsx_to_cdlgssm_params(dsx_model: DynamicalModel) -> ParamsCDLGSSM: Requires: - drift is AffineDrift (A, B, b) - - diffusion_coefficient is constant (callable returning same value for any x, u, t) - returning same value for any x, u, t) + - diffusion_coefficient is constant (array/scalar-valued Dynestyx diffusion) - observation_model is LinearGaussianObservation - initial_condition is MultivariateNormal """ state_evo = dsx_model.state_evolution - if not isinstance(state_evo, ContinuousTimeStateEvolution): - raise TypeError("dsx_to_cdlgssm_params requires ContinuousTimeStateEvolution.") + if not isinstance(state_evo, StochasticContinuousTimeStateEvolution): + raise TypeError( + "dsx_to_cdlgssm_params requires StochasticContinuousTimeStateEvolution. You probably tried to call a continuous-time CD-Dynamax filter with a DeterministicContinuousTimeStateEvolution or DiscreteTimeStateEvolution." + ) drift = state_evo.drift if not isinstance(drift, AffineDrift): raise TypeError( f"dsx_to_cdlgssm_params requires AffineDrift, got {type(drift).__name__}." ) - if state_evo.diffusion_coefficient is None: - raise ValueError("dsx_to_cdlgssm_params requires diffusion_coefficient.") - # Extract constant L and use inferred Brownian dimension. - x0 = jnp.zeros(dsx_model.state_dim) - L = state_evo.diffusion_coefficient(x0, None, jnp.array(0.0)) - if state_evo.bm_dim is None: - raise ValueError( - "state_evolution.bm_dim is not set on ContinuousTimeStateEvolution." + if callable(state_evo.diffusion.coefficient): + raise TypeError( + "Callable diffusion is not supported by the continuous-time exact " + "Kalman filter path. If you need callable diffusion, use a " + "continuous-time nonlinear filter config such as " + "ContinuousTimeEKFConfig instead of ContinuousTimeKFConfig. " + "When the goal is exact filtering, CT-EKF is the appropriate " + "fallback here because in the linear-Gaussian case it emulates " + "the KF via auto-diff and cd_dynamax supports callable diffusion " + "in the nonlinear case (not directly in the linear-Gaussian case)." ) - Q = jnp.eye(state_evo.bm_dim) + + # Extract constant L and use resolved Brownian dimension. + L = state_evo.diffusion.as_matrix( + x=jnp.zeros(dsx_model.state_dim), + u=None if dsx_model.control_dim == 0 else jnp.zeros((dsx_model.control_dim,)), + t=jnp.array(0.0), + state_dim=dsx_model.state_dim, + ) + Q = jnp.eye(dsx_model.state_dim, state_evo.diffusion.bm_dim) ic = dsx_model.initial_condition if not isinstance(ic, dist.MultivariateNormal): @@ -260,30 +277,44 @@ def dsx_to_cd_dynamax( ## Map state evolution ## state_evo = dsx_model.state_evolution - if isinstance(state_evo, ContinuousTimeStateEvolution): - if state_evo.drift is not None or state_evo.potential is not None: - shared_params.update( - { - "dynamics_drift": state_evo.total_drift, - } - ) - else: + if isinstance( + state_evo, + ( + DeterministicContinuousTimeStateEvolution, + StochasticContinuousTimeStateEvolution, + ), + ): + if state_evo.drift is None and state_evo.potential is None: raise ValueError("Both drift and potential are None; define at least one.") - if state_evo.diffusion_coefficient is not None: - if state_evo.bm_dim is None: - raise ValueError( - "state_evolution.bm_dim is not set on ContinuousTimeStateEvolution." - ) - diffusion_coeff = _normalize_cd_dynamax_diffusion( - state_evo.diffusion_coefficient, - dsx_model.state_dim, - ) - shared_params.update( - { - "dynamics_diffusion_coefficient": diffusion_coeff, - "dynamics_diffusion_cov": jnp.eye(dsx_model.state_dim), - } - ) + shared_params.update( + { + "dynamics_drift": state_evo.total_drift, + } + ) + if isinstance(state_evo, StochasticContinuousTimeStateEvolution): + diffusion_coeff = _as_cd_dynamax_diffusion_coefficient( + state_evo, + dsx_model.state_dim, + ) + shared_params.update( + { + "dynamics_diffusion_coefficient": diffusion_coeff, + "dynamics_diffusion_cov": jnp.eye(dsx_model.state_dim), + } + ) + elif isinstance(state_evo, DeterministicContinuousTimeStateEvolution): + shared_params.update( + { + "dynamics_diffusion_coefficient": jnp.zeros( + (dsx_model.state_dim, dsx_model.state_dim) + ), + "dynamics_diffusion_cov": jnp.eye(dsx_model.state_dim), + } + ) + elif isinstance(state_evo, LinearGaussianStateEvolution): + raise NotImplementedError( + f"State evolution of type {type(state_evo)} is not supported yet." + ) else: raise NotImplementedError( f"State evolution of type {type(state_evo)} is not supported yet." diff --git a/dynestyx/inference/smoothers.py b/dynestyx/inference/smoothers.py index 2b15f99a..a214292d 100644 --- a/dynestyx/inference/smoothers.py +++ b/dynestyx/inference/smoothers.py @@ -46,7 +46,6 @@ ) from dynestyx.models import DynamicalModel from dynestyx.types import FunctionOfTime -from dynestyx.utils import _ensure_continuous_bm_dim DiscreteSmootherConfig = ( KFSmootherConfig | EKFSmootherConfig | UKFSmootherConfig | PFSmootherConfig @@ -190,8 +189,6 @@ def _add_log_factors( if obs_times is None or obs_values is None: raise ValueError("obs_times and obs_values are required for smoothing.") - dynamics = _ensure_continuous_bm_dim(dynamics) - config = ( self.smoother_config if self.smoother_config is not None diff --git a/dynestyx/models/__init__.py b/dynestyx/models/__init__.py index 6b2f1f56..2085abe9 100644 --- a/dynestyx/models/__init__.py +++ b/dynestyx/models/__init__.py @@ -5,10 +5,18 @@ from dynestyx.models.core import ( ContinuousTimeStateEvolution, + DeterministicContinuousTimeStateEvolution, DiscreteTimeStateEvolution, Drift, DynamicalModel, ObservationModel, + StochasticContinuousTimeStateEvolution, +) +from dynestyx.models.diffusions import ( + DiagonalDiffusion, + Diffusion, + FullDiffusion, + ScalarDiffusion, ) from dynestyx.models.lti_dynamics import LTI_continuous, LTI_discrete from dynestyx.models.observations import ( @@ -24,16 +32,22 @@ __all__ = [ "ContinuousTimeStateEvolution", + "DeterministicContinuousTimeStateEvolution", "AffineDrift", "DiracIdentityObservation", + "Diffusion", "DiscreteTimeStateEvolution", + "DiagonalDiffusion", "DynamicalModel", "Drift", + "FullDiffusion", "GaussianObservation", "GaussianStateEvolution", "LinearGaussianObservation", "LinearGaussianStateEvolution", "ObservationModel", + "StochasticContinuousTimeStateEvolution", "LTI_continuous", "LTI_discrete", + "ScalarDiffusion", ] diff --git a/dynestyx/models/checkers.py b/dynestyx/models/checkers.py index 634cb885..669f456c 100644 --- a/dynestyx/models/checkers.py +++ b/dynestyx/models/checkers.py @@ -101,111 +101,46 @@ def _make_probe_state(initial_condition: Any, state_dim: int) -> jax.Array: return jnp.zeros((state_dim,)) -def _infer_bm_dim( - state_evolution: Any, - state_dim: int, - x0: State, - u0: Control | None, - t0: Time, -) -> int | None: - """Infer bm_dim from diffusion coefficient output shape. - - Tolerates leading batch dimensions (e.g. from plate-batched parameters) - by inspecting only the trailing two dimensions (..., state_dim, bm_dim). - - Returns the inferred bm_dim, or None if there is no diffusion coefficient. - """ - if state_evolution.diffusion_coefficient is None: - if state_evolution.bm_dim is not None: - raise ValueError("bm_dim cannot be set when diffusion_coefficient is None.") - return None - - diffusion_shape = jax.eval_shape( - lambda: state_evolution.diffusion_coefficient(x0, u0, t0) - ).shape - if len(diffusion_shape) < 2: - raise ValueError( - "diffusion_coefficient must return shape (..., state_dim, bm_dim). " - f"Got shape {diffusion_shape}." - ) - if int(diffusion_shape[-2]) != state_dim: - raise ValueError( - "diffusion_coefficient penultimate dimension must match state_dim. " - f"Got diffusion shape {diffusion_shape}, state_dim={state_dim}." - ) - inferred_bm_dim = int(diffusion_shape[-1]) - if ( - state_evolution.bm_dim is not None - and int(state_evolution.bm_dim) != inferred_bm_dim - ): - raise ValueError( - "bm_dim does not match inferred diffusion_coefficient output shape. " - f"Got bm_dim={state_evolution.bm_dim}, inferred={inferred_bm_dim}." - ) - return inferred_bm_dim - - def _validate_continuous_state_evolution( state_evolution: Any, state_dim: int, - x0: State, - u0: Control | None, - t0: Time, -) -> int | None: - """Validate the shape of the continuous-time state evolution w.r.t. state_dim and bm_dim. - - Returns the inferred bm_dim (or None if no diffusion coefficient). - """ - drift_shape = jax.eval_shape(lambda: state_evolution.total_drift(x0, u0, t0)).shape + x_probe: State, + u_probe: Control | None, + t_probe: Time, +) -> None: + """Validate the drift shape of a continuous-time state evolution.""" + drift_shape = jax.eval_shape( + lambda: state_evolution.total_drift(x_probe, u_probe, t_probe) + ).shape if drift_shape != (state_dim,): raise ValueError( "State drift shape is inconsistent with state_dim. " f"Expected {(state_dim,)}, got {drift_shape}." ) - return _infer_bm_dim(state_evolution, state_dim, x0, u0, t0) - -def _validate_state_evolution_output_shape( +def _validate_discrete_state_evolution_output_shape( state_evolution: Callable[[State, Control, Time], State] | Callable[[State, Control, Time, Time], State], state_dim: int, - x0: State, - u0: Control | None, - t0: Time, - *, - continuous_time: bool, -) -> int | None: - """Validate the shape of the state evolution w.r.t. state_dim (and bm_dim for continuous-time models). - - Returns the inferred bm_dim for continuous-time models, or None otherwise. - """ - if continuous_time: - return _validate_continuous_state_evolution( - state_evolution=state_evolution, - state_dim=state_dim, - x0=x0, - u0=u0, - t0=t0, - ) - else: - if getattr(state_evolution, "bm_dim", None) is not None: - raise ValueError( - "bm_dim can only be set for continuous-time models with " - "diffusion_coefficient." - ) - t_now = t0 - t_next = t0 + 1.0 - transition_dist = state_evolution(x=x0, u=u0, t_now=t_now, t_next=t_next) # type: ignore[misc,call-arg] - inferred_state_dim = _infer_vector_dim_from_distribution( - transition_dist, "state_evolution(x, u, t_now, t_next)" + x_probe: State, + u_probe: Control | None, + t_probe: Time, +) -> None: + """Validate a discrete-time state evolution against the inferred state dimension.""" + if getattr(state_evolution, "diffusion", None) is not None: + raise ValueError("diffusion can only be set for continuous-time models.") + t_now = t_probe + t_next = t_probe + 1.0 + transition_dist = state_evolution(x=x_probe, u=u_probe, t_now=t_now, t_next=t_next) # type: ignore[misc,call-arg] + inferred_state_dim = _infer_vector_dim_from_distribution( + transition_dist, "state_evolution(x, u, t_now, t_next)" + ) + if inferred_state_dim != state_dim: + raise ValueError( + "State transition shape is inconsistent with state_dim. " + f"state_dim={state_dim}, inferred={inferred_state_dim}." ) - if inferred_state_dim != state_dim: - raise ValueError( - "State transition shape is inconsistent with state_dim. " - f"state_dim={state_dim}, inferred={inferred_state_dim}." - ) - return None def _validate_continuous_time_flag( @@ -255,25 +190,18 @@ def _inside_numpyro_plate_context() -> bool: def _infer_observation_dim_in_plate_context( *, - initial_condition: Any, observation_model: Callable[[State, Control | None, Time], Any], - inferred_state_dim: int, - control_dim: int, - t0: Time | None, + x_probe: State, + u_probe: Control | None, + t_probe: Time, observation_dim: int | None, ) -> int: """Infer observation dimension in plate context, falling back to explicit value.""" if observation_dim is not None: return int(observation_dim) - x0 = _make_probe_state( - initial_condition=initial_condition, - state_dim=inferred_state_dim, - ) - u0 = None if control_dim == 0 else jnp.zeros((control_dim,)) - dummy_t0 = jnp.array(0.0) if t0 is None else t0 try: - obs_dist = observation_model(x0, u0, dummy_t0) + obs_dist = observation_model(x_probe, u_probe, t_probe) return int( _infer_vector_dim_from_distribution( obs_dist, diff --git a/dynestyx/models/core.py b/dynestyx/models/core.py index 47c3d87c..37b475a0 100644 --- a/dynestyx/models/core.py +++ b/dynestyx/models/core.py @@ -1,5 +1,7 @@ """Core interfaces and base classes for dynamical models.""" +from __future__ import annotations + from collections.abc import Callable from typing import Any, Protocol @@ -9,19 +11,27 @@ from numpyro.distributions import Distribution from dynestyx.models.checkers import ( - _infer_bm_dim, _infer_observation_dim_in_plate_context, _infer_vector_dim_from_distribution, _inside_numpyro_plate_context, _is_categorical_distribution, _make_probe_state, _validate_categorical_state, + _validate_continuous_state_evolution, _validate_continuous_time_flag, + _validate_discrete_state_evolution_output_shape, _validate_observation_dim, _validate_state_dim, - _validate_state_evolution_output_shape, ) -from dynestyx.types import Control, State, Time, TimeLike, as_scalar_time_array, dState +from dynestyx.models.diffusions import Diffusion +from dynestyx.types import ( + Control, + State, + Time, + TimeLike, + as_scalar_time_array, + dState, +) class DynamicalModel(eqx.Module): @@ -84,7 +94,9 @@ class DynamicalModel(eqx.Module): initial_condition: Distribution state_evolution: ( - Callable[[State, Control, Time], State] + ContinuousTimeStateEvolution + | DiscreteTimeStateEvolution + | Callable[[State, Control, Time], State] | Callable[[State, Control, Time, Time], State] ) observation_model: Callable[[State, Control, Time], Distribution] @@ -151,69 +163,94 @@ def __init__( # Skip shape validation when inside a numpyro plate context, since # batched parameters produce shapes that don't match unbatched expectations. + + def _make_probes() -> tuple[State, Control | None, Time]: + """Build synthetic inputs for validation/metadata resolution.""" + x_probe = _make_probe_state( + initial_condition=initial_condition, + state_dim=inferred_state_dim, + ) + u_probe = None if control_dim == 0 else jnp.zeros((control_dim,)) + t_probe = jnp.array(0.0) if self.t0 is None else self.t0 + return x_probe, u_probe, t_probe + + x_probe, u_probe, t_probe = _make_probes() + + def _resolve_continuous_state_evolution( + current_state_evolution: ContinuousTimeStateEvolution, + ) -> ( + DeterministicContinuousTimeStateEvolution + | StochasticContinuousTimeStateEvolution + ): + """Return either a DeterministicContinuousTimeStateEvolution or a StochasticContinuousTimeStateEvolution. + If diffusion is present, lazily build probes to resolve its metadata (e.g., bm_dim). + """ + diffusion = current_state_evolution.diffusion + if diffusion is None: + if isinstance( + current_state_evolution, DeterministicContinuousTimeStateEvolution + ): + return current_state_evolution + return DeterministicContinuousTimeStateEvolution( + drift=current_state_evolution.drift, + potential=current_state_evolution.potential, + use_negative_gradient=current_state_evolution.use_negative_gradient, + ) + + resolved_diffusion = diffusion.resolve_metadata( + state_dim=inferred_state_dim, + x_probe=x_probe, + u_probe=u_probe, + t_probe=t_probe, + ) + return StochasticContinuousTimeStateEvolution( + drift=current_state_evolution.drift, + potential=current_state_evolution.potential, + use_negative_gradient=current_state_evolution.use_negative_gradient, + diffusion=resolved_diffusion, + ) + if _inside_plate: # Cannot validate shapes with batched parameters; trust the user. # Infer observation_dim from observation model if not explicitly provided. inferred_obs_dim = _infer_observation_dim_in_plate_context( - initial_condition=initial_condition, observation_model=observation_model, - inferred_state_dim=inferred_state_dim, - control_dim=control_dim, - t0=self.t0, + x_probe=x_probe, + u_probe=u_probe, + t_probe=t_probe, observation_dim=observation_dim, ) - self.state_dim = int(inferred_state_dim) - self.observation_dim = inferred_obs_dim - self.control_dim = int(control_dim) - self.categorical_state = bool(inferred_categorical_state) - - # Infer bm_dim for continuous-time models - if inferred_continuous_time and isinstance( - state_evolution, ContinuousTimeStateEvolution - ): - x0 = jnp.zeros((inferred_state_dim,)) - u0 = None if control_dim == 0 else jnp.zeros((control_dim,)) - dummy_t0 = jnp.array(0.0) if self.t0 is None else self.t0 - inferred_bm_dim = _infer_bm_dim( - state_evolution, inferred_state_dim, x0, u0, dummy_t0 + else: + if self.continuous_time: + _validate_continuous_state_evolution( + state_evolution=state_evolution, + state_dim=inferred_state_dim, + x_probe=x_probe, + u_probe=u_probe, + t_probe=t_probe, + ) + else: + _validate_discrete_state_evolution_output_shape( + state_evolution=state_evolution, + state_dim=inferred_state_dim, + x_probe=x_probe, + u_probe=u_probe, + t_probe=t_probe, ) - if inferred_bm_dim is not None: - if ( - state_evolution.bm_dim is not None - and inferred_bm_dim != state_evolution.bm_dim - ): - raise ValueError( - "bm_dim does not match inferred diffusion_coefficient output shape. " - f"Got bm_dim={state_evolution.bm_dim}, inferred={inferred_bm_dim}." - ) - object.__setattr__(state_evolution, "bm_dim", inferred_bm_dim) - return - - x0 = _make_probe_state( - initial_condition=initial_condition, state_dim=inferred_state_dim - ) - u0 = None if control_dim == 0 else jnp.zeros((control_dim,)) - dummy_t0 = jnp.array(0.0) if self.t0 is None else self.t0 - - inferred_bm_dim = _validate_state_evolution_output_shape( - state_evolution=state_evolution, - state_dim=inferred_state_dim, - x0=x0, - u0=u0, - t0=dummy_t0, - continuous_time=self.continuous_time, - ) - if self.continuous_time and inferred_bm_dim != state_evolution.bm_dim: - object.__setattr__(state_evolution, "bm_dim", inferred_bm_dim) - obs_dist = observation_model(x0, u0, dummy_t0) - inferred_observation_dim = _infer_vector_dim_from_distribution( - obs_dist, "observation_model(x, u, t)" - ) - _validate_observation_dim(observation_dim, inferred_observation_dim) + obs_dist = observation_model(x_probe, u_probe, t_probe) + inferred_obs_dim = _infer_vector_dim_from_distribution( + obs_dist, "observation_model(x, u, t)" + ) + _validate_observation_dim(observation_dim, inferred_obs_dim) + + if self.continuous_time: + self.state_evolution = _resolve_continuous_state_evolution( + state_evolution, + ) self.state_dim = int(inferred_state_dim) - self.observation_dim = int(inferred_observation_dim) + self.observation_dim = int(inferred_obs_dim) self.control_dim = int(control_dim) self.categorical_state = bool(inferred_categorical_state) @@ -321,19 +358,15 @@ class ContinuousTimeStateEvolution(eqx.Module): At least one of `drift` or `potential` must be non-None. use_negative_gradient (bool): If True, use $-\\nabla_x V$ (e.g., gradient descent on potential); otherwise use $+\\nabla_x V$. Default is False. - diffusion_coefficient (Drift | None): Diffusion coefficient $L(x, u, t)$ mapping to a matrix; - multiplies the Brownian increment $dW_t$. - Defaults to zero if None (i.e., deterministic ODE). - bm_dim (int | None): Dimension of the Brownian motion $W_t$. - Inferred automatically from the output shape of `diffusion_coefficient`; - if passed by the user, it must match diffusion_coefficient(...).shape[1]. + diffusion (Diffusion | None): Diffusion coefficient object. + Use `FullDiffusion`, `DiagonalDiffusion`, or `ScalarDiffusion` to define + the stochastic part of the SDE. Pass `None` for deterministic dynamics. """ drift: Drift | None = None potential: Potential | None = None use_negative_gradient: bool = eqx.field(static=True, default=False) - diffusion_coefficient: Drift | None = None - bm_dim: int | None = eqx.field(static=True, default=None) + diffusion: Diffusion | None = None def total_drift(self, x: State, u: Control | None, t: Time) -> dState: base = self.drift(x, u, t) if self.drift is not None else None @@ -354,6 +387,73 @@ def total_drift(self, x: State, u: Control | None, t: Time) -> dState: return base + grad_term +class DeterministicContinuousTimeStateEvolution(ContinuousTimeStateEvolution): + """Continuous-time state evolution with no diffusion term, i.e., describing an ODE. + + This is a refined form of :class:`ContinuousTimeStateEvolution` used when a + model has deterministic continuous-time dynamics, i.e. an ODE rather than an + SDE. In most user code you should construct + :class:`ContinuousTimeStateEvolution` directly and let + :class:`DynamicalModel` refine it into this subclass when ``diffusion=None``. + Its main semantic guarantee is that ``diffusion`` is always ``None``. + """ + + diffusion: None = eqx.field(static=True, default=None) + + def __init__( + self, + drift: Drift | None = None, + potential: Potential | None = None, + use_negative_gradient: bool = False, + diffusion: None = None, + ): + if diffusion is not None: + raise ValueError( + "DeterministicContinuousTimeStateEvolution does not accept diffusion." + ) + self.drift = drift + self.potential = potential + self.use_negative_gradient = use_negative_gradient + self.diffusion = None + + +class StochasticContinuousTimeStateEvolution(ContinuousTimeStateEvolution): + """Continuous-time state evolution with resolved stochastic diffusion. + + This is a refined form of :class:`ContinuousTimeStateEvolution` used for SDE + models after the diffusion metadata has been resolved. In practice that means + the attached :class:`~dynestyx.models.diffusions.Diffusion` has a known + ``bm_dim`` and can therefore be used safely by downstream SDE solvers, + discretizers, and inference backends. + """ + + diffusion: Diffusion = eqx.field(static=True, kw_only=True) + + def __init__( + self, + *, + drift: Drift | None = None, + potential: Potential | None = None, + use_negative_gradient: bool = False, + diffusion: Diffusion, + ): + if diffusion.bm_dim is None: + raise ValueError( + "StochasticContinuousTimeStateEvolution requires diffusion with " + "resolved bm_dim." + ) + self.drift = drift + self.potential = potential + self.use_negative_gradient = use_negative_gradient + self.diffusion = diffusion + + @property + def bm_dim(self) -> int: + bm_dim = self.diffusion.bm_dim + assert bm_dim is not None + return bm_dim + + class DiscreteTimeStateEvolution(eqx.Module): """ Discrete-time state evolution via Markov transition distributions. diff --git a/dynestyx/models/diffusions.py b/dynestyx/models/diffusions.py new file mode 100644 index 00000000..81b580fe --- /dev/null +++ b/dynestyx/models/diffusions.py @@ -0,0 +1,412 @@ +"""Diffusion objects for continuous-time state evolution.""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import NamedTuple, cast + +import equinox as eqx +import jax +import jax.numpy as jnp +from jax import Array + +from dynestyx.types import Control, State, Time + +DiffusionValue = Array | float | int +DiffusionSpec = Callable[[State, Control | None, Time], DiffusionValue] | DiffusionValue + + +class EvaluatedDiffusion(NamedTuple): + """A diffusion coefficient evaluated at a specific ``(x, u, t)``. + + This is primarily a developer-facing helper used by solvers and backend + integrations. It pairs a structured ``Diffusion`` object with the concrete + value of its coefficient at one state, control, and time. + """ + + diffusion: Diffusion + value: Array + + def as_matrix(self, *, state_dim: int) -> Array: + """Return the evaluated diffusion coefficient as a matrix ``L``.""" + return self.diffusion._value_as_matrix(self.value, state_dim=state_dim) + + def gram_matrix(self, *, state_dim: int) -> Array: + """L L^T.""" + return self.diffusion._value_gram_matrix(self.value, state_dim=state_dim) + + def apply(self, dw: Array, *, state_dim: int) -> Array: + """Return ``L @ dw`` using the structured diffusion representation.""" + return self.diffusion._apply_value(self.value, dw, state_dim=state_dim) + + +class Diffusion(eqx.Module): + """Base class for diffusion coefficients in SDEs. + + A diffusion encapsulates both the numeric coefficient ``L(x, u, t)`` and the + structural interpretation of that coefficient inside the SDE + + ``dx_t = f(x_t, u_t, t) dt + L(x_t, u_t, t) dW_t``. + + Most users should instantiate one of the concrete subclasses: + + - ``FullDiffusion`` for an arbitrary matrix-valued coefficient. + - ``DiagonalDiffusion`` for axis-aligned loadings. + - ``ScalarDiffusion`` for isotropic or shared-driver noise. + + The ``coefficient`` may be either: + + - a constant array or scalar, for time-homogeneous diffusion, or + - a callable ``(x, u, t) -> value`` for state-, control-, or time-dependent diffusion. + + ``bm_dim`` is the Brownian dimension ``d_w``. For ``FullDiffusion`` it can be + inferred from the matrix shape. For ``DiagonalDiffusion`` and + ``ScalarDiffusion`` it must be specified explicitly and must be either ``1`` or + ``state_dim``. + """ + + coefficient: DiffusionSpec + bm_dim: int | None = eqx.field(static=True, default=None) + + def __init__( + self, + coefficient: DiffusionSpec, + bm_dim: int | None = None, + ): + self.coefficient = ( + coefficient if callable(coefficient) else jnp.asarray(coefficient) + ) + self.bm_dim = None if bm_dim is None else int(bm_dim) + self._validate_init() + + def evaluate_value( + self, + *, + x: State, + u: Control | None, + t: Time, + ) -> Array: + """Return the raw coefficient value at ``(x, u, t)``.""" + if callable(self.coefficient): + return jnp.asarray(self.coefficient(x, u, t)) + return cast(Array, self.coefficient) + + def resolve_metadata( + self, + *, + state_dim: int, + x_probe: State, + u_probe: Control | None, + t_probe: Time, + ) -> Diffusion: + """Check coefficient shape and resolve ``bm_dim`` if needed.""" + raise NotImplementedError + + def evaluate( + self, + *, + x: State, + u: Control | None, + t: Time, + ) -> EvaluatedDiffusion: + """Evaluate the diffusion at ``(x, u, t)``.""" + return EvaluatedDiffusion(self, self.evaluate_value(x=x, u=u, t=t)) + + def as_matrix( + self, + *, + x: State, + u: Control | None, + t: Time, + state_dim: int, + ) -> Array: + """Return the matrix-valued diffusion coefficient ``L(x, u, t)``.""" + return self.evaluate(x=x, u=u, t=t).as_matrix(state_dim=state_dim) + + def gram_matrix( + self, + *, + x: State, + u: Control | None, + t: Time, + state_dim: int, + ) -> Array: + """Return the diffusion Gram matrix ``L(x,u,t) L(x,u,t)^T``.""" + return self.evaluate(x=x, u=u, t=t).gram_matrix(state_dim=state_dim) + + def apply( + self, + dw: Array, + *, + x: State, + u: Control | None, + t: Time, + state_dim: int, + ) -> Array: + """Apply the diffusion coefficient to a Brownian increment ``dw``.""" + return self.evaluate(x=x, u=u, t=t).apply(dw, state_dim=state_dim) + + def _with_bm_dim(self, bm_dim: int) -> Diffusion: + return type(self)(self.coefficient, bm_dim=int(bm_dim)) + + def _constant_shape(self) -> tuple[int, ...] | None: + if callable(self.coefficient): + return None + return tuple(int(d) for d in jnp.shape(self.coefficient)) + + def _validate_init(self) -> None: + if self.bm_dim is not None and int(self.bm_dim) <= 0: + raise ValueError(f"bm_dim must be positive. Got bm_dim={self.bm_dim}.") + + def _value_as_matrix(self, value: Array, *, state_dim: int) -> Array: + raise NotImplementedError( + "Please don't construct `Diffusion` directly; instead instantiate one of its subclasses (e.g., `FullDiffusion`, `DiagonalDiffusion`, or `ScalarDiffusion`)" + ) + + def _value_gram_matrix(self, value: Array, *, state_dim: int) -> Array: + raise NotImplementedError + + def _apply_value(self, value: Array, dw: Array, *, state_dim: int) -> Array: + raise NotImplementedError + + +class FullDiffusion(Diffusion): + """General matrix-valued diffusion coefficient. + + Use ``FullDiffusion`` when you want to specify the diffusion matrix + ``L(x, u, t)`` directly. + + Args: + coefficient: Either a constant array with trailing shape + ``(..., state_dim, bm_dim)`` or a callable ``(x, u, t) -> array`` + with that trailing shape. + bm_dim: Optional Brownian dimension. If omitted for a constant + coefficient, it is inferred from the trailing matrix dimension. + + This is the most general public diffusion class. Prefer it when your model + genuinely needs a dense or otherwise unstructured loading matrix. + """ + + def _validate_init(self) -> None: + super()._validate_init() + shape = self._constant_shape() + if shape is not None and len(shape) < 2: + raise ValueError( + "Full diffusion requires a matrix-valued constant coefficient with " + "trailing shape (..., state_dim, bm_dim). " + f"Got shape {shape}." + ) + if shape is not None and self.bm_dim is None: + self.bm_dim = int(shape[-1]) + + def resolve_metadata( + self, + *, + state_dim: int, + x_probe: State, + u_probe: Control | None, + t_probe: Time, + ) -> FullDiffusion: + """Check matrix shape and infer ``bm_dim`` from the trailing dimension if needed.""" + shape = jax.eval_shape( + lambda: self.evaluate_value(x=x_probe, u=u_probe, t=t_probe) + ).shape + if len(shape) < 2 or int(shape[-2]) != state_dim: + raise ValueError( + "Full diffusion must have trailing shape (..., state_dim, bm_dim). " + f"Got shape {shape} with state_dim={state_dim}." + ) + inferred_bm_dim = int(shape[-1]) + if self.bm_dim is not None and int(self.bm_dim) != inferred_bm_dim: + raise ValueError( + "bm_dim does not match inferred diffusion output shape. " + f"Got bm_dim={self.bm_dim}, inferred={inferred_bm_dim} from shape {shape}." + ) + return ( + self + if self.bm_dim == inferred_bm_dim + else cast(FullDiffusion, self._with_bm_dim(inferred_bm_dim)) + ) + + def _value_as_matrix(self, value: Array, *, state_dim: int) -> Array: + return value + + def _value_gram_matrix(self, value: Array, *, state_dim: int) -> Array: + return value @ jnp.swapaxes(value, -1, -2) + + def _apply_value(self, value: Array, dw: Array, *, state_dim: int) -> Array: + return value @ dw + + +class DiagonalDiffusion(Diffusion): + """Vector-valued diffusion with diagonal or shared-driver interpretation. + + Use ``DiagonalDiffusion(v, bm_dim=...)`` when the diffusion is naturally + parameterized by a vector ``v`` of length ``state_dim``. + + Args: + coefficient: Either a constant vector with trailing shape + ``(..., state_dim)`` or a callable ``(x, u, t) -> array`` with that + trailing shape. + bm_dim: Brownian dimension. Must be either ``1`` or ``state_dim``. + + - If ``bm_dim = state_dim``, the vector is interpreted as the diagonal of + ``L = diag(v)``. + - If ``bm_dim = 1``, the vector is interpreted as a column loading vector, + so all state coordinates share one Brownian driver. + + This is often the right public API choice when each state coordinate has its + own scale but you do not want to write out a full matrix. + """ + + def _validate_init(self) -> None: + super()._validate_init() + if self.bm_dim is None: + raise ValueError( + "Diagonal diffusion requires explicit bm_dim. " + "For diagonal diffusion, bm_dim must be either 1 or state_dim." + ) + shape = self._constant_shape() + if shape is not None and len(shape) == 0: + raise ValueError( + "Diagonal diffusion requires a vector-valued constant coefficient " + "with trailing shape (..., state_dim). " + f"Got shape {shape}." + ) + + def resolve_metadata( + self, + *, + state_dim: int, + x_probe: State, + u_probe: Control | None, + t_probe: Time, + ) -> DiagonalDiffusion: + """Check vector shape and verify that ``bm_dim`` is either 1 or ``state_dim``.""" + shape = jax.eval_shape( + lambda: self.evaluate_value(x=x_probe, u=u_probe, t=t_probe) + ).shape + if len(shape) == 0 or int(shape[-1]) != state_dim: + raise ValueError( + "Diagonal diffusion must have trailing shape (..., state_dim). " + f"Got shape {shape} with state_dim={state_dim}." + ) + bm_dim = self.bm_dim + assert bm_dim is not None + if bm_dim not in (1, state_dim): + raise ValueError( + "Diagonal diffusion requires bm_dim to be either 1 or state_dim. " + f"Got bm_dim={bm_dim}, state_dim={state_dim}." + ) + return self + + def _value_as_matrix(self, value: Array, *, state_dim: int) -> Array: + assert self.bm_dim is not None + if self.bm_dim == 1: + return value[..., :, None] + return value[..., :, None] * jnp.eye(state_dim, dtype=value.dtype) + + def _value_gram_matrix(self, value: Array, *, state_dim: int) -> Array: + assert self.bm_dim is not None + if self.bm_dim == 1: + return value[..., :, None] * value[..., None, :] + return jnp.square(value)[..., :, None] * jnp.eye(state_dim, dtype=value.dtype) + + def _apply_value(self, value: Array, dw: Array, *, state_dim: int) -> Array: + assert self.bm_dim is not None + if self.bm_dim == 1: + return value * dw[..., 0] + return value * dw + + +class ScalarDiffusion(Diffusion): + """Scalar-valued diffusion with isotropic or shared-driver interpretation. + + Use ``ScalarDiffusion(sigma, bm_dim=...)`` when the diffusion is controlled + by a single scalar scale ``sigma``. + + Args: + coefficient: Either a scalar, a constant array with trailing shape + ``(..., 1)``, or a callable ``(x, u, t) -> scalar_or_length_1_array``. + bm_dim: Brownian dimension. Must be either ``1`` or ``state_dim``. + + - If ``bm_dim = state_dim``, this represents isotropic diffusion + ``L = sigma I``. + - If ``bm_dim = 1``, this represents a shared scalar driver applied equally + to every state coordinate. + + This is typically the simplest public API choice for constant isotropic + diffusion, and is usually preferable to writing ``sigma * eye(state_dim)`` + by hand. + """ + + def _validate_init(self) -> None: + super()._validate_init() + if self.bm_dim is None: + raise ValueError( + "Scalar diffusion requires explicit bm_dim. " + "For scalar diffusion, bm_dim must be either 1 or state_dim." + ) + shape = self._constant_shape() + if shape is not None and len(shape) != 0 and int(shape[-1]) != 1: + raise ValueError( + "Scalar diffusion requires a scalar constant coefficient or trailing " + "shape (..., 1). " + f"Got shape {shape}." + ) + + def resolve_metadata( + self, + *, + state_dim: int, + x_probe: State, + u_probe: Control | None, + t_probe: Time, + ) -> ScalarDiffusion: + """Check scalar shape and verify that ``bm_dim`` is either 1 or ``state_dim``.""" + shape = jax.eval_shape( + lambda: self.evaluate_value(x=x_probe, u=u_probe, t=t_probe) + ).shape + if len(shape) != 0 and int(shape[-1]) != 1: + raise ValueError( + "Scalar diffusion must have shape () or trailing shape (..., 1). " + f"Got shape {shape}." + ) + bm_dim = self.bm_dim + assert bm_dim is not None + if bm_dim not in (1, state_dim): + raise ValueError( + "Scalar diffusion requires bm_dim to be either 1 or state_dim. " + f"Got bm_dim={bm_dim}, state_dim={state_dim}." + ) + return self + + def _scalar_value(self, value: Array) -> Array: + return value if value.ndim == 0 else jnp.squeeze(value, axis=-1) + + def _value_as_matrix(self, value: Array, *, state_dim: int) -> Array: + scalar = self._scalar_value(value) + assert self.bm_dim is not None + if self.bm_dim == 1: + return jnp.broadcast_to( + scalar[..., None, None], scalar.shape + (state_dim, 1) + ) + return scalar[..., None, None] * jnp.eye(state_dim, dtype=value.dtype) + + def _value_gram_matrix(self, value: Array, *, state_dim: int) -> Array: + sigma_sq = jnp.square(self._scalar_value(value)) + assert self.bm_dim is not None + if self.bm_dim == 1: + return sigma_sq[..., None, None] * jnp.ones( + (state_dim, state_dim), dtype=value.dtype + ) + return sigma_sq[..., None, None] * jnp.eye(state_dim, dtype=value.dtype) + + def _apply_value(self, value: Array, dw: Array, *, state_dim: int) -> Array: + scalar = self._scalar_value(value) + assert self.bm_dim is not None + if self.bm_dim == 1: + return jnp.broadcast_to( + (scalar * dw[..., 0])[..., None], scalar.shape + (state_dim,) + ) + return scalar[..., None] * dw diff --git a/dynestyx/models/lti_dynamics.py b/dynestyx/models/lti_dynamics.py index e7c98776..a0a5bcb5 100644 --- a/dynestyx/models/lti_dynamics.py +++ b/dynestyx/models/lti_dynamics.py @@ -6,6 +6,7 @@ ContinuousTimeStateEvolution, DynamicalModel, ) +from dynestyx.models.diffusions import FullDiffusion from dynestyx.models.observations import LinearGaussianObservation from dynestyx.models.state_evolution import AffineDrift, LinearGaussianStateEvolution @@ -175,7 +176,7 @@ def LTI_continuous( state_evolution = ContinuousTimeStateEvolution( drift=drift, - diffusion_coefficient=lambda x, u, t: L, + diffusion=FullDiffusion(L), ) observation_model = LinearGaussianObservation(H=H, R=R, D=D, bias=d) diff --git a/dynestyx/simulators.py b/dynestyx/simulators.py index 6869349a..d231e2ac 100644 --- a/dynestyx/simulators.py +++ b/dynestyx/simulators.py @@ -21,10 +21,11 @@ from dynestyx.handlers import HandlesSelf, _sample_intp from dynestyx.inference.integrations.utils import WeightedParticles from dynestyx.models import ( - ContinuousTimeStateEvolution, + DeterministicContinuousTimeStateEvolution, DiracIdentityObservation, DiscreteTimeStateEvolution, DynamicalModel, + StochasticContinuousTimeStateEvolution, ) from dynestyx.solvers import solve_ode, solve_sde from dynestyx.types import FunctionOfTime, State, Time, TimeLike, as_scalar_time_array @@ -32,7 +33,6 @@ _array_has_plate_dims, _build_control_path, _dist_has_plate_batch_dims, - _ensure_continuous_bm_dim, _get_val_or_None, _has_any_batched_plate_source, _leaf_is_plate_batched, @@ -260,8 +260,6 @@ def _run_single_member_simulation( **kwargs, ) -> dict[str, Array] | None: """Run simulator logic for one unbatched member and return trajectories.""" - dynamics = _ensure_continuous_bm_dim(dynamics) - use_smoothed_rollout = smoothed_times is not None or smoothed_dists is not None if use_smoothed_rollout and ( filtered_times is not None or filtered_dists is not None @@ -274,7 +272,6 @@ def _run_single_member_simulation( rollout_times = smoothed_times if use_smoothed_rollout else filtered_times rollout_dists = smoothed_dists if use_smoothed_rollout else filtered_dists rollout_label = "smoothed" if use_smoothed_rollout else "filtered" - if ( rollout_times is not None and rollout_dists is None @@ -832,7 +829,7 @@ def _simulate( Args: dynamics: A `DynamicalModel` whose `state_evolution` is a - `ContinuousTimeStateEvolution` with a non-None diffusion coefficient + `ContinuousTimeStateEvolution` with a non-None diffusion and inferred `bm_dim` (set during `DynamicalModel` construction). obs_times: Times at which to save the latent state and emit observations. Required. @@ -851,16 +848,12 @@ def _simulate( parameter inference for SDEs, because it introduces an explicit, high- dimensional latent path. Prefer filtering (`Filter`) or particle methods. """ - if not isinstance(dynamics.state_evolution, ContinuousTimeStateEvolution): + if not isinstance( + dynamics.state_evolution, StochasticContinuousTimeStateEvolution + ): raise NotImplementedError( - f"SDESimulator only works with ContinuousTimeStateEvolution, got {type(dynamics.state_evolution)}" - ) - - if dynamics.state_evolution.diffusion_coefficient is None: - raise ValueError( - "SDESimulator requires diffusion_coefficient to be defined " - f"(got coeff={dynamics.state_evolution.diffusion_coefficient}). " - "Use ODESimulator for deterministic dynamics." + "SDESimulator only works with StochasticContinuousTimeStateEvolution, got " + f"{type(dynamics.state_evolution)}" ) if obs_times is not None: @@ -1255,7 +1248,7 @@ def _simulate( Args: dynamics: A `DynamicalModel` whose `state_evolution` is a - `ContinuousTimeStateEvolution` with deterministic dynamics. + `DeterministicContinuousTimeStateEvolution`. obs_times: Times at which to save the latent state and emit observations. obs_values: Optional observation array. If provided, observation sites are conditioned via `obs=obs_values[i]`. @@ -1379,11 +1372,14 @@ def _simulate( **kwargs, ) -> dict[str, State]: if self.simulator is None: - if isinstance(dynamics.state_evolution, ContinuousTimeStateEvolution): - if dynamics.state_evolution.diffusion_coefficient is None: - self.simulator = ODESimulator(*self.args, **self.kwargs) - else: - self.simulator = SDESimulator(*self.args, **self.kwargs) + if isinstance( + dynamics.state_evolution, StochasticContinuousTimeStateEvolution + ): + self.simulator = SDESimulator(*self.args, **self.kwargs) + elif isinstance( + dynamics.state_evolution, DeterministicContinuousTimeStateEvolution + ): + self.simulator = ODESimulator(*self.args, **self.kwargs) elif isinstance(dynamics.state_evolution, DiscreteTimeStateEvolution): self.simulator = DiscreteTimeSimulator(*self.args, **self.kwargs) else: diff --git a/dynestyx/solvers/sde.py b/dynestyx/solvers/sde.py index 09e0da53..7f4986be 100644 --- a/dynestyx/solvers/sde.py +++ b/dynestyx/solvers/sde.py @@ -11,29 +11,11 @@ import jax.random as jr from jax import Array, lax, vmap -from dynestyx.models import ContinuousTimeStateEvolution, DynamicalModel +from dynestyx.models import DynamicalModel, StochasticContinuousTimeStateEvolution +from dynestyx.models.diffusions import EvaluatedDiffusion from dynestyx.types import State, Time, TimeLike, as_scalar_time_array -def _apply_diffusion(diffusion_term: Array, dw: Array) -> Array: - """Apply diffusion operator to a Brownian increment. - - Args: - diffusion_term: Diffusion coefficient with shape compatible with `dw`. - dw: Brownian increment vector. - - Returns: - State increment induced by the diffusion term. - """ - if diffusion_term.ndim == 0: - return diffusion_term * dw[0] - if diffusion_term.ndim == 1: - if dw.shape[0] == 1: - return diffusion_term * dw[0] - return diffusion_term * dw - return diffusion_term @ dw - - def _early_return_states(x0: State, saveat_times: Array) -> Array: """Build no-op solve output by repeating the initial state. @@ -47,43 +29,13 @@ def _early_return_states(x0: State, saveat_times: Array) -> Array: return jnp.broadcast_to(x0, (len(saveat_times),) + jnp.shape(x0)) -def _require_bm_dim(state_evolution: ContinuousTimeStateEvolution) -> int: - """Return Brownian dimension or raise if unspecified. - - Args: - state_evolution: Continuous-time state evolution. - - Returns: - Brownian motion dimension used by EM sampling. - """ - if state_evolution.bm_dim is None: - raise ValueError("SDE sampling requires state_evolution.bm_dim to be set.") - return int(state_evolution.bm_dim) - - -def _require_diffusion_fn( - state_evolution: ContinuousTimeStateEvolution, -) -> Callable[[Array, Array | None, Array], Array]: - """Get diffusion callable or raise if unavailable. - - Args: - state_evolution: Continuous-time state evolution. - - Returns: - Diffusion function with signature `(x, u, t) -> diffusion`. - """ - diffusion_fn = state_evolution.diffusion_coefficient - if diffusion_fn is None: - raise ValueError("SDE solver requires diffusion_coefficient to be defined.") - return diffusion_fn - - def _em_local_terms( - state_evolution: ContinuousTimeStateEvolution, + state_evolution: StochasticContinuousTimeStateEvolution, + diffusion, x: Array, u: Array | None, t_now: Array, -) -> tuple[Array, Array]: +) -> tuple[Array, EvaluatedDiffusion]: """Compute local EM drift and diffusion terms. Args: @@ -96,13 +48,11 @@ def _em_local_terms( Tuple `(drift, diffusion)` at `(x, u, t_now)`. """ drift = state_evolution.total_drift(x=x, u=u, t=t_now) - diffusion_fn = _require_diffusion_fn(state_evolution) - diffusion = jnp.asarray(diffusion_fn(x, u, t_now)) - return drift, diffusion + return drift, diffusion.evaluate(x=x, u=u, t=t_now) def _em_moments_from_terms( - x: Array, dt: Array, drift: Array, diffusion: Array + x: Array, dt: Array, drift: Array, diffusion: EvaluatedDiffusion ) -> tuple[Array, Array]: """Convert local EM terms to one-step Gaussian moments. @@ -116,7 +66,7 @@ def _em_moments_from_terms( Tuple `(loc, cov)` for the EM Gaussian approximation. """ loc = x + drift * dt - cov = diffusion @ diffusion.T * dt + cov = diffusion.gram_matrix(state_dim=x.shape[-1]) * dt return loc, cov @@ -124,7 +74,7 @@ def _em_sample_from_terms( x: Array, dt: Array, drift: Array, - diffusion: Array, + diffusion: EvaluatedDiffusion, *, key: Array, bm_dim: int, @@ -145,12 +95,12 @@ def _em_sample_from_terms( key_next, k_step = jr.split(key) z = jr.normal(k_step, shape=(bm_dim,), dtype=jnp.asarray(x).dtype) dw = jnp.sqrt(dt) * z - x_next = x + drift * dt + _apply_diffusion(diffusion, dw) + x_next = x + drift * dt + diffusion.apply(dw, state_dim=x.shape[-1]) return x_next, key_next def euler_maruyama_integrate_state_to_time( - state_evolution: ContinuousTimeStateEvolution, + state_evolution: StochasticContinuousTimeStateEvolution, x_init: Array, t_init: Time, key_init: Array, @@ -177,7 +127,8 @@ def euler_maruyama_integrate_state_to_time( dt0, dt0 <= 0, f"EM integration requires dt0 > 0, got dt0={dt0!r}." ) - bm_dim = _require_bm_dim(state_evolution) + diffusion = state_evolution.diffusion + bm_dim = state_evolution.bm_dim def _cond_fn(carry): _, t_curr, _, t_end = carry @@ -187,9 +138,11 @@ def _body_fn(carry): x_curr, t_curr, key_curr, t_end = carry h = jnp.minimum(dt0, t_end - t_curr) u_t = control_path_eval(t_curr) if control_path_eval is not None else None - drift, diffusion = _em_local_terms(state_evolution, x_curr, u_t, t_curr) + drift, evaluated_diffusion = _em_local_terms( + state_evolution, diffusion, x_curr, u_t, t_curr + ) x_next, key_next = _em_sample_from_terms( - x_curr, h, drift, diffusion, key=key_curr, bm_dim=bm_dim + x_curr, h, drift, evaluated_diffusion, key=key_curr, bm_dim=bm_dim ) return x_next, t_curr + h, key_next, t_end @@ -199,7 +152,7 @@ def _body_fn(carry): def euler_maruyama_loc_cov( - state_evolution: ContinuousTimeStateEvolution, + state_evolution: StochasticContinuousTimeStateEvolution, x: Array, u: Array | None, t_now: Array, @@ -234,10 +187,13 @@ def euler_maruyama_loc_cov( the time axis to the front as a side effect of `vmap`. """ x_arr = jnp.asarray(x) + diffusion = state_evolution.diffusion def _step_interval(_x, _u, _t_now, _t_next): - drift, diffusion = _em_local_terms(state_evolution, _x, _u, _t_now) - return _em_moments_from_terms(_x, _t_next - _t_now, drift, diffusion) + drift, evaluated_diffusion = _em_local_terms( + state_evolution, diffusion, _x, _u, _t_now + ) + return _em_moments_from_terms(_x, _t_next - _t_now, drift, evaluated_diffusion) if x_arr.ndim == 1: loc, cov = _step_interval(x_arr, u, jnp.asarray(t_now), jnp.asarray(t_next)) @@ -324,9 +280,9 @@ def _solve_sde_scan( """ if key is None: raise ValueError("PRNG key is required for em_scan SDE solves.") - if not isinstance(dynamics.state_evolution, ContinuousTimeStateEvolution): + if not isinstance(dynamics.state_evolution, StochasticContinuousTimeStateEvolution): raise TypeError( - "SDE solver requires ContinuousTimeStateEvolution, got " + "SDE solver requires StochasticContinuousTimeStateEvolution, got " f"{type(dynamics.state_evolution)}" ) @@ -391,13 +347,14 @@ def _solve_sde_diffrax( """ if key is None: raise ValueError("PRNG key is required for diffrax SDE solves.") - if not isinstance(dynamics.state_evolution, ContinuousTimeStateEvolution): + if not isinstance(dynamics.state_evolution, StochasticContinuousTimeStateEvolution): raise TypeError( - "SDE solver requires ContinuousTimeStateEvolution, got " + "SDE solver requires StochasticContinuousTimeStateEvolution, got " f"{type(dynamics.state_evolution)}" ) state_evolution = dynamics.state_evolution + diffusion = state_evolution.diffusion def _drift(t, y, args): u_t = args(t) if args is not None else None @@ -405,7 +362,7 @@ def _drift(t, y, args): def _diffusion(t, y, args): u_t = args(t) if args is not None else None - return state_evolution.diffusion_coefficient(x=y, u=u_t, t=t) + return diffusion.as_matrix(x=y, u=u_t, t=t, state_dim=y.shape[-1]) k_bm, _ = jr.split(key, 2) bm = dfx.VirtualBrownianTree( @@ -415,11 +372,8 @@ def _diffusion(t, y, args): shape=(state_evolution.bm_dim,), key=k_bm, ) - terms = dfx.MultiTerm( # type: ignore[arg-type] - dfx.ODETerm(_drift), dfx.ControlTerm(_diffusion, bm) - ) sol = dfx.diffeqsolve( - terms, + dfx.MultiTerm(dfx.ODETerm(_drift), dfx.ControlTerm(_diffusion, bm)), t0=t0, t1=saveat_times[-1], y0=x0, diff --git a/dynestyx/utils.py b/dynestyx/utils.py index 0e57802b..7490ad41 100644 --- a/dynestyx/utils.py +++ b/dynestyx/utils.py @@ -9,8 +9,7 @@ from cd_dynamax import ContDiscreteNonlinearSSM as CDNLSSM from jax import Array, lax -from dynestyx.models import ContinuousTimeStateEvolution, DynamicalModel -from dynestyx.models.checkers import _infer_bm_dim +from dynestyx.models import DynamicalModel def flatten_draws(arr: Array) -> Array: @@ -170,28 +169,6 @@ def _has_any_batched_plate_source( return False -def _ensure_continuous_bm_dim(dynamics: DynamicalModel) -> DynamicalModel: - """Infer and set bm_dim when continuous dynamics were constructed in plates.""" - if not dynamics.continuous_time: - return dynamics - - state_evolution = dynamics.state_evolution - if ( - not isinstance(state_evolution, ContinuousTimeStateEvolution) - or state_evolution.diffusion_coefficient is None - or state_evolution.bm_dim is not None - ): - return dynamics - - x0 = jnp.zeros((dynamics.state_dim,)) - u0 = None if dynamics.control_dim == 0 else jnp.zeros((dynamics.control_dim,)) - t0 = jnp.array(0.0) if dynamics.t0 is None else jnp.asarray(dynamics.t0) - inferred_bm_dim = _infer_bm_dim(state_evolution, dynamics.state_dim, x0, u0, t0) - if inferred_bm_dim is not None: - object.__setattr__(state_evolution, "bm_dim", inferred_bm_dim) - return dynamics - - def _should_record_field( record_val: bool | None, shape: tuple[int, ...], max_elems: int ) -> bool: diff --git a/mkdocs.yml b/mkdocs.yml index 16efd33f..184b12bd 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -51,6 +51,7 @@ nav: - Core Models: - DynamicalModel: api_reference/public/models/core/dynamical_model.md - ContinuousTimeStateEvolution: api_reference/public/models/core/continuous_time_state_evolution.md + - Diffusion: api_reference/public/models/core/diffusion.md - DiscreteTimeStateEvolution: api_reference/public/models/core/discrete_time_state_evolution.md - ObservationModel: api_reference/public/models/core/observation_model.md - Drift: api_reference/public/models/core/drift.md @@ -88,6 +89,7 @@ nav: - Overview: api_reference/developer/index.md - Models: - Core Models: api_reference/developer/models/core_models.md + - Diffusions: api_reference/developer/models/diffusions.md - Specialized Models: api_reference/developer/models/specialized_models.md - Checkers: api_reference/developer/models/checkers.md - Inference: diff --git a/tests/models.py b/tests/models.py index 5eb8612f..ab09e4e1 100644 --- a/tests/models.py +++ b/tests/models.py @@ -7,6 +7,7 @@ ContinuousTimeStateEvolution, DiracIdentityObservation, DynamicalModel, + FullDiffusion, LinearGaussianObservation, LinearGaussianStateEvolution, ) @@ -173,7 +174,7 @@ def continuous_time_stochastic_l63_model( ) + (10 * u if u is not None else jnp.zeros(3)) ), - diffusion_coefficient=lambda x, u, t: jnp.eye(3), + diffusion=FullDiffusion(lambda x, u, t: jnp.eye(3)), ), observation_model=LinearGaussianObservation( H=jnp.array([[1.0, 0.0, 0.0]]), R=jnp.array([[1.0**2]]) @@ -229,7 +230,7 @@ def continuous_time_stochastic_l63_model_dirac_obs( ) + (10 * u if u is not None else jnp.zeros_like(x)) ), - diffusion_coefficient=lambda x, u, t: jnp.eye(3), + diffusion=FullDiffusion(lambda x, u, t: jnp.eye(3)), ), observation_model=DiracIdentityObservation(), ) @@ -290,7 +291,7 @@ def continuous_time_LTI_gaussian( ), state_evolution=ContinuousTimeStateEvolution( drift=lambda x, u, t: A @ x + (10 * u if u is not None else jnp.zeros(2)), - diffusion_coefficient=lambda x, u, t: jnp.eye(2), + diffusion=FullDiffusion(lambda x, u, t: jnp.eye(2)), ), observation_model=LinearGaussianObservation( H=jnp.array([[0.0, 1.0]]), R=jnp.array([[1.0**2]]) @@ -543,7 +544,7 @@ def jumpy_controls_model_sde( ): state_evolution = ContinuousTimeStateEvolution( drift=lambda x, u, t: x + u, - diffusion_coefficient=lambda x, u, t: 0.01 * jnp.eye(1), + diffusion=FullDiffusion(lambda x, u, t: 0.01 * jnp.eye(1)), ) dynamics = DynamicalModel( control_dim=1, diff --git a/tests/test_bm_dim_plate.py b/tests/test_bm_dim_plate.py index 77c367bd..6521da44 100644 --- a/tests/test_bm_dim_plate.py +++ b/tests/test_bm_dim_plate.py @@ -3,18 +3,69 @@ import jax.numpy as jnp import numpyro import numpyro.distributions as dist - -from dynestyx.models import ContinuousTimeStateEvolution, DynamicalModel - - -def test_bm_dim_inferred_outside_plate(): - """bm_dim is correctly inferred when not in a plate context.""" +import pytest + +from dynestyx.models import ( + ContinuousTimeStateEvolution, + DiagonalDiffusion, + DynamicalModel, + FullDiffusion, + ScalarDiffusion, +) + + +def _make_diffusion_spec( + diffusion_form: str, + *, + state_dim: int, + sigma=None, +): + if diffusion_form == "full": + return FullDiffusion(jnp.ones((state_dim, 1))) + if diffusion_form == "diag": + return DiagonalDiffusion(jnp.ones((state_dim,)), bm_dim=state_dim) + if diffusion_form == "scalar": + return ScalarDiffusion(jnp.array(1.0), bm_dim=state_dim) + if diffusion_form == "callable_full": + return FullDiffusion( + (lambda x, u, t: sigma[..., None, None] * jnp.ones((state_dim, 1))) + if sigma is not None + else (lambda x, u, t: jnp.ones((state_dim, 1))) + ) + if diffusion_form == "callable_diag": + return DiagonalDiffusion( + (lambda x, u, t: sigma[..., None] * jnp.ones((state_dim,))) + if sigma is not None + else (lambda x, u, t: jnp.ones((state_dim,))), + bm_dim=state_dim, + ) + if diffusion_form == "callable_scalar": + return ScalarDiffusion( + (lambda x, u, t: sigma[..., None]) + if sigma is not None + else (lambda x, u, t: jnp.array([1.0])), + bm_dim=state_dim, + ) + raise ValueError(f"Unknown diffusion form: {diffusion_form}") + + +@pytest.mark.parametrize( + "diffusion_form", + ["full", "diag", "scalar", "callable_full", "callable_diag", "callable_scalar"], +) +def test_bm_dim_resolved_outside_plate(diffusion_form): + """bm_dim is resolved correctly when not in a plate context.""" state_dim = 2 - bm_dim = 1 + expected_bm_dim = 1 if "full" in diffusion_form else state_dim + + diffusion = _make_diffusion_spec( + diffusion_form, + state_dim=state_dim, + ) state_evo = ContinuousTimeStateEvolution( drift=lambda x, u, t: -x, - diffusion_coefficient=lambda x, u, t: jnp.ones((state_dim, bm_dim)), + diffusion=diffusion, ) dynamics = DynamicalModel( initial_condition=dist.MultivariateNormal( @@ -25,62 +76,34 @@ def test_bm_dim_inferred_outside_plate(): x, 0.1 * jnp.eye(state_dim) ), ) - assert dynamics.state_evolution.bm_dim == bm_dim, ( - f"Expected bm_dim={bm_dim}, got {dynamics.state_evolution.bm_dim}" + assert dynamics.state_evolution.diffusion is not None + assert dynamics.state_evolution.diffusion.bm_dim == expected_bm_dim, ( + f"Expected bm_dim={expected_bm_dim}, got {dynamics.state_evolution.diffusion.bm_dim}" ) -def test_bm_dim_inferred_inside_plate(): - """bm_dim should be inferred when model is constructed inside a plate context.""" +@pytest.mark.parametrize( + "diffusion_form", + ["full", "diag", "scalar", "callable_full", "callable_diag", "callable_scalar"], +) +def test_bm_dim_resolved_inside_plate(diffusion_form): + """bm_dim should be resolved when model is constructed inside a plate context.""" state_dim = 2 - bm_dim = 1 + expected_bm_dim = 1 if "full" in diffusion_form else state_dim M = 3 def model(): with numpyro.plate("trajectories", M): sigma = numpyro.sample("sigma", dist.HalfNormal(1.0)) - - state_evo = ContinuousTimeStateEvolution( - drift=lambda x, u, t: -x, - diffusion_coefficient=lambda x, u, t: ( - sigma[..., None, None] * jnp.ones((state_dim, bm_dim)) - ), + diffusion = _make_diffusion_spec( + diffusion_form, + state_dim=state_dim, + sigma=sigma, ) - dynamics = DynamicalModel( - initial_condition=dist.MultivariateNormal( - jnp.zeros(state_dim), jnp.eye(state_dim) - ), - state_evolution=state_evo, - observation_model=lambda x, u, t: dist.MultivariateNormal( - x, 0.1 * jnp.eye(state_dim) - ), - ) - # This is the bug: bm_dim stays None inside plate context - assert dynamics.state_evolution.bm_dim is not None, ( - "bm_dim should not be None after DynamicalModel construction in plate" - ) - assert dynamics.state_evolution.bm_dim == bm_dim, ( - f"Expected bm_dim={bm_dim}, got {dynamics.state_evolution.bm_dim}" - ) - - # Run the model with seed to trigger numpyro.sample - with numpyro.handlers.seed(rng_seed=0): - model() - - -def test_bm_dim_inferred_inside_plate_unbatched_diffusion(): - """bm_dim should be inferred even when diffusion doesn't use batched params.""" - state_dim = 2 - bm_dim = 1 - M = 3 - - def model(): - with numpyro.plate("trajectories", M): - _ = numpyro.sample("sigma", dist.HalfNormal(1.0)) state_evo = ContinuousTimeStateEvolution( drift=lambda x, u, t: -x, - diffusion_coefficient=lambda x, u, t: jnp.ones((state_dim, bm_dim)), + diffusion=diffusion, ) dynamics = DynamicalModel( initial_condition=dist.MultivariateNormal( @@ -91,11 +114,11 @@ def model(): x, 0.1 * jnp.eye(state_dim) ), ) - assert dynamics.state_evolution.bm_dim is not None, ( + assert dynamics.state_evolution.diffusion is not None, ( "bm_dim should not be None after DynamicalModel construction in plate" ) - assert dynamics.state_evolution.bm_dim == bm_dim, ( - f"Expected bm_dim={bm_dim}, got {dynamics.state_evolution.bm_dim}" + assert dynamics.state_evolution.diffusion.bm_dim == expected_bm_dim, ( + f"Expected bm_dim={expected_bm_dim}, got {dynamics.state_evolution.diffusion.bm_dim}" ) with numpyro.handlers.seed(rng_seed=0): @@ -103,20 +126,17 @@ def model(): if __name__ == "__main__": - print("Test 1: bm_dim outside plate...") - test_bm_dim_inferred_outside_plate() - print(" PASSED") - - print("Test 2: bm_dim inside plate (unbatched diffusion)...") - try: - test_bm_dim_inferred_inside_plate_unbatched_diffusion() + for form in [ + "full", + "diag", + "scalar", + "callable_full", + "callable_diag", + "callable_scalar", + ]: + print(f"Testing {form} outside plate...") + test_bm_dim_resolved_outside_plate(form) print(" PASSED") - except AssertionError as e: - print(f" FAILED: {e}") - - print("Test 3: bm_dim inside plate (batched diffusion)...") - try: - test_bm_dim_inferred_inside_plate() + print(f"Testing {form} inside plate...") + test_bm_dim_resolved_inside_plate(form) print(" PASSED") - except AssertionError as e: - print(f" FAILED: {e}") diff --git a/tests/test_discretizers.py b/tests/test_discretizers.py index c73f5906..05dc2067 100644 --- a/tests/test_discretizers.py +++ b/tests/test_discretizers.py @@ -3,6 +3,7 @@ import jax.numpy as jnp import jax.random as jr import numpyro.distributions as dist +import pytest from numpyro.handlers import seed, trace from numpyro.infer import Predictive @@ -16,9 +17,12 @@ from dynestyx.inference.filters import Filter from dynestyx.models import ( ContinuousTimeStateEvolution, + DiagonalDiffusion, DiracIdentityObservation, DynamicalModel, + FullDiffusion, GaussianStateEvolution, + ScalarDiffusion, ) from dynestyx.models.observations import LinearGaussianObservation from dynestyx.solvers import euler_maruyama_loc_cov @@ -26,11 +30,68 @@ def _ctse_1d_zero_drift_unit_diffusion() -> ContinuousTimeStateEvolution: - return ContinuousTimeStateEvolution( + cte = ContinuousTimeStateEvolution( drift=lambda x, u, t: jnp.zeros_like(x), - diffusion_coefficient=lambda x, u, t: jnp.ones((1, 1)), - bm_dim=1, + diffusion=FullDiffusion(lambda x, u, t: jnp.ones((1, 1)), bm_dim=1), ) + dynamics = DynamicalModel( + initial_condition=dist.MultivariateNormal(jnp.zeros(1), jnp.eye(1)), + state_evolution=cte, + observation_model=LinearGaussianObservation( + H=jnp.array([[1.0]]), + R=jnp.array([[1.0]]), + ), + control_dim=0, + ) + assert isinstance(dynamics.state_evolution, ContinuousTimeStateEvolution) + return dynamics.state_evolution + + +def _ctse_2d_zero_drift(diffusion_form: str) -> ContinuousTimeStateEvolution: + if diffusion_form == "full": + cte = ContinuousTimeStateEvolution( + drift=lambda x, u, t: jnp.zeros_like(x), + diffusion=FullDiffusion(jnp.eye(2), bm_dim=2), + ) + elif diffusion_form == "diag": + cte = ContinuousTimeStateEvolution( + drift=lambda x, u, t: jnp.zeros_like(x), + diffusion=DiagonalDiffusion(jnp.ones((2,)), bm_dim=2), + ) + elif diffusion_form == "scalar": + cte = ContinuousTimeStateEvolution( + drift=lambda x, u, t: jnp.zeros_like(x), + diffusion=ScalarDiffusion(jnp.array(1.0), bm_dim=2), + ) + elif diffusion_form == "callable_full": + cte = ContinuousTimeStateEvolution( + drift=lambda x, u, t: jnp.zeros_like(x), + diffusion=FullDiffusion(lambda x, u, t: jnp.eye(2), bm_dim=2), + ) + elif diffusion_form == "callable_diag": + cte = ContinuousTimeStateEvolution( + drift=lambda x, u, t: jnp.zeros((2,)), + diffusion=DiagonalDiffusion(lambda x, u, t: jnp.ones((2,)), bm_dim=2), + ) + elif diffusion_form == "callable_scalar": + cte = ContinuousTimeStateEvolution( + drift=lambda x, u, t: jnp.zeros_like(x), + diffusion=ScalarDiffusion(lambda x, u, t: jnp.array(1.0), bm_dim=2), + ) + else: + raise ValueError(f"Unknown diffusion form: {diffusion_form}") + + dynamics = DynamicalModel( + initial_condition=dist.MultivariateNormal(jnp.zeros(2), jnp.eye(2)), + state_evolution=cte, + observation_model=LinearGaussianObservation( + H=jnp.eye(2), + R=jnp.eye(2), + ), + control_dim=0, + ) + assert isinstance(dynamics.state_evolution, ContinuousTimeStateEvolution) + return dynamics.state_evolution def test_euler_maruyama_returns_gaussian_state_evolution_with_callable_cov(): @@ -78,6 +139,21 @@ def test_euler_maruyama_loc_cov_single_pass_consistent_with_gaussian_state_evolu assert jnp.allclose(d_dict["cov"], d.covariance_matrix) +@pytest.mark.parametrize( + "diffusion_form", + ["full", "diag", "scalar", "callable_full", "callable_diag", "callable_scalar"], +) +def test_euler_maruyama_structured_diffusions_match_dense_covariance(diffusion_form): + cte = _ctse_2d_zero_drift(diffusion_form) + x = jnp.array([0.3, -0.1]) + t0 = jnp.array(1.0) + t1 = jnp.array(3.0) + out = euler_maruyama_loc_cov(cte, x, None, t0, t1) + expected_cov = (t1 - t0) * jnp.eye(2) + assert jnp.allclose(out["loc"], x) + assert jnp.allclose(out["cov"], expected_cov) + + def test_discretized_dirac_observations_preserve_state_dimension(): obs_times = jnp.arange(0.0, 0.05, 0.01) predictive = Predictive( diff --git a/tests/test_filters.py b/tests/test_filters.py index e52987db..048aa0f5 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -21,7 +21,7 @@ from dynestyx.inference.integrations.cuthbert.discrete import ( run_discrete_filter as run_cuthbert_discrete_filter, ) -from dynestyx.models import ContinuousTimeStateEvolution, DynamicalModel +from dynestyx.models import ContinuousTimeStateEvolution, DynamicalModel, FullDiffusion from dynestyx.simulators import DiscreteTimeSimulator from tests.fixtures import ( _squeeze_sim_dims, @@ -102,7 +102,7 @@ def model(): initial_condition=dist.LogNormal(loc=jnp.zeros(1), scale=jnp.ones(1)), state_evolution=ContinuousTimeStateEvolution( drift=lambda x, u, t: -0.3 * jnp.sin(x), - diffusion_coefficient=lambda x, u, t: 0.1 * jnp.eye(1), + diffusion=FullDiffusion(lambda x, u, t: 0.1 * jnp.eye(1)), ), observation_model=lambda x, u, t: dist.Poisson(rate=jnp.exp(x[0] + bias)), ) diff --git a/tests/test_hierarchical_simulator_discretizer_smokes.py b/tests/test_hierarchical_simulator_discretizer_smokes.py index 9d50e8f5..6749f046 100644 --- a/tests/test_hierarchical_simulator_discretizer_smokes.py +++ b/tests/test_hierarchical_simulator_discretizer_smokes.py @@ -30,6 +30,7 @@ ContinuousTimeStateEvolution, DiracIdentityObservation, DynamicalModel, + FullDiffusion, GaussianStateEvolution, LinearGaussianObservation, ) @@ -102,6 +103,7 @@ def _plate_continuous_sde_model( alpha = numpyro.sample("alpha", dist.Uniform(0.1, 0.8)) A_base = jnp.array([[0.0, 0.1], [0.1, 0.8]]) A = jnp.repeat(A_base[None], M, axis=0).at[:, 0, 0].set(alpha) + # Rectangular diffusion (bm_dim < state_dim) is padded in cd-dynamax integration. L = 0.2 * jnp.array([[1.0], [0.5]]) H = jnp.array([[1.0, 0.0]]) R = jnp.array([[0.25]]) @@ -131,9 +133,7 @@ def _plate_continuous_ode_model( initial_condition=dist.MultivariateNormal( loc=jnp.zeros(2), covariance_matrix=jnp.eye(2) ), - state_evolution=ContinuousTimeStateEvolution( - drift=drift, diffusion_coefficient=None - ), + state_evolution=ContinuousTimeStateEvolution(drift=drift), observation_model=LinearGaussianObservation( H=jnp.array([[1.0, 0.0]]), R=jnp.array([[0.25]]) ), @@ -557,7 +557,7 @@ def _plate_continuous_dirac_for_discretizer_model( loc=jnp.zeros(2), covariance_matrix=jnp.eye(2) ), state_evolution=ContinuousTimeStateEvolution( - drift=AffineDrift(A=A), diffusion_coefficient=lambda x, u, t: L + drift=AffineDrift(A=A), diffusion=FullDiffusion(lambda x, u, t: L) ), observation_model=DiracIdentityObservation(), ) @@ -592,7 +592,7 @@ def _nested_plate_continuous_dirac_for_discretizer_model( loc=jnp.zeros(2), covariance_matrix=jnp.eye(2) ), state_evolution=ContinuousTimeStateEvolution( - drift=AffineDrift(A=A), diffusion_coefficient=lambda x, u, t: L + drift=AffineDrift(A=A), diffusion=FullDiffusion(lambda x, u, t: L) ), observation_model=DiracIdentityObservation(), ) diff --git a/tests/test_hierarchical_simulator_inference_smokes.py b/tests/test_hierarchical_simulator_inference_smokes.py index 9eeb46d0..5105592a 100644 --- a/tests/test_hierarchical_simulator_inference_smokes.py +++ b/tests/test_hierarchical_simulator_inference_smokes.py @@ -123,9 +123,7 @@ def _hierarchical_ode_model( initial_condition=dist.MultivariateNormal( loc=jnp.zeros(2), covariance_matrix=jnp.eye(2) ), - state_evolution=ContinuousTimeStateEvolution( - drift=drift, diffusion_coefficient=None - ), + state_evolution=ContinuousTimeStateEvolution(drift=drift), observation_model=LinearGaussianObservation( H=jnp.array([[0.0, 1.0]]), R=jnp.eye(1) ), diff --git a/tests/test_models_core.py b/tests/test_models_core.py index 2f97300b..7080bb7d 100644 --- a/tests/test_models_core.py +++ b/tests/test_models_core.py @@ -6,8 +6,23 @@ import pytest import dynestyx as dsx -from dynestyx.inference.integrations.cd_dynamax.utils import gaussian_to_nlgssm_params -from dynestyx.models.core import ContinuousTimeStateEvolution, DynamicalModel +from dynestyx.inference.integrations.cd_dynamax.utils import ( + _as_cd_dynamax_diffusion_coefficient, + dsx_to_cd_dynamax, + gaussian_to_nlgssm_params, +) +from dynestyx.models.core import ( + ContinuousTimeStateEvolution, + DeterministicContinuousTimeStateEvolution, + DynamicalModel, + StochasticContinuousTimeStateEvolution, +) +from dynestyx.models.diffusions import ( + DiagonalDiffusion, + Diffusion, + FullDiffusion, + ScalarDiffusion, +) from dynestyx.simulators import DiscreteTimeSimulator @@ -233,14 +248,14 @@ def test_categorical_state_override_incompatible_raises() -> None: def test_continuous_state_evolution_infers_bm_dim() -> None: state_evolution = ContinuousTimeStateEvolution( drift=lambda x, u, t: x, - diffusion_coefficient=lambda x, u, t: jnp.eye(2, 3), + diffusion=FullDiffusion(lambda x, u, t: jnp.eye(2, 3)), ) def observation_model(x, u, t): del u, t return dist.MultivariateNormal(loc=x, covariance_matrix=jnp.eye(2)) - _ = DynamicalModel( + dynamics = DynamicalModel( initial_condition=dist.MultivariateNormal( loc=jnp.zeros(2), covariance_matrix=jnp.eye(2), @@ -250,21 +265,21 @@ def observation_model(x, u, t): control_dim=0, ) - assert state_evolution.bm_dim == 3 + assert isinstance(dynamics.state_evolution, StochasticContinuousTimeStateEvolution) + assert dynamics.state_evolution.diffusion.bm_dim == 3 def test_continuous_state_evolution_bm_dim_override_compatible() -> None: state_evolution = ContinuousTimeStateEvolution( drift=lambda x, u, t: x, - diffusion_coefficient=lambda x, u, t: jnp.eye(2, 3), - bm_dim=3, + diffusion=FullDiffusion(lambda x, u, t: jnp.eye(2, 3), bm_dim=3), ) def observation_model(x, u, t): del u, t return dist.MultivariateNormal(loc=x, covariance_matrix=jnp.eye(2)) - _ = DynamicalModel( + dynamics = DynamicalModel( initial_condition=dist.MultivariateNormal( loc=jnp.zeros(2), covariance_matrix=jnp.eye(2), @@ -274,14 +289,14 @@ def observation_model(x, u, t): control_dim=0, ) - assert state_evolution.bm_dim == 3 + assert isinstance(dynamics.state_evolution, StochasticContinuousTimeStateEvolution) + assert dynamics.state_evolution.diffusion.bm_dim == 3 def test_continuous_state_evolution_bm_dim_override_mismatch_raises() -> None: state_evolution = ContinuousTimeStateEvolution( drift=lambda x, u, t: x, - diffusion_coefficient=lambda x, u, t: jnp.eye(2, 3), - bm_dim=2, + diffusion=FullDiffusion(lambda x, u, t: jnp.eye(2, 3), bm_dim=2), ) def observation_model(x, u, t): @@ -289,7 +304,7 @@ def observation_model(x, u, t): return dist.MultivariateNormal(loc=x, covariance_matrix=jnp.eye(2)) with pytest.raises( - ValueError, match="bm_dim does not match inferred diffusion_coefficient" + ValueError, match="bm_dim does not match inferred diffusion output shape" ): DynamicalModel( initial_condition=dist.MultivariateNormal( @@ -302,39 +317,183 @@ def observation_model(x, u, t): ) -def test_continuous_state_evolution_bm_dim_without_diffusion_raises() -> None: +def test_continuous_state_evolution_without_diffusion_is_allowed() -> None: state_evolution = ContinuousTimeStateEvolution( drift=lambda x, u, t: x, - bm_dim=2, ) def observation_model(x, u, t): del u, t return dist.MultivariateNormal(loc=x, covariance_matrix=jnp.eye(2)) - with pytest.raises( - ValueError, match="bm_dim cannot be set when diffusion_coefficient" - ): + dynamics = DynamicalModel( + initial_condition=dist.MultivariateNormal( + loc=jnp.zeros(2), + covariance_matrix=jnp.eye(2), + ), + state_evolution=state_evolution, + observation_model=observation_model, + control_dim=0, + ) + + assert isinstance( + dynamics.state_evolution, DeterministicContinuousTimeStateEvolution + ) + + +def test_continuous_state_evolution_requires_explicit_bm_dim_for_diag_diffusion() -> ( + None +): + with pytest.raises(ValueError, match="Diagonal diffusion requires explicit bm_dim"): + DiagonalDiffusion(jnp.ones((2,)), bm_dim=None) # type: ignore[arg-type] + + +def test_continuous_state_evolution_requires_explicit_bm_dim_for_scalar_diffusion() -> ( + None +): + with pytest.raises(ValueError, match="Scalar diffusion requires explicit bm_dim"): + ScalarDiffusion(jnp.array(0.3), bm_dim=None) # type: ignore[arg-type] + + +def test_continuous_state_evolution_rejects_invalid_shorthand_trailing_dim() -> None: + state_evolution = ContinuousTimeStateEvolution( + drift=lambda x, u, t: x, + diffusion=DiagonalDiffusion(jnp.ones((3,)), bm_dim=2), + ) + + with pytest.raises(ValueError, match="Diagonal diffusion must have trailing shape"): DynamicalModel( initial_condition=dist.MultivariateNormal( loc=jnp.zeros(2), covariance_matrix=jnp.eye(2), ), state_evolution=state_evolution, - observation_model=observation_model, + observation_model=_observation_model_2d, control_dim=0, ) -def test_discrete_state_evolution_bm_dim_override_raises() -> None: +@pytest.mark.parametrize( + ( + "diffusion", + "expected_matrix", + "expected_cov", + ), + [ + ( + ScalarDiffusion(jnp.array(0.5), bm_dim=1), + 0.5 * jnp.ones((2, 1)), + 0.25 * jnp.ones((2, 2)), + ), + ( + ScalarDiffusion(jnp.array(0.5), bm_dim=2), + 0.5 * jnp.eye(2), + 0.25 * jnp.eye(2), + ), + ( + DiagonalDiffusion(jnp.array([0.2, 0.4]), bm_dim=1), + jnp.array([[0.2], [0.4]]), + jnp.array([[0.04, 0.08], [0.08, 0.16]]), + ), + ( + DiagonalDiffusion(jnp.array([0.2, 0.4]), bm_dim=2), + jnp.diag(jnp.array([0.2, 0.4])), + jnp.diag(jnp.array([0.04, 0.16])), + ), + ], +) +def test_diffusion_helpers_preserve_shorthand_semantics( + diffusion, + expected_matrix, + expected_cov, +) -> None: + evaluated = diffusion.evaluate(x=jnp.zeros(2), u=None, t=jnp.array(0.0)) + + assert jnp.allclose( + evaluated.as_matrix(state_dim=2), + expected_matrix, + ) + assert jnp.allclose( + evaluated.gram_matrix(state_dim=2), + expected_cov, + ) + + +@pytest.mark.parametrize("callable_form", ["scalar", "diag", "full"]) +def test_diffusion_helpers_support_callable_shorthands(callable_form: str) -> None: + diffusion: Diffusion + if callable_form == "scalar": + diffusion = ScalarDiffusion(lambda x, u, t: 0.5 + 0.0 * t, bm_dim=2) + expected = 0.5 * jnp.eye(2) + elif callable_form == "diag": + diffusion = DiagonalDiffusion( + lambda x, u, t: jnp.array([0.2, 0.4]) + 0.0 * t, + bm_dim=2, + ) + expected = jnp.diag(jnp.array([0.2, 0.4])) + else: + diffusion = FullDiffusion(lambda x, u, t: jnp.eye(2) * (1.0 + 0.0 * t)) + diffusion = diffusion.resolve_metadata( + state_dim=2, + x_probe=jnp.zeros(2), + u_probe=None, + t_probe=jnp.array(1.0), + ) + expected = jnp.eye(2) + + evaluated = diffusion.evaluate(x=jnp.zeros(2), u=None, t=jnp.array(1.0)) + + assert jnp.allclose(evaluated.as_matrix(state_dim=2), expected) + + +def test_cd_dynamax_rejects_diffusion_with_bm_dim_exceeds_state_dim() -> None: + state_evolution = StochasticContinuousTimeStateEvolution( + drift=lambda x, u, t: x, + diffusion=FullDiffusion(jnp.ones((1, 2)), bm_dim=2), + ) + with pytest.raises(ValueError, match="bm_dim <= state_dim"): + _as_cd_dynamax_diffusion_coefficient(state_evolution, state_dim=1)( + jnp.zeros(1), + None, + jnp.array(0.0), + ) + + +def test_continuous_cd_dynamax_rejects_diffusion_with_bm_dim_exceeds_state_dim_early() -> ( + None +): + state_evolution = ContinuousTimeStateEvolution( + drift=lambda x, u, t: x, + diffusion=FullDiffusion(jnp.ones((1, 2)), bm_dim=2), + ) + dynamics = DynamicalModel( + initial_condition=dist.MultivariateNormal( + loc=jnp.zeros(1), + covariance_matrix=jnp.eye(1), + ), + state_evolution=state_evolution, + observation_model=dsx.LinearGaussianObservation( + H=jnp.array([[1.0]]), + R=jnp.array([[1.0]]), + ), + control_dim=0, + ) + + with pytest.raises( + ValueError, match="Continuous cd-dynamax filters require bm_dim <= state_dim" + ): + dsx_to_cd_dynamax(dynamics) + + +def test_discrete_state_evolution_diffusion_override_raises() -> None: def state_evolution(x, u, t_now, t_next): del u, t_now, t_next return dist.MultivariateNormal(loc=x, covariance_matrix=jnp.eye(2)) - state_evolution.bm_dim = 2 # type: ignore[attr-defined] + state_evolution.diffusion = ScalarDiffusion(0.1, bm_dim=2) # type: ignore[attr-defined] with pytest.raises( - ValueError, match="bm_dim can only be set for continuous-time models" + ValueError, match="diffusion can only be set for continuous-time models" ): DynamicalModel( initial_condition=_initial_condition_2d(), diff --git a/tests/test_science/test_discreteTime_generic.py b/tests/test_science/test_discreteTime_generic.py index 31444a61..ea73b6b7 100644 --- a/tests/test_science/test_discreteTime_generic.py +++ b/tests/test_science/test_discreteTime_generic.py @@ -31,14 +31,17 @@ def test_mcmc_inference(data_conditioned_discrete_time_l63, num_samples): # noq if SAVE_FIG and OUTPUT_DIR is not None: import matplotlib.pyplot as plt + states = synthetic["states"] + observations = synthetic["observations"] + plt.plot( - obs_times.squeeze(0), - synthetic["states"].squeeze(0)[:, 0], + obs_times, + states if states.ndim == 1 else states[:, 0], label="x[0]", ) plt.plot( - obs_times.squeeze(0), - synthetic["observations"].squeeze(0)[:, 0], + obs_times, + observations if observations.ndim == 1 else observations[:, 0], label="observations", linestyle="--", ) diff --git a/tests/test_science/test_discreteTime_generic_auto.py b/tests/test_science/test_discreteTime_generic_auto.py index 6aa0b7a0..d881e78d 100644 --- a/tests/test_science/test_discreteTime_generic_auto.py +++ b/tests/test_science/test_discreteTime_generic_auto.py @@ -31,14 +31,17 @@ def test_mcmc_inference(data_conditioned_discrete_time_l63_auto, num_samples): if SAVE_FIG and OUTPUT_DIR is not None: import matplotlib.pyplot as plt + states = synthetic["states"] + observations = synthetic["observations"] + plt.plot( - obs_times.squeeze(0), - synthetic["states"].squeeze(0)[:, 0], + obs_times, + states if states.ndim == 1 else states[:, 0], label="x[0]", ) plt.plot( - obs_times.squeeze(0), - synthetic["observations"].squeeze(0)[:, 0], + obs_times, + observations if observations.ndim == 1 else observations[:, 0], label="observations", linestyle="--", ) diff --git a/tests/test_science/test_discrete_time_l63_mcmc.py b/tests/test_science/test_discrete_time_l63_mcmc.py index c2e4d186..959dbd74 100644 --- a/tests/test_science/test_discrete_time_l63_mcmc.py +++ b/tests/test_science/test_discrete_time_l63_mcmc.py @@ -34,14 +34,17 @@ def test_mcmc_inference(data_conditioned_discrete_time_l63_filter, num_samples): if SAVE_FIG and OUTPUT_DIR is not None: import matplotlib.pyplot as plt + states = synthetic["states"] + observations = synthetic["observations"] + plt.plot( - obs_times.squeeze(0), - synthetic["states"].squeeze(0)[:, 0], + obs_times, + states if states.ndim == 1 else states[:, 0], label="x[0]", ) plt.plot( - obs_times.squeeze(0), - synthetic["observations"].squeeze(0)[:, 0], + obs_times, + observations if observations.ndim == 1 else observations[:, 0], label="observations", linestyle="--", ) diff --git a/tests/test_science/test_discrete_time_l63_svi_pf.py b/tests/test_science/test_discrete_time_l63_svi_pf.py index 1504432f..6c3bd9e6 100644 --- a/tests/test_science/test_discrete_time_l63_svi_pf.py +++ b/tests/test_science/test_discrete_time_l63_svi_pf.py @@ -36,14 +36,17 @@ def test_svi_inference(data_conditioned_discrete_time_l63_filter_pf, num_steps): if SAVE_FIG and OUTPUT_DIR is not None: import matplotlib.pyplot as plt + states = synthetic["states"] + observations = synthetic["observations"] + plt.plot( - obs_times.squeeze(0), - synthetic["states"].squeeze(0)[:, 0], + obs_times, + states if states.ndim == 1 else states[:, 0], label="x[0]", ) plt.plot( - obs_times.squeeze(0), - synthetic["observations"].squeeze(0)[:, 0], + obs_times, + observations if observations.ndim == 1 else observations[:, 0], label="observations", linestyle="--", ) diff --git a/tests/test_science/test_hmm.py b/tests/test_science/test_hmm.py index d78d889f..130a72b4 100644 --- a/tests/test_science/test_hmm.py +++ b/tests/test_science/test_hmm.py @@ -27,9 +27,9 @@ def test_mcmc_inference(data_conditioned_hmm, num_samples): # noqa: F811 if SAVE_FIG and OUTPUT_DIR is not None: plot_hmm_states_and_observations( - times=obs_times.squeeze(0), - x=synthetic["states"].squeeze(0), - y=synthetic["observations"].squeeze(0), + times=obs_times, + x=synthetic["states"], + y=synthetic["observations"], save_path=OUTPUT_DIR / "data_generation.png", ) diff --git a/tests/test_science/test_l63_ODE_mcmc.py b/tests/test_science/test_l63_ODE_mcmc.py index f91cdc77..37476a59 100644 --- a/tests/test_science/test_l63_ODE_mcmc.py +++ b/tests/test_science/test_l63_ODE_mcmc.py @@ -37,14 +37,17 @@ def test_mcmc_inference( if SAVE_FIG and OUTPUT_DIR is not None: import matplotlib.pyplot as plt + states = synthetic["states"] + observations = synthetic["observations"] + plt.plot( - obs_times.squeeze(0), - synthetic["states"].squeeze(0)[:, 0], + obs_times, + states if states.ndim == 1 else states[:, 0], label="x[0]", ) plt.plot( - obs_times.squeeze(0), - synthetic["observations"].squeeze(0)[:, 0], + obs_times, + observations if observations.ndim == 1 else observations[:, 0], label="observations", linestyle="--", ) diff --git a/tests/test_science/test_l63_SDE_mcmc.py b/tests/test_science/test_l63_SDE_mcmc.py index bc6d8b4e..cf67de30 100644 --- a/tests/test_science/test_l63_SDE_mcmc.py +++ b/tests/test_science/test_l63_SDE_mcmc.py @@ -37,14 +37,17 @@ def test_mcmc_inference(data_conditioned_continuous_time_stochastic_l63, num_sam if SAVE_FIG and OUTPUT_DIR is not None: import matplotlib.pyplot as plt + states = synthetic["states"] + observations = synthetic["observations"] + plt.plot( - obs_times.squeeze(0), - synthetic["states"].squeeze(0)[:, 0], + obs_times, + states if states.ndim == 1 else states[:, 0], label="x[0]", ) plt.plot( - obs_times.squeeze(0), - synthetic["observations"].squeeze(0)[:, 0], + obs_times, + observations if observations.ndim == 1 else observations[:, 0], label="observations", linestyle="--", ) diff --git a/tests/test_science/test_l63_mcmc_dpf.py b/tests/test_science/test_l63_mcmc_dpf.py index 8cc70995..fe531e5f 100644 --- a/tests/test_science/test_l63_mcmc_dpf.py +++ b/tests/test_science/test_l63_mcmc_dpf.py @@ -29,14 +29,17 @@ def test_mcmc_inference(data_conditioned_continuous_time_l63_dpf, num_samples): if SAVE_FIG and OUTPUT_DIR is not None: import matplotlib.pyplot as plt + states = synthetic["states"] + observations = synthetic["observations"] + plt.plot( - obs_times.squeeze(0), - synthetic["states"].squeeze(0)[:, 0], + obs_times, + states if states.ndim == 1 else states[:, 0], label="x[0]", ) plt.plot( - obs_times.squeeze(0), - synthetic["observations"].squeeze(0)[:, 0], + obs_times, + observations if observations.ndim == 1 else observations[:, 0], label="observations", linestyle="--", ) diff --git a/tests/test_science/test_lti_continuous_simplified.py b/tests/test_science/test_lti_continuous_simplified.py index d8130e52..529e71fe 100644 --- a/tests/test_science/test_lti_continuous_simplified.py +++ b/tests/test_science/test_lti_continuous_simplified.py @@ -40,19 +40,22 @@ def test_mcmc_inference( if SAVE_FIG and OUTPUT_DIR is not None: import matplotlib.pyplot as plt + states = synthetic["states"] + observations = synthetic["observations"] + plt.plot( obs_times, - synthetic["states"][:, 0], + states[:, 0], label="x[0]", ) plt.plot( obs_times, - synthetic["states"][:, 1], + states[:, 1], label="x[1]", ) plt.plot( obs_times, - synthetic["observations"][:, 0], + observations if observations.ndim == 1 else observations[:, 0], label="observations", linestyle="--", ) diff --git a/tests/test_science/test_lti_discrete_simplified.py b/tests/test_science/test_lti_discrete_simplified.py index bb8aaaad..e7715a06 100644 --- a/tests/test_science/test_lti_discrete_simplified.py +++ b/tests/test_science/test_lti_discrete_simplified.py @@ -40,19 +40,22 @@ def test_mcmc_inference( if SAVE_FIG and OUTPUT_DIR is not None: import matplotlib.pyplot as plt + states = synthetic["states"] + observations = synthetic["observations"] + plt.plot( - obs_times.squeeze(0), - synthetic["states"].squeeze(0)[:, 0], + obs_times, + states[:, 0], label="x[0]", ) plt.plot( - obs_times.squeeze(0), - synthetic["states"].squeeze(0)[:, 1], + obs_times, + states[:, 1], label="x[1]", ) plt.plot( - obs_times.squeeze(0), - synthetic["observations"].squeeze(0)[:, 0], + obs_times, + observations if observations.ndim == 1 else observations[:, 0], label="observations", linestyle="--", ) diff --git a/tests/test_science/test_lti_gaussian.py b/tests/test_science/test_lti_gaussian.py index a603a616..5ec6c061 100644 --- a/tests/test_science/test_lti_gaussian.py +++ b/tests/test_science/test_lti_gaussian.py @@ -36,19 +36,22 @@ def test_mcmc_inference(data_conditioned_continuous_time_lti_gaussian, num_sampl if SAVE_FIG and OUTPUT_DIR is not None: import matplotlib.pyplot as plt + states = synthetic["states"] + observations = synthetic["observations"] + plt.plot( - obs_times.squeeze(0), - synthetic["states"].squeeze(0)[:, 0], + obs_times, + states[:, 0], label="x[0]", ) plt.plot( - obs_times.squeeze(0), - synthetic["states"].squeeze(0)[:, 1], + obs_times, + states[:, 1], label="x[1]", ) plt.plot( - obs_times.squeeze(0), - synthetic["observations"].squeeze(0)[:, 0], + obs_times, + observations if observations.ndim == 1 else observations[:, 0], label="observations", linestyle="--", ) diff --git a/tests/test_science/test_lti_gaussian_dpf.py b/tests/test_science/test_lti_gaussian_dpf.py index eeac060d..f564796e 100644 --- a/tests/test_science/test_lti_gaussian_dpf.py +++ b/tests/test_science/test_lti_gaussian_dpf.py @@ -33,19 +33,22 @@ def test_mcmc_inference(data_conditioned_continuous_time_lti_gaussian_dpf, num_s if SAVE_FIG and OUTPUT_DIR is not None: import matplotlib.pyplot as plt + states = synthetic["states"] + observations = synthetic["observations"] + plt.plot( - obs_times.squeeze(0), - synthetic["states"].squeeze(0)[:, 0], + obs_times, + states[:, 0], label="x[0]", ) plt.plot( - obs_times.squeeze(0), - synthetic["states"].squeeze(0)[:, 1], + obs_times, + states[:, 1], label="x[1]", ) plt.plot( - obs_times.squeeze(0), - synthetic["observations"].squeeze(0)[:, 0], + obs_times, + observations if observations.ndim == 1 else observations[:, 0], label="observations", linestyle="--", ) diff --git a/tests/test_science/test_nonlinear_discretized_hierarchical_multitraj_ekf.py b/tests/test_science/test_nonlinear_discretized_hierarchical_multitraj_ekf.py index efdd2360..82bae61d 100644 --- a/tests/test_science/test_nonlinear_discretized_hierarchical_multitraj_ekf.py +++ b/tests/test_science/test_nonlinear_discretized_hierarchical_multitraj_ekf.py @@ -16,7 +16,7 @@ from dynestyx.discretizers import Discretizer from dynestyx.inference.filter_configs import EKFConfig from dynestyx.inference.filters import Filter -from dynestyx.models import ContinuousTimeStateEvolution, DynamicalModel +from dynestyx.models import ContinuousTimeStateEvolution, DynamicalModel, FullDiffusion from dynestyx.models.observations import LinearGaussianObservation from dynestyx.simulators import DiscreteTimeSimulator from tests.test_utils import get_output_dir @@ -43,7 +43,7 @@ def _make_ct_nonlinear_dynamics(beta: jnp.ndarray) -> DynamicalModel: initial_condition=dist.MultivariateNormal(jnp.zeros(1), 0.35 * jnp.eye(1)), state_evolution=ContinuousTimeStateEvolution( drift=_PlateNonlinearDrift(beta=beta), - diffusion_coefficient=lambda x, u, t: jnp.sqrt(0.06) * jnp.eye(1), + diffusion=FullDiffusion(lambda x, u, t: jnp.sqrt(0.06) * jnp.eye(1)), ), observation_model=LinearGaussianObservation( H=jnp.array([[1.0]]), diff --git a/tests/test_science/test_ode_hierarchical_simulator_inference.py b/tests/test_science/test_ode_hierarchical_simulator_inference.py index f48d122b..b751e318 100644 --- a/tests/test_science/test_ode_hierarchical_simulator_inference.py +++ b/tests/test_science/test_ode_hierarchical_simulator_inference.py @@ -64,7 +64,6 @@ def hierarchical_ode_model( ), state_evolution=ContinuousTimeStateEvolution( drift=_DampedNonlinearOscillatorDrift(alpha=alpha), - diffusion_coefficient=None, ), observation_model=LinearGaussianObservation( H=jnp.eye(2), R=(0.05**2) * jnp.eye(2)