Skip to content

Support structured diffusion specs in continuous-time models and enforce cd-dynamax full-diffusion constraints#214

Merged
DanWaxman merged 31 commits into
mainfrom
ml-feature-212
May 15, 2026
Merged

Support structured diffusion specs in continuous-time models and enforce cd-dynamax full-diffusion constraints#214
DanWaxman merged 31 commits into
mainfrom
ml-feature-212

Conversation

@mattlevine22
Copy link
Copy Markdown
Collaborator

@mattlevine22 mattlevine22 commented Apr 27, 2026

Addresses #212

Summary

This PR introduces a structured diffusion API for continuous-time models.

The main change is a new dynestyx.models.diffusions module with first-class diffusion objects:

  • Diffusion
  • FullDiffusion
  • DiagonalDiffusion
  • ScalarDiffusion

These objects make the shape and interpretation of the SDE diffusion term explicit, instead of relying on loosely structured diffusion_coefficient inputs.

In particular, this PR:

  • adds structured diffusion classes for full, diagonal, and scalar diffusion coefficients
  • updates ContinuousTimeStateEvolution / DynamicalModel to canonicalize and validate diffusion metadata centrally
  • clarifies bm_dim behavior:
    • FullDiffusion: bm_dim is inferred from trailing matrix shape when omitted
    • DiagonalDiffusion / ScalarDiffusion: bm_dim is explicit and must be 1 or state_dim
  • updates SDE solver internals so Euler–Maruyama moments, sampling, and the Diffrax solve path all use the same diffusion semantics
  • updates continuous cd-dynamax integration to validate diffusion structure up front and fail early on unsupported cases
  • updates LTI_continuous and related examples/docs to use the new diffusion API directly
  • expands coverage across model validation, discretizers, bm_dim/plate behavior, and hierarchical smoke tests

Why

This change makes continuous-time model specification clearer and safer.

It also improves performance in simulation-oriented SDE paths when the diffusion has scalar or diagonal structure, since we can avoid treating every diffusion as a dense full matrix. The speedup figure below illustrates this benefit.

Just as importantly, invalid setups now fail earlier with clearer errors, especially in continuous cd-dynamax filtering where only square full diffusion is currently supported.

Notes

  • Continuous cd-dynamax filters currently require square full diffusion (bm_dim == state_dim).
  • More general rectangular diffusion remains supported in simulation-oriented SDE / Euler–Maruyama paths, but not in the current continuous cd-dynamax filter backend.
  • I explored exploiting scalar/diagonal structure more aggressively inside Filter + Discretizer, but did not see meaningful gains there, so those backend-specific optimizations were not kept; will matter more with multi-step discretizers, and can leverage then.

Tests

Added or updated coverage in:

  • tests/test_models_core.py
  • tests/test_discretizers.py
  • tests/test_bm_dim_plate.py
  • tests/test_hierarchical_simulator_discretizer_smokes.py

Below image showing speed-ups of solver using scalar/diag/full diffusion coefficient in SDE.
image

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends continuous-time SDE model support by introducing a structured diffusion specification API (constant scalar/diag/full or callable) and plumbing those semantics through validation, solvers, and the cd-dynamax integration path.

Changes:

  • Add dynestyx/models/diffusions.py to standardize diffusion evaluation/inference (diffusion_type, bm_dim) and conversions (matrix/covariance/application).
  • Expand ContinuousTimeStateEvolution to accept DiffusionSpec + optional diffusion_type, update core validation/inference accordingly.
  • Update SDE solver internals and cd-dynamax integration utilities to use the shared diffusion semantics; expand test coverage across models/discretizers/plates.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
