Support structured diffusion specs in continuous-time models and enforce cd-dynamax full-diffusion constraints#214
Conversation
There was a problem hiding this comment.
Pull request overview
This PR extends continuous-time SDE model support by introducing a structured diffusion specification API (constant scalar/diag/full or callable) and plumbing those semantics through validation, solvers, and the cd-dynamax integration path.
Changes:
- Add
dynestyx/models/diffusions.pyto standardize diffusion evaluation/inference (diffusion_type,bm_dim) and conversions (matrix/covariance/application). - Expand
ContinuousTimeStateEvolutionto acceptDiffusionSpec+ optionaldiffusion_type, update core validation/inference accordingly. - Update SDE solver internals and cd-dynamax integration utilities to use the shared diffusion semantics; expand test coverage across models/discretizers/plates.
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
dynestyx/models/diffusions.py |
New helper module implementing evaluation, type inference, covariance/matrix conversion, and applying diffusion to Brownian increments. |
dynestyx/models/core.py |
Extends ContinuousTimeStateEvolution API to accept structured diffusion specs and optional diffusion_type. |
dynestyx/models/checkers.py |
Updates bm_dim inference/validation to use the new diffusion evaluation/metadata resolution helpers. |
dynestyx/solvers/sde.py |
Replaces ad-hoc diffusion handling with shared helpers for EM moments/sampling and diffrax diffusion callbacks. |
dynestyx/inference/integrations/cd_dynamax/utils.py |
Normalizes/validates diffusion for cd-dynamax using the new diffusion semantics and adds early validation. |
dynestyx/models/lti_dynamics.py |
Updates LTI_continuous to pass constant diffusion matrix directly using the new API. |
dynestyx/models/__init__.py |
Re-exports DiffusionType. |
tests/test_models_core.py |
Adds unit tests for new diffusion semantics and cd-dynamax diffusion constraints/normalization. |
tests/test_discretizers.py |
Adds parametrized tests ensuring structured diffusion forms match expected dense covariance behavior. |
tests/test_bm_dim_plate.py |
Expands plate-context tests to cover structured diffusion specs and expected bm_dim resolution behavior. |
tests/test_hierarchical_simulator_discretizer_smokes.py |
Updates/clarifies a comment related to rectangular diffusion behavior in cd-dynamax integration. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| bm_dim = diffusion_as_matrix(diffusion, state_dim=state_dim).shape[-1] | ||
| if bm_dim > state_dim: | ||
| raise ValueError( | ||
| "Continuous cd-dynamax filters require bm_dim <= state_dim. " | ||
| f"Got state_dim={state_dim}, bm_dim={bm_dim}." | ||
| ) |
There was a problem hiding this comment.
The PR description says continuous cd-dynamax filtering requires square full diffusion (bm_dim == state_dim / no rectangular diffusion), but the implemented validation only rejects bm_dim > state_dim and _normalize_cd_dynamax_diffusion explicitly pads when bm_dim < state_dim. Please reconcile the intended constraint: either tighten validation to require bm_dim == state_dim (and possibly diffusion_type == 'full') or update the PR description/docs to match the current bm_dim <= state_dim + zero-padding behavior.
There was a problem hiding this comment.
@copilot update the docs/descriptions to match the current behavior.
There was a problem hiding this comment.
Done in commit aa50c38. Updated the PR description Notes section and renamed the two misleading test functions to accurately reflect that only bm_dim > state_dim is rejected — rectangular diffusion with bm_dim < state_dim is zero-padded and supported.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…traint Agent-Logs-Url: https://github.com/BasisResearch/dynestyx/sessions/9646e756-8a5e-47b4-9c36-33cb58c2110e Co-authored-by: mattlevine22 <11492591+mattlevine22@users.noreply.github.com>
|
Thanks, this looks interesting!!
Out of curiosity, did you try for state estimation only or for parameter inference? |
Only did a forward pass of the filter; didn't use in parameter inference, and actually didn't even look at expense of gradients. Would you expect those to be noticeably different? |
@DanWaxman I looked at value_and_grad, but saw inconsistent improvements. The best I saw was like a 5% gain. I think this is because currently, we only do this multiply once-per-observation (and we have other operations that occur per-observation that dominate). When we have fancier discretization that take many steps between observations, it will be more impactful I think. I'd recommend for now introducing the API and leveraging in the simple SDE solver where we know it clearly helps (esp. in large dims), then leveraging diagonal structure as it becomes usable on the methods end. |
That makes sense! |
…ort direct API by checking if bm_dim or diffusion_type are None)
DanWaxman
left a comment
There was a problem hiding this comment.
I find all this evaluate_diffusion(...) business pretty awkward, to be honest. Namely, I find it awkward that we have to manually go "into" state_evo and pick out state_evo.diffusion_coefficient, state_evo.diffusion_type, and state_evo.bm_dim every time -- this metadata only makes sense as a package, so we shouldn't have to pick out all of them every time.
Do you have thoughts on putting these into a single dataclass (i.e., have a Diffusion(eqx.Module)) that packages all this info together? Then, we can just have something likediffusion.evaluate_diffusion(x=...,u=...,t=...).
|
I like this idea and gets at #16. Could also have some constructor classes: ContinuousTimeStateEvolution(
drift=...,
diffusion=Diffusion.diag(jnp.array([0.1, 0.2]), bm_dim=2),
)
class Diffusion(eqx.Module):
coefficient: DiffusionSpec
diffusion_type: DiffusionType | None = eqx.field(static=True, default=None)
bm_dim: int | None = eqx.field(static=True, default=None)
@classmethod
def full(cls, coefficient, bm_dim: int | None = None):
return cls(coefficient=coefficient, diffusion_type="full", bm_dim=bm_dim)
@classmethod
def diag(cls, coefficient, bm_dim: int):
return cls(coefficient=coefficient, diffusion_type="diag", bm_dim=bm_dim)
@classmethod
def scalar(cls, coefficient, bm_dim: int):
return cls(coefficient=coefficient, diffusion_type="scalar", bm_dim=bm_dim) |
|
Could also do this for Drifts...and I think this may help us a lot down the road in adding new features. May want to think ahead a bit more around names to anticipate future dev |
… also push it up with a new resolvedDiffusion class)
DanWaxman
left a comment
There was a problem hiding this comment.
Didn't get through all the tests, but this looks like a major step! I don't like coercion or __setattr__ much. Also not sure that we should call a resolved meta-type "canonical."
Co-authored-by: Dan Waxman <dan.waxman1@gmail.com>
DanWaxman
left a comment
There was a problem hiding this comment.
Thanks! A few other small things that were remaining.
Co-authored-by: Dan Waxman <dan.waxman1@gmail.com>
|
Addressed missing |
DanWaxman
left a comment
There was a problem hiding this comment.
I think probably the last batch of comments :))
| # In a plate, parameter callables often expect batched parameters, so | ||
| # we resolve diffusion metadata using synthetic per-trajectory probes. | ||
| x_probe = _make_probe_state( | ||
| initial_condition=initial_condition, state_dim=inferred_state_dim | ||
| ) | ||
| u_probe = None if control_dim == 0 else jnp.zeros((control_dim,)) | ||
| t_probe = jnp.array(0.0) if self.t0 is None else self.t0 | ||
| resolved_state_evolution = _refine_continuous_state_evolution( | ||
| resolved_state_evolution, | ||
| x_probe=x_probe, | ||
| u_probe=u_probe, | ||
| t_probe=t_probe, | ||
| ) | ||
| self.state_evolution = resolved_state_evolution |
There was a problem hiding this comment.
I still don't think I understand this block, unfortunately. In particular, I still don't really see how it differs other than (as the earlier comment states) not doing validation. In which case, I think this whole block can be simplified?
(Sorry I know I'm making this PR a long process, just think having this block be understandable is very important)
There was a problem hiding this comment.
This is the part where we turn the CTSE into Discrete/StochasticCTSE; we use the probes in case it is a callable to check shapes so that we can detect/verify bm_dim as well.
I'm not entirely clear how to make sure I'm playing well with the plates...previously, there was infer_bm_dim code for _inside_plate True and False.
I'll refactor to try to streamline and clarify
There was a problem hiding this comment.
Sure, I guess my point is what's the real difference between here:
dynestyx/dynestyx/models/core.py
Lines 226 to 239 in d2f8564
and the non-plate path:
dynestyx/dynestyx/models/core.py
Lines 242 to 261 in d2f8564
except that one validates and one doesn't?
There was a problem hiding this comment.
Yeah I'm trying to consolidate that now. I think that question could be asked of the current main branch as well, right?
There was a problem hiding this comment.
This might've been some sloppy code from before (maybe it was just "early return before validation in plate" -> add in bm_dim inference -> add in probe states -> lots of duplicate code), but we should clean that up now if so, I think.
There was a problem hiding this comment.
agreed, just making sure we are understanding the situation similarly
There was a problem hiding this comment.
should we continue to not validate inside the plate?
There was a problem hiding this comment.
proposing:
if _inside_plate:
# Cannot validate shapes with batched parameters; trust the user.
# Infer observation_dim from observation model if not explicitly provided.
inferred_obs_dim = _infer_observation_dim_in_plate_context(
initial_condition=initial_condition,
observation_model=observation_model,
inferred_state_dim=inferred_state_dim,
control_dim=control_dim,
t_probe=self.t0,
observation_dim=observation_dim,
)
else:
x_probe, u_probe, t_probe = _make_probes()
if self.continuous_time:
_validate_continuous_state_evolution(
state_evolution=state_evolution,
state_dim=inferred_state_dim,
x_probe=x_probe,
u_probe=u_probe,
t_probe=t_probe,
)
else:
_validate_discrete_state_evolution_output_shape(
state_evolution=state_evolution,
state_dim=inferred_state_dim,
x_probe=x_probe,
u_probe=u_probe,
t_probe=t_probe,
)
obs_dist = observation_model(x_probe, u_probe, t_probe)
inferred_obs_dim = _infer_vector_dim_from_distribution(
obs_dist, "observation_model(x, u, t)"
)
_validate_observation_dim(observation_dim, inferred_obs_dim)
self.state_evolution = _resolve_continuous_state_evolution(
state_evolution,
)
self.state_dim = int(inferred_state_dim)
self.observation_dim = int(inferred_obs_dim)
self.control_dim = int(control_dim)
self.categorical_state = bool(inferred_categorical_state)
There was a problem hiding this comment.
Doesn't
self.state_evolution = _resolve_continuous_state_evolution(
state_evolution,
)
require the probe states? This looks good to me though, otherwise.
There was a problem hiding this comment.
I refactored to let it use _make_probe...will update in a bit.
Co-authored-by: Dan Waxman <dan.waxman1@gmail.com>
Addresses #212
Summary
This PR introduces a structured diffusion API for continuous-time models.
The main change is a new
dynestyx.models.diffusionsmodule with first-class diffusion objects:DiffusionFullDiffusionDiagonalDiffusionScalarDiffusionThese objects make the shape and interpretation of the SDE diffusion term explicit, instead of relying on loosely structured
diffusion_coefficientinputs.In particular, this PR:
ContinuousTimeStateEvolution/DynamicalModelto canonicalize and validate diffusion metadata centrallybm_dimbehavior:FullDiffusion:bm_dimis inferred from trailing matrix shape when omittedDiagonalDiffusion/ScalarDiffusion:bm_dimis explicit and must be1orstate_dimLTI_continuousand related examples/docs to use the new diffusion API directlybm_dim/plate behavior, and hierarchical smoke testsWhy
This change makes continuous-time model specification clearer and safer.
It also improves performance in simulation-oriented SDE paths when the diffusion has scalar or diagonal structure, since we can avoid treating every diffusion as a dense full matrix. The speedup figure below illustrates this benefit.
Just as importantly, invalid setups now fail earlier with clearer errors, especially in continuous cd-dynamax filtering where only square full diffusion is currently supported.
Notes
bm_dim == state_dim).Filter + Discretizer, but did not see meaningful gains there, so those backend-specific optimizations were not kept; will matter more with multi-step discretizers, and can leverage then.Tests
Added or updated coverage in:
tests/test_models_core.pytests/test_discretizers.pytests/test_bm_dim_plate.pytests/test_hierarchical_simulator_discretizer_smokes.pyBelow image showing speed-ups of solver using scalar/diag/full diffusion coefficient in SDE.
