Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
021550d
support Array and lambda diffusions (special support for scalar/diag/…
mattlevine22 Apr 27, 2026
9652cbf
adding back padding of diffusion-coefficient in CD-Dynamax integration
mattlevine22 Apr 27, 2026
8c6726e
Update dynestyx/inference/integrations/cd_dynamax/utils.py
mattlevine22 Apr 27, 2026
38df332
Update dynestyx/models/diffusions.py
mattlevine22 Apr 27, 2026
aa50c38
docs: rename misleading test names to reflect bm_dim > state_dim cons…
Copilot Apr 27, 2026
ecd987e
plz lint
mattlevine22 Apr 27, 2026
bc0c252
Merge branch 'main' into ml-feature-212
mattlevine22 Apr 28, 2026
375cf59
get diffusion info at init
mattlevine22 Apr 28, 2026
0fa5667
give discretizer fallback if dynamicalModel init hadnt been run (supp…
mattlevine22 Apr 28, 2026
bf1f456
new Diffusion class with subtypes.
mattlevine22 May 10, 2026
ef5dab6
merge with main (plate upgrades)
mattlevine22 May 10, 2026
20d1d37
dont use cast, just assert bm_dim is not None (fight w linter...could…
mattlevine22 May 11, 2026
e9ea930
merge with main (smoothers)
mattlevine22 May 14, 2026
bc41705
please lint
mattlevine22 May 14, 2026
3f37300
rename as gram
mattlevine22 May 14, 2026
657a190
rename func to _coerce_to_param_dtype
mattlevine22 May 14, 2026
958a2a7
simplify edit to Quick example
mattlevine22 May 14, 2026
0e04441
dont change faq
mattlevine22 May 14, 2026
2cfab46
simplify notebook changes for PR
mattlevine22 May 14, 2026
90a2813
simplify notebook changes for PR
mattlevine22 May 14, 2026
c8c5cd8
simplify notebook changes for PR
mattlevine22 May 14, 2026
2e02f2f
improve API documentation for diffusions
mattlevine22 May 14, 2026
740750e
improve API documentation for diffusions
mattlevine22 May 14, 2026
2ebb45c
Update dynestyx/models/diffusions.py
mattlevine22 May 15, 2026
3612c92
simplify code and improve docs
mattlevine22 May 15, 2026
a4cee35
clarify resolve metadata docstring
mattlevine22 May 15, 2026
19a344b
Update dynestyx/inference/integrations/cd_dynamax/utils.py
mattlevine22 May 15, 2026
6d39004
probing and errors for cd-dynamax diffusion plumbing
mattlevine22 May 15, 2026
d2f8564
fix lint
mattlevine22 May 15, 2026
83641cb
Update dynestyx/models/core.py
mattlevine22 May 15, 2026
8501a17
streamline DynamicalModel init
mattlevine22 May 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 63 additions & 16 deletions dynestyx/inference/integrations/cd_dynamax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,40 +22,74 @@
LinearGaussianObservation,
LinearGaussianStateEvolution,
)
from dynestyx.models.diffusions import diffusion_as_matrix, evaluate_diffusion

type SSMType = ContDiscreteNonlinearGaussianSSM | ContDiscreteNonlinearSSM


def _normalize_cd_dynamax_diffusion(
diffusion_coefficient,
state_evolution: ContinuousTimeStateEvolution,
state_dim: int,
):
"""Return a diffusion coeff compatible with cd-dynamax's EnKF SDE solve.

cd-dynamax's internal diffrax wrapper builds Brownian controls with shape
equal to `y0.shape` (state_dim). For non-square diffusion coefficients
(state_dim, bm_dim) with bm_dim != state_dim, pad/truncate columns so the
returned matrix is always (state_dim, state_dim).
equal to `y0.shape` (state_dim). For diffusion with `bm_dim < state_dim`,
pad trailing Brownian columns with zeros to match `(state_dim, state_dim)`.
Diffusion with `bm_dim > state_dim` is rejected.
"""

def _wrapped(x, u, t):
L = diffusion_coefficient(x, u, t)
if L.ndim == 1:
L = jnp.diag(L)
if L.ndim != 2:
diffusion = evaluate_diffusion(
state_evolution.diffusion_coefficient,
diffusion_type=state_evolution.diffusion_type,
bm_dim=state_evolution.bm_dim,
x=x,
u=u,
t=t,
state_dim=state_dim,
)
L = diffusion_as_matrix(diffusion, state_dim=state_dim)
n_cols = L.shape[-1]
if n_cols > state_dim:
raise ValueError(
"diffusion_coefficient must return a vector or matrix for cd-dynamax."
"cd-dynamax continuous diffusion requires bm_dim <= state_dim. "
f"Got state_dim={state_dim}, bm_dim={n_cols}."
)
n_cols = L.shape[-1]
if n_cols == state_dim:
return L
if n_cols < state_dim:
return jnp.pad(L, ((0, 0), (0, state_dim - n_cols)))
return L[:, :state_dim]
pad_width = ((0, 0),) * (L.ndim - 1) + ((0, state_dim - n_cols),)
L = jnp.pad(L, pad_width)
return L

return _wrapped


def _validate_cd_dynamax_continuous_diffusion(
state_evolution: ContinuousTimeStateEvolution,
state_dim: int,
) -> None:
"""Eagerly validate diffusion shape constraints for cd-dynamax continuous filters."""
if state_evolution.diffusion_coefficient is None:
return

probe_x = jnp.zeros(state_dim)
diffusion = evaluate_diffusion(
state_evolution.diffusion_coefficient,
diffusion_type=state_evolution.diffusion_type,
bm_dim=state_evolution.bm_dim,
x=probe_x,
u=None,
t=jnp.array(0.0),
state_dim=state_dim,
)
bm_dim = diffusion_as_matrix(diffusion, state_dim=state_dim).shape[-1]
if bm_dim > state_dim:
raise ValueError(
"Continuous cd-dynamax filters require bm_dim <= state_dim. "
f"Got state_dim={state_dim}, bm_dim={bm_dim}."
)
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description says continuous cd-dynamax filtering requires square full diffusion (bm_dim == state_dim / no rectangular diffusion), but the implemented validation only rejects bm_dim > state_dim and _normalize_cd_dynamax_diffusion explicitly pads when bm_dim < state_dim. Please reconcile the intended constraint: either tighten validation to require bm_dim == state_dim (and possibly diffusion_type == 'full') or update the PR description/docs to match the current bm_dim <= state_dim + zero-padding behavior.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot update the docs/descriptions to match the current behavior.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in commit aa50c38. Updated the PR description Notes section and renamed the two misleading test functions to accurately reflect that only bm_dim > state_dim is rejected — rectangular diffusion with bm_dim < state_dim is zero-padded and supported.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure



class _ConstantFunction(eqx.Module):
value: Any

Expand Down Expand Up @@ -198,7 +232,16 @@ def dsx_to_cdlgssm_params(dsx_model: DynamicalModel) -> ParamsCDLGSSM:

# Extract constant L and use inferred Brownian dimension.
x0 = jnp.zeros(dsx_model.state_dim)
L = state_evo.diffusion_coefficient(x0, None, jnp.array(0.0))
diffusion = evaluate_diffusion(
state_evo.diffusion_coefficient,
diffusion_type=state_evo.diffusion_type,
bm_dim=state_evo.bm_dim,
x=x0,
u=None,
Comment thread
DanWaxman marked this conversation as resolved.
Outdated
t=jnp.array(0.0),
state_dim=dsx_model.state_dim,
)
L = diffusion_as_matrix(diffusion, state_dim=dsx_model.state_dim)
if state_evo.bm_dim is None:
raise ValueError(
"state_evolution.bm_dim is not set on ContinuousTimeStateEvolution."
Expand Down Expand Up @@ -273,8 +316,12 @@ def dsx_to_cd_dynamax(
raise ValueError(
"state_evolution.bm_dim is not set on ContinuousTimeStateEvolution."
)
_validate_cd_dynamax_continuous_diffusion(
state_evo,
dsx_model.state_dim,
)
diffusion_coeff = _normalize_cd_dynamax_diffusion(
state_evo.diffusion_coefficient,
state_evo,
dsx_model.state_dim,
)
shared_params.update(
Expand Down
2 changes: 2 additions & 0 deletions dynestyx/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
DynamicalModel,
ObservationModel,
)
from dynestyx.models.diffusions import DiffusionType
from dynestyx.models.lti_dynamics import LTI_continuous, LTI_discrete
from dynestyx.models.observations import (
DiracIdentityObservation,
Expand All @@ -27,6 +28,7 @@
"AffineDrift",
"DiracIdentityObservation",
"DiscreteTimeStateEvolution",
"DiffusionType",
"DynamicalModel",
"Drift",
"GaussianObservation",
Expand Down
33 changes: 13 additions & 20 deletions dynestyx/models/checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import numpyro.distributions as dist
import numpyro.primitives

from dynestyx.models.diffusions import (
evaluate_diffusion_value,
resolve_diffusion_metadata,
)
from dynestyx.types import Control, State, Time


Expand Down Expand Up @@ -100,27 +104,16 @@ def _infer_bm_dim(
return None

diffusion_shape = jax.eval_shape(
lambda: state_evolution.diffusion_coefficient(x0, u0, t0)
).shape
if len(diffusion_shape) < 2:
raise ValueError(
"diffusion_coefficient must return shape (..., state_dim, bm_dim). "
f"Got shape {diffusion_shape}."
)
if int(diffusion_shape[-2]) != state_dim:
raise ValueError(
"diffusion_coefficient penultimate dimension must match state_dim. "
f"Got diffusion shape {diffusion_shape}, state_dim={state_dim}."
)
inferred_bm_dim = int(diffusion_shape[-1])
if (
state_evolution.bm_dim is not None
and int(state_evolution.bm_dim) != inferred_bm_dim
):
raise ValueError(
"bm_dim does not match inferred diffusion_coefficient output shape. "
f"Got bm_dim={state_evolution.bm_dim}, inferred={inferred_bm_dim}."
lambda: evaluate_diffusion_value(
state_evolution.diffusion_coefficient, x0, u0, t0
)
).shape
_, inferred_bm_dim = resolve_diffusion_metadata(
diffusion_shape,
state_dim=state_dim,
diffusion_type=state_evolution.diffusion_type,
bm_dim=state_evolution.bm_dim,
)
return inferred_bm_dim


Expand Down
29 changes: 22 additions & 7 deletions dynestyx/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,15 @@
_validate_state_dim,
_validate_state_evolution_output_shape,
)
from dynestyx.types import Control, State, Time, TimeLike, as_scalar_time_array, dState
from dynestyx.models.diffusions import DiffusionSpec, DiffusionType
from dynestyx.types import (
Control,
State,
Time,
TimeLike,
as_scalar_time_array,
dState,
)


class DynamicalModel(eqx.Module):
Expand Down Expand Up @@ -298,18 +306,25 @@ class ContinuousTimeStateEvolution(eqx.Module):
At least one of `drift` or `potential` must be non-None.
use_negative_gradient (bool): If True, use $-\\nabla_x V$ (e.g., gradient descent on potential);
otherwise use $+\\nabla_x V$. Default is False.
diffusion_coefficient (Drift | None): Diffusion coefficient $L(x, u, t)$ mapping to a matrix;
multiplies the Brownian increment $dW_t$.
Defaults to zero if None (i.e., deterministic ODE).
diffusion_coefficient (DiffusionSpec | None): Diffusion coefficient specification.
This may be a callable `L(x, u, t)`, a constant scalar/vector/matrix, or `None`
for deterministic dynamics.
diffusion_type ("full" | "diag" | "scalar" | None): Optional explicit diffusion semantics.
Use `"full"` for matrix-valued diffusion, `"diag"` for diagonal shorthand with
trailing shape `(..., state_dim)`, and `"scalar"` for scalar shorthand with shape
`()` or `(..., 1)`. If omitted, legacy behavior infers `"full"` from trailing
shape `(..., state_dim, bm_dim)` and otherwise infers scalar/diagonal shorthand
from the trailing dimension.
bm_dim (int | None): Dimension of the Brownian motion $W_t$.
Inferred automatically from the output shape of `diffusion_coefficient`;
if passed by the user, it must match diffusion_coefficient(...).shape[1].
Inferred automatically only for full matrix diffusion. Scalar and diagonal
diffusion require explicit `bm_dim`, which must be either `1` or `state_dim`.
"""

drift: Drift | None = None
potential: Potential | None = None
use_negative_gradient: bool = eqx.field(static=True, default=False)
diffusion_coefficient: Drift | None = None
diffusion_coefficient: DiffusionSpec | None = None
diffusion_type: DiffusionType | None = eqx.field(static=True, default=None)
bm_dim: int | None = eqx.field(static=True, default=None)

def total_drift(self, x: State, u: Control | None, t: Time) -> dState:
Expand Down
Loading
Loading