dynestyx/models/diffusions.py New helper module implementing evaluation, type inference, covariance/matrix conversion, and applying diffusion to Brownian increments.
dynestyx/models/core.py Extends ContinuousTimeStateEvolution API to accept structured diffusion specs and optional diffusion_type.
dynestyx/models/checkers.py Updates bm_dim inference/validation to use the new diffusion evaluation/metadata resolution helpers.
dynestyx/solvers/sde.py Replaces ad-hoc diffusion handling with shared helpers for EM moments/sampling and diffrax diffusion callbacks.
dynestyx/inference/integrations/cd_dynamax/utils.py Normalizes/validates diffusion for cd-dynamax using the new diffusion semantics and adds early validation.
dynestyx/models/lti_dynamics.py Updates LTI_continuous to pass constant diffusion matrix directly using the new API.
dynestyx/models/__init__.py Re-exports DiffusionType.
tests/test_models_core.py Adds unit tests for new diffusion semantics and cd-dynamax diffusion constraints/normalization.
tests/test_discretizers.py Adds parametrized tests ensuring structured diffusion forms match expected dense covariance behavior.
tests/test_bm_dim_plate.py Expands plate-context tests to cover structured diffusion specs and expected bm_dim resolution behavior.
tests/test_hierarchical_simulator_discretizer_smokes.py Updates/clarifies a comment related to rectangular diffusion behavior in cd-dynamax integration.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread dynestyx/models/diffusions.py Outdated
Comment thread dynestyx/inference/integrations/cd_dynamax/utils.py Outdated
Comment on lines +84 to +89
bm_dim = diffusion_as_matrix(diffusion, state_dim=state_dim).shape[-1]
if bm_dim > state_dim:
raise ValueError(
"Continuous cd-dynamax filters require bm_dim <= state_dim. "
f"Got state_dim={state_dim}, bm_dim={bm_dim}."
)
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description says continuous cd-dynamax filtering requires square full diffusion (bm_dim == state_dim / no rectangular diffusion), but the implemented validation only rejects bm_dim > state_dim and _normalize_cd_dynamax_diffusion explicitly pads when bm_dim < state_dim. Please reconcile the intended constraint: either tighten validation to require bm_dim == state_dim (and possibly diffusion_type == 'full') or update the PR description/docs to match the current bm_dim <= state_dim + zero-padding behavior.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot update the docs/descriptions to match the current behavior.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in commit aa50c38. Updated the PR description Notes section and renamed the two misleading test functions to accurately reflect that only bm_dim > state_dim is rejected — rectangular diffusion with bm_dim < state_dim is zero-padded and supported.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

mattlevine22 and others added 2 commits April 26, 2026 23:15
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@DanWaxman
Copy link
Copy Markdown
Collaborator

Thanks, this looks interesting!!

Note that I tried to exploit in Filter + Discretizer, but didn't see any speedups so I removed those implementations.

Out of curiosity, did you try for state estimation only or for parameter inference?

@mattlevine22
Copy link
Copy Markdown
Collaborator Author

Thanks, this looks interesting!!

Note that I tried to exploit in Filter + Discretizer, but didn't see any speedups so I removed those implementations.

Out of curiosity, did you try for state estimation only or for parameter inference?

Only did a forward pass of the filter; didn't use in parameter inference, and actually didn't even look at expense of gradients. Would you expect those to be noticeably different?

@mattlevine22
Copy link
Copy Markdown
Collaborator Author

Thanks, this looks interesting!!

Note that I tried to exploit in Filter + Discretizer, but didn't see any speedups so I removed those implementations.

Out of curiosity, did you try for state estimation only or for parameter inference?

Only did a forward pass of the filter; didn't use in parameter inference, and actually didn't even look at expense of gradients. Would you expect those to be noticeably different?

Thanks, this looks interesting!!

Note that I tried to exploit in Filter + Discretizer, but didn't see any speedups so I removed those implementations.

Out of curiosity, did you try for state estimation only or for parameter inference?

Only did a forward pass of the filter; didn't use in parameter inference, and actually didn't even look at expense of gradients. Would you expect those to be noticeably different?

@DanWaxman I looked at value_and_grad, but saw inconsistent improvements. The best I saw was like a 5% gain. I think this is because currently, we only do this multiply once-per-observation (and we have other operations that occur per-observation that dominate). When we have fancier discretization that take many steps between observations, it will be more impactful I think.

I'd recommend for now introducing the API and leveraging in the simple SDE solver where we know it clearly helps (esp. in large dims), then leveraging diagonal structure as it becomes usable on the methods end.

@mattlevine22 mattlevine22 marked this pull request as ready for review April 27, 2026 20:00
@DanWaxman
Copy link
Copy Markdown
Collaborator

@DanWaxman I looked at value_and_grad, but saw inconsistent improvements. The best I saw was like a 5% gain. I think this is because currently, we only do this multiply once-per-observation (and we have other operations that occur per-observation that dominate). When we have fancier discretization that take many steps between observations, it will be more impactful I think.

I'd recommend for now introducing the API and leveraging in the simple SDE solver where we know it clearly helps (esp. in large dims), then leveraging diagonal structure as it becomes usable on the methods end.

That makes sense!

