Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
021550d
support Array and lambda diffusions (special support for scalar/diag/…
mattlevine22 Apr 27, 2026
9652cbf
adding back padding of diffusion-coefficient in CD-Dynamax integration
mattlevine22 Apr 27, 2026
8c6726e
Update dynestyx/inference/integrations/cd_dynamax/utils.py
mattlevine22 Apr 27, 2026
38df332
Update dynestyx/models/diffusions.py
mattlevine22 Apr 27, 2026
aa50c38
docs: rename misleading test names to reflect bm_dim > state_dim cons…
Copilot Apr 27, 2026
ecd987e
plz lint
mattlevine22 Apr 27, 2026
bc0c252
Merge branch 'main' into ml-feature-212
mattlevine22 Apr 28, 2026
375cf59
get diffusion info at init
mattlevine22 Apr 28, 2026
0fa5667
give discretizer fallback if dynamicalModel init hadnt been run (supp…
mattlevine22 Apr 28, 2026
bf1f456
new Diffusion class with subtypes.
mattlevine22 May 10, 2026
ef5dab6
merge with main (plate upgrades)
mattlevine22 May 10, 2026
20d1d37
dont use cast, just assert bm_dim is not None (fight w linter...could…
mattlevine22 May 11, 2026
e9ea930
merge with main (smoothers)
mattlevine22 May 14, 2026
bc41705
please lint
mattlevine22 May 14, 2026
3f37300
rename as gram
mattlevine22 May 14, 2026
657a190
rename func to _coerce_to_param_dtype
mattlevine22 May 14, 2026
958a2a7
simplify edit to Quick example
mattlevine22 May 14, 2026
0e04441
dont change faq
mattlevine22 May 14, 2026
2cfab46
simplify notebook changes for PR
mattlevine22 May 14, 2026
90a2813
simplify notebook changes for PR
mattlevine22 May 14, 2026
c8c5cd8
simplify notebook changes for PR
mattlevine22 May 14, 2026
2e02f2f
improve API documentation for diffusions
mattlevine22 May 14, 2026
740750e
improve API documentation for diffusions
mattlevine22 May 14, 2026
2ebb45c
Update dynestyx/models/diffusions.py
mattlevine22 May 15, 2026
3612c92
simplify code and improve docs
mattlevine22 May 15, 2026
a4cee35
clarify resolve metadata docstring
mattlevine22 May 15, 2026
19a344b
Update dynestyx/inference/integrations/cd_dynamax/utils.py
mattlevine22 May 15, 2026
6d39004
probing and errors for cd-dynamax diffusion plumbing
mattlevine22 May 15, 2026
d2f8564
fix lint
mattlevine22 May 15, 2026
83641cb
Update dynestyx/models/core.py
mattlevine22 May 15, 2026
8501a17
streamline DynamicalModel init
mattlevine22 May 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 49 additions & 17 deletions dynestyx/discretizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,47 @@
DynamicalModel,
GaussianStateEvolution,
)
from dynestyx.models.checkers import _infer_bm_dim
from dynestyx.models.checkers import _resolve_ctse_diffusion_metadata
from dynestyx.solvers import euler_maruyama_loc_cov
from dynestyx.types import FunctionOfTime


def _ensure_ctse_diffusion_metadata(
cte: ContinuousTimeStateEvolution,
*,
state_dim: int,
control_dim: int = 0,
t0=None,
) -> ContinuousTimeStateEvolution:
"""Resolve and set diffusion metadata on a CTSE when missing."""
if cte.diffusion_coefficient is None:
return cte
if cte.diffusion_type is not None and cte.bm_dim is not None:
return cte

x0 = jnp.zeros((state_dim,))
u0 = None if control_dim == 0 else jnp.zeros((control_dim,))
probe_t0 = jnp.array(0.0) if t0 is None else jnp.asarray(t0)
resolved = _resolve_ctse_diffusion_metadata(cte, state_dim, x0, u0, probe_t0)
if resolved is not None:
resolved_type, resolved_bm_dim = resolved
object.__setattr__(cte, "diffusion_type", resolved_type)
object.__setattr__(cte, "bm_dim", resolved_bm_dim)
return cte


def _ensure_ctse_bm_dim(dynamics: DynamicalModel) -> DynamicalModel:
"""Infer and set bm_dim when CT dynamics are built under active plates."""
"""Resolve diffusion metadata when CT dynamics are built under active plates."""
if not isinstance(dynamics.state_evolution, ContinuousTimeStateEvolution):
return dynamics

cte = dynamics.state_evolution
if cte.diffusion_coefficient is None or cte.bm_dim is not None:
return dynamics

x0 = jnp.zeros((dynamics.state_dim,))
u0 = None if dynamics.control_dim == 0 else jnp.zeros((dynamics.control_dim,))
t0 = jnp.array(0.0) if dynamics.t0 is None else jnp.asarray(dynamics.t0)
inferred_bm_dim = _infer_bm_dim(cte, dynamics.state_dim, x0, u0, t0)
if inferred_bm_dim is not None:
object.__setattr__(cte, "bm_dim", inferred_bm_dim)
_ensure_ctse_diffusion_metadata(
cte,
state_dim=dynamics.state_dim,
control_dim=dynamics.control_dim,
t0=dynamics.t0,
)
return dynamics


Expand All @@ -46,13 +67,24 @@ def __init__(
# Accept these for reconstruction paths, but derive both from `cte`.
del F, cov
self.cte = cte

def _loc(x, u, t_now, t_next):
_ensure_ctse_diffusion_metadata(
cte,
state_dim=jnp.asarray(x).shape[-1],
)
return euler_maruyama_loc_cov(cte, x, u, t_now, t_next)["loc"]

def _cov(x, u, t_now, t_next):
_ensure_ctse_diffusion_metadata(
cte,
state_dim=jnp.asarray(x).shape[-1],
)
return euler_maruyama_loc_cov(cte, x, u, t_now, t_next)["cov"]

super().__init__(
F=lambda x, u, t_now, t_next: euler_maruyama_loc_cov(
cte, x, u, t_now, t_next
)["loc"],
cov=lambda x, u, t_now, t_next: euler_maruyama_loc_cov(
cte, x, u, t_now, t_next
)["cov"],
F=_loc,
cov=_cov,
)

def __call__(self, x, u, t_now, t_next):
Expand Down
72 changes: 56 additions & 16 deletions dynestyx/inference/integrations/cd_dynamax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,40 +22,67 @@
LinearGaussianObservation,
LinearGaussianStateEvolution,
)
from dynestyx.models.diffusions import diffusion_as_matrix, evaluate_diffusion

type SSMType = ContDiscreteNonlinearGaussianSSM | ContDiscreteNonlinearSSM


def _normalize_cd_dynamax_diffusion(
diffusion_coefficient,
state_evolution: ContinuousTimeStateEvolution,
state_dim: int,
):
"""Return a diffusion coeff compatible with cd-dynamax's EnKF SDE solve.

cd-dynamax's internal diffrax wrapper builds Brownian controls with shape
equal to `y0.shape` (state_dim). For non-square diffusion coefficients
(state_dim, bm_dim) with bm_dim != state_dim, pad/truncate columns so the
returned matrix is always (state_dim, state_dim).
equal to `y0.shape` (state_dim). For diffusion with `bm_dim < state_dim`,
pad trailing Brownian columns with zeros to match `(state_dim, state_dim)`.
Diffusion with `bm_dim > state_dim` is rejected.
"""

def _wrapped(x, u, t):
L = diffusion_coefficient(x, u, t)
if L.ndim == 1:
L = jnp.diag(L)
if L.ndim != 2:
diffusion = evaluate_diffusion(
state_evolution.diffusion_coefficient,
diffusion_type=state_evolution.diffusion_type,
bm_dim=state_evolution.bm_dim,
x=x,
u=u,
t=t,
state_dim=state_dim,
)
L = diffusion_as_matrix(diffusion, state_dim=state_dim)
n_cols = L.shape[-1]
if n_cols > state_dim:
raise ValueError(
"diffusion_coefficient must return a vector or matrix for cd-dynamax."
"cd-dynamax continuous diffusion requires bm_dim <= state_dim. "
f"Got state_dim={state_dim}, bm_dim={n_cols}."
)
n_cols = L.shape[-1]
if n_cols == state_dim:
return L
if n_cols < state_dim:
return jnp.pad(L, ((0, 0), (0, state_dim - n_cols)))
return L[:, :state_dim]
pad_width = ((0, 0),) * (L.ndim - 1) + ((0, state_dim - n_cols),)
L = jnp.pad(L, pad_width)
return L

return _wrapped


def _validate_cd_dynamax_continuous_diffusion(
state_evolution: ContinuousTimeStateEvolution,
state_dim: int,
) -> None:
"""Eagerly validate diffusion shape constraints for cd-dynamax continuous filters."""
if state_evolution.diffusion_coefficient is None:
return
if state_evolution.bm_dim is None:
raise ValueError(
"Continuous cd-dynamax filters require resolved bm_dim on "
"ContinuousTimeStateEvolution."
)
if state_evolution.bm_dim > state_dim:
raise ValueError(
"Continuous cd-dynamax filters require bm_dim <= state_dim. "
f"Got state_dim={state_dim}, bm_dim={state_evolution.bm_dim}."
)


class _ConstantFunction(eqx.Module):
value: Any

Expand Down Expand Up @@ -198,7 +225,16 @@ def dsx_to_cdlgssm_params(dsx_model: DynamicalModel) -> ParamsCDLGSSM:

# Extract constant L and use inferred Brownian dimension.
x0 = jnp.zeros(dsx_model.state_dim)
L = state_evo.diffusion_coefficient(x0, None, jnp.array(0.0))
diffusion = evaluate_diffusion(
state_evo.diffusion_coefficient,
diffusion_type=state_evo.diffusion_type,
bm_dim=state_evo.bm_dim,
x=x0,
u=None,
Comment thread
DanWaxman marked this conversation as resolved.
Outdated
t=jnp.array(0.0),
state_dim=dsx_model.state_dim,
)
L = diffusion_as_matrix(diffusion, state_dim=dsx_model.state_dim)
if state_evo.bm_dim is None:
raise ValueError(
"state_evolution.bm_dim is not set on ContinuousTimeStateEvolution."
Expand Down Expand Up @@ -273,8 +309,12 @@ def dsx_to_cd_dynamax(
raise ValueError(
"state_evolution.bm_dim is not set on ContinuousTimeStateEvolution."
)
_validate_cd_dynamax_continuous_diffusion(
state_evo,
dsx_model.state_dim,
)
diffusion_coeff = _normalize_cd_dynamax_diffusion(
state_evo.diffusion_coefficient,
state_evo,
dsx_model.state_dim,
)
shared_params.update(
Expand Down
2 changes: 2 additions & 0 deletions dynestyx/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
DynamicalModel,
ObservationModel,
)
from dynestyx.models.diffusions import DiffusionType
from dynestyx.models.lti_dynamics import LTI_continuous, LTI_discrete
from dynestyx.models.observations import (
DiracIdentityObservation,
Expand All @@ -27,6 +28,7 @@
"AffineDrift",
"DiracIdentityObservation",
"DiscreteTimeStateEvolution",
"DiffusionType",
"DynamicalModel",
"Drift",
"GaussianObservation",
Expand Down
77 changes: 45 additions & 32 deletions dynestyx/models/checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
import numpyro.distributions as dist
import numpyro.primitives

from dynestyx.models.diffusions import (
DiffusionType,
evaluate_diffusion_value,
resolve_diffusion_metadata,
)
from dynestyx.types import Control, State, Time


Expand Down Expand Up @@ -80,60 +85,67 @@ def _make_probe_state(initial_condition: Any, state_dim: int) -> jax.Array:
return jnp.zeros((state_dim,))


def _infer_bm_dim(
def _resolve_ctse_diffusion_metadata(
state_evolution: Any,
state_dim: int,
x0: State,
u0: Control | None,
t0: Time,
) -> int | None:
"""Infer bm_dim from diffusion coefficient output shape.
) -> tuple[DiffusionType, int] | None:
"""Resolve diffusion metadata from one probe evaluation.

Tolerates leading batch dimensions (e.g. from plate-batched parameters)
by inspecting only the trailing two dimensions (..., state_dim, bm_dim).

Returns the inferred bm_dim, or None if there is no diffusion coefficient.
Returns the resolved diffusion type and bm_dim, or None if there is no
diffusion coefficient.
"""
if state_evolution.diffusion_coefficient is None:
if state_evolution.bm_dim is not None:
raise ValueError("bm_dim cannot be set when diffusion_coefficient is None.")
if (
state_evolution.bm_dim is not None
or state_evolution.diffusion_type is not None
):
raise ValueError(
"diffusion_type and bm_dim cannot be set when "
"diffusion_coefficient is None."
)
return None

diffusion_shape = jax.eval_shape(
lambda: state_evolution.diffusion_coefficient(x0, u0, t0)
).shape
if len(diffusion_shape) < 2:
raise ValueError(
"diffusion_coefficient must return shape (..., state_dim, bm_dim). "
f"Got shape {diffusion_shape}."
)
if int(diffusion_shape[-2]) != state_dim:
raise ValueError(
"diffusion_coefficient penultimate dimension must match state_dim. "
f"Got diffusion shape {diffusion_shape}, state_dim={state_dim}."
)
inferred_bm_dim = int(diffusion_shape[-1])
if (
state_evolution.bm_dim is not None
and int(state_evolution.bm_dim) != inferred_bm_dim
):
raise ValueError(
"bm_dim does not match inferred diffusion_coefficient output shape. "
f"Got bm_dim={state_evolution.bm_dim}, inferred={inferred_bm_dim}."
lambda: evaluate_diffusion_value(
state_evolution.diffusion_coefficient, x0, u0, t0
)
return inferred_bm_dim
).shape
return resolve_diffusion_metadata(
diffusion_shape,
state_dim=state_dim,
diffusion_type=state_evolution.diffusion_type,
bm_dim=state_evolution.bm_dim,
)


def _validate_continuous_state_evolution(
def _infer_bm_dim(
state_evolution: Any,
state_dim: int,
x0: State,
u0: Control | None,
t0: Time,
) -> int | None:
"""Compatibility wrapper returning only the resolved bm_dim."""
resolved = _resolve_ctse_diffusion_metadata(state_evolution, state_dim, x0, u0, t0)
return None if resolved is None else resolved[1]


def _validate_continuous_state_evolution(
state_evolution: Any,
state_dim: int,
x0: State,
u0: Control | None,
t0: Time,
) -> tuple[DiffusionType, int] | None:
"""Validate the shape of the continuous-time state evolution w.r.t. state_dim and bm_dim.

Returns the inferred bm_dim (or None if no diffusion coefficient).
Returns the resolved diffusion metadata (or None if no diffusion coefficient).
"""
drift_shape = jax.eval_shape(lambda: state_evolution.total_drift(x0, u0, t0)).shape
if drift_shape != (state_dim,):
Expand All @@ -142,7 +154,7 @@ def _validate_continuous_state_evolution(
f"Expected {(state_dim,)}, got {drift_shape}."
)

return _infer_bm_dim(state_evolution, state_dim, x0, u0, t0)
return _resolve_ctse_diffusion_metadata(state_evolution, state_dim, x0, u0, t0)


def _validate_state_evolution_output_shape(
Expand All @@ -157,16 +169,17 @@ def _validate_state_evolution_output_shape(
) -> int | None:
"""Validate the shape of the state evolution w.r.t. state_dim (and bm_dim for continuous-time models).

Returns the inferred bm_dim for continuous-time models, or None otherwise.
Returns the resolved bm_dim for continuous-time models, or None otherwise.
"""
if continuous_time:
return _validate_continuous_state_evolution(
resolved_diffusion = _validate_continuous_state_evolution(
state_evolution=state_evolution,
state_dim=state_dim,
x0=x0,
u0=u0,
t0=t0,
)
return None if resolved_diffusion is None else resolved_diffusion[1]
else:
if getattr(state_evolution, "bm_dim", None) is not None:
raise ValueError(
Expand Down
Loading
Loading