Comment thread dynestyx/models/diffusions.py Outdated
Copy link
Copy Markdown
Collaborator

@DanWaxman DanWaxman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find all this evaluate_diffusion(...) business pretty awkward, to be honest. Namely, I find it awkward that we have to manually go "into" state_evo and pick out state_evo.diffusion_coefficient, state_evo.diffusion_type, and state_evo.bm_dim every time -- this metadata only makes sense as a package, so we shouldn't have to pick out all of them every time.

Do you have thoughts on putting these into a single dataclass (i.e., have a Diffusion(eqx.Module)) that packages all this info together? Then, we can just have something likediffusion.evaluate_diffusion(x=...,u=...,t=...).

@mattlevine22
Copy link
Copy Markdown
Collaborator Author

I like this idea and gets at #16.

Could also have some constructor classes:

ContinuousTimeStateEvolution(
    drift=...,
    diffusion=Diffusion.diag(jnp.array([0.1, 0.2]), bm_dim=2),
)

class Diffusion(eqx.Module):
    coefficient: DiffusionSpec
    diffusion_type: DiffusionType | None = eqx.field(static=True, default=None)
    bm_dim: int | None = eqx.field(static=True, default=None)

    @classmethod
    def full(cls, coefficient, bm_dim: int | None = None):
        return cls(coefficient=coefficient, diffusion_type="full", bm_dim=bm_dim)

    @classmethod
    def diag(cls, coefficient, bm_dim: int):
        return cls(coefficient=coefficient, diffusion_type="diag", bm_dim=bm_dim)

    @classmethod
    def scalar(cls, coefficient, bm_dim: int):
        return cls(coefficient=coefficient, diffusion_type="scalar", bm_dim=bm_dim)

@mattlevine22
Copy link
Copy Markdown
Collaborator Author

Could also do this for Drifts...and I think this may help us a lot down the road in adding new features.

May want to think ahead a bit more around names to anticipate future dev

@mattlevine22 mattlevine22 requested a review from DanWaxman May 14, 2026 20:18
Copy link
Copy Markdown
Collaborator

@DanWaxman DanWaxman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't get through all the tests, but this looks like a major step! I don't like coercion or __setattr__ much. Also not sure that we should call a resolved meta-type "canonical."

Comment thread docs/api_reference/public/models/core/diffusion.md
Comment thread docs/logo/make_logo.py
Comment thread docs/math_intro.md
Comment thread dynestyx/inference/integrations/cd_dynamax/continuous_filter.py Outdated
Comment thread dynestyx/inference/integrations/cd_dynamax/continuous_filter.py Outdated
Comment thread dynestyx/models/core.py Outdated
Comment thread dynestyx/models/core.py Outdated
Comment thread dynestyx/models/diffusions.py Outdated
Comment thread dynestyx/models/diffusions.py
Comment thread dynestyx/models/diffusions.py Outdated
@mattlevine22 mattlevine22 requested a review from DanWaxman May 15, 2026 01:23
Copy link
Copy Markdown
Collaborator

@DanWaxman DanWaxman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! A few other small things that were remaining.

Comment thread docs/logo/make_logo.py
Comment thread dynestyx/inference/integrations/cd_dynamax/utils.py Outdated
Comment thread dynestyx/inference/integrations/cd_dynamax/utils.py Outdated
@mattlevine22
Copy link
Copy Markdown
Collaborator Author

Addressed missing u_probe in cd_dynamax/utils.py; also surfaced an outstanding issue hd-UQ/cd_dynamax#25 which I'll try to address separately.

@mattlevine22 mattlevine22 requested a review from DanWaxman May 15, 2026 19:13
Copy link
Copy Markdown
Collaborator

@DanWaxman DanWaxman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think probably the last batch of comments :))

Comment thread dynestyx/models/core.py Outdated
Comment thread dynestyx/models/core.py Outdated
Comment on lines +226 to +239
# In a plate, parameter callables often expect batched parameters, so
# we resolve diffusion metadata using synthetic per-trajectory probes.
x_probe = _make_probe_state(
initial_condition=initial_condition, state_dim=inferred_state_dim
)
u_probe = None if control_dim == 0 else jnp.zeros((control_dim,))
t_probe = jnp.array(0.0) if self.t0 is None else self.t0
resolved_state_evolution = _refine_continuous_state_evolution(
resolved_state_evolution,
x_probe=x_probe,
u_probe=u_probe,
t_probe=t_probe,
)
self.state_evolution = resolved_state_evolution
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't think I understand this block, unfortunately. In particular, I still don't really see how it differs other than (as the earlier comment states) not doing validation. In which case, I think this whole block can be simplified?

(Sorry I know I'm making this PR a long process, just think having this block be understandable is very important)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the part where we turn the CTSE into Discrete/StochasticCTSE; we use the probes in case it is a callable to check shapes so that we can detect/verify bm_dim as well.

I'm not entirely clear how to make sure I'm playing well with the plates...previously, there was infer_bm_dim code for _inside_plate True and False.

I'll refactor to try to streamline and clarify

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I guess my point is what's the real difference between here:

# In a plate, parameter callables often expect batched parameters, so
# we resolve diffusion metadata using synthetic per-trajectory probes.
x_probe = _make_probe_state(
initial_condition=initial_condition, state_dim=inferred_state_dim
)
u_probe = None if control_dim == 0 else jnp.zeros((control_dim,))
t_probe = jnp.array(0.0) if self.t0 is None else self.t0
resolved_state_evolution = _refine_continuous_state_evolution(
resolved_state_evolution,
x_probe=x_probe,
u_probe=u_probe,
t_probe=t_probe,
)
self.state_evolution = resolved_state_evolution

and the non-plate path:
x_probe = _make_probe_state(
initial_condition=initial_condition, state_dim=inferred_state_dim
)
u_probe = None if control_dim == 0 else jnp.zeros((control_dim,))
t_probe = jnp.array(0.0) if self.t0 is None else self.t0
if self.continuous_time:
_validate_continuous_state_evolution(
state_evolution=state_evolution,
state_dim=inferred_state_dim,
x_probe=x_probe,
u_probe=u_probe,
t_probe=t_probe,
)
resolved_state_evolution = _refine_continuous_state_evolution(
resolved_state_evolution,
x_probe=x_probe,
u_probe=u_probe,
t_probe=t_probe,
)

except that one validates and one doesn't?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm trying to consolidate that now. I think that question could be asked of the current main branch as well, right?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might've been some sloppy code from before (maybe it was just "early return before validation in plate" -> add in bm_dim inference -> add in probe states -> lots of duplicate code), but we should clean that up now if so, I think.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, just making sure we are understanding the situation similarly

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we continue to not validate inside the plate?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

proposing:

        if _inside_plate:
            # Cannot validate shapes with batched parameters; trust the user.
            # Infer observation_dim from observation model if not explicitly provided.
            inferred_obs_dim = _infer_observation_dim_in_plate_context(
                initial_condition=initial_condition,
                observation_model=observation_model,
                inferred_state_dim=inferred_state_dim,
                control_dim=control_dim,
                t_probe=self.t0,
                observation_dim=observation_dim,
            )
        else:
            x_probe, u_probe, t_probe = _make_probes()

            if self.continuous_time:
                _validate_continuous_state_evolution(
                    state_evolution=state_evolution,
                    state_dim=inferred_state_dim,
                    x_probe=x_probe,
                    u_probe=u_probe,
                    t_probe=t_probe,
                )
            else:
                _validate_discrete_state_evolution_output_shape(
                    state_evolution=state_evolution,
                    state_dim=inferred_state_dim,
                    x_probe=x_probe,
                    u_probe=u_probe,
                    t_probe=t_probe,
                )

            obs_dist = observation_model(x_probe, u_probe, t_probe)
            inferred_obs_dim = _infer_vector_dim_from_distribution(
                obs_dist, "observation_model(x, u, t)"
            )
            _validate_observation_dim(observation_dim, inferred_obs_dim)

        self.state_evolution = _resolve_continuous_state_evolution(
            state_evolution,
        )

        self.state_dim = int(inferred_state_dim)
        self.observation_dim = int(inferred_obs_dim)
        self.control_dim = int(control_dim)
        self.categorical_state = bool(inferred_categorical_state)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't

self.state_evolution = _resolve_continuous_state_evolution(
            state_evolution,
        )

require the probe states? This looks good to me though, otherwise.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored to let it use _make_probe...will update in a bit.

mattlevine22 and others added 2 commits May 15, 2026 16:46
Co-authored-by: Dan Waxman <dan.waxman1@gmail.com>
@mattlevine22 mattlevine22 requested a review from DanWaxman May 15, 2026 21:42
Copy link
Copy Markdown
Collaborator

@DanWaxman DanWaxman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@DanWaxman DanWaxman merged commit aea4e37 into main May 15, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants