-
Notifications
You must be signed in to change notification settings - Fork 5
Support structured diffusion specs in continuous-time models and enforce cd-dynamax full-diffusion constraints #214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 23 commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
021550d
support Array and lambda diffusions (special support for scalar/diag/…
mattlevine22 9652cbf
adding back padding of diffusion-coefficient in CD-Dynamax integration
mattlevine22 8c6726e
Update dynestyx/inference/integrations/cd_dynamax/utils.py
mattlevine22 38df332
Update dynestyx/models/diffusions.py
mattlevine22 aa50c38
docs: rename misleading test names to reflect bm_dim > state_dim cons…
Copilot ecd987e
plz lint
mattlevine22 bc0c252
Merge branch 'main' into ml-feature-212
mattlevine22 375cf59
get diffusion info at init
mattlevine22 0fa5667
give discretizer fallback if dynamicalModel init hadnt been run (supp…
mattlevine22 bf1f456
new Diffusion class with subtypes.
mattlevine22 ef5dab6
merge with main (plate upgrades)
mattlevine22 20d1d37
dont use cast, just assert bm_dim is not None (fight w linter...could…
mattlevine22 e9ea930
merge with main (smoothers)
mattlevine22 bc41705
please lint
mattlevine22 3f37300
rename as gram
mattlevine22 657a190
rename func to _coerce_to_param_dtype
mattlevine22 958a2a7
simplify edit to Quick example
mattlevine22 0e04441
dont change faq
mattlevine22 2cfab46
simplify notebook changes for PR
mattlevine22 90a2813
simplify notebook changes for PR
mattlevine22 c8c5cd8
simplify notebook changes for PR
mattlevine22 2e02f2f
improve API documentation for diffusions
mattlevine22 740750e
improve API documentation for diffusions
mattlevine22 2ebb45c
Update dynestyx/models/diffusions.py
mattlevine22 3612c92
simplify code and improve docs
mattlevine22 a4cee35
clarify resolve metadata docstring
mattlevine22 19a344b
Update dynestyx/inference/integrations/cd_dynamax/utils.py
mattlevine22 6d39004
probing and errors for cd-dynamax diffusion plumbing
mattlevine22 d2f8564
fix lint
mattlevine22 83641cb
Update dynestyx/models/core.py
mattlevine22 8501a17
streamline DynamicalModel init
mattlevine22 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,14 @@ | ||
| # Core Models | ||
|
|
||
| This page includes the developer-facing refined continuous-time state-evolution | ||
| classes: | ||
|
|
||
| - `DeterministicContinuousTimeStateEvolution` | ||
| - `StochasticContinuousTimeStateEvolution` | ||
|
|
||
| Most users should continue to construct `ContinuousTimeStateEvolution` directly and let | ||
| `DynamicalModel` canonicalize it internally. | ||
|
|
||
| ::: dynestyx.models.core | ||
| options: | ||
| filters: [] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| # Diffusions | ||
|
|
||
| This page documents the diffusion internals used by continuous-time solvers, | ||
| discretizers, and backend integrations. | ||
|
|
||
| ## Overview | ||
|
|
||
| Public model code should usually work with `FullDiffusion`, | ||
| `DiagonalDiffusion`, and `ScalarDiffusion` through | ||
| `ContinuousTimeStateEvolution(diffusion=...)`. | ||
|
|
||
| Developer-facing code sometimes needs the lower-level evaluation layer: | ||
|
|
||
| - `Diffusion.evaluate(...)` evaluates a structured diffusion object at a | ||
| concrete `(x, u, t)`. | ||
| - `EvaluatedDiffusion.as_matrix(...)` returns the corresponding matrix `L`. | ||
| - `EvaluatedDiffusion.apply(...)` applies that matrix to a Brownian increment. | ||
| - `resolve_metadata(...)` canonicalizes `bm_dim` and validates the coefficient | ||
| shape against `state_dim`. | ||
|
|
||
| In other words, the public classes describe the structure of the diffusion, | ||
| while `EvaluatedDiffusion` is the solver-facing object used after a concrete | ||
| state, control, and time have been chosen. | ||
|
|
||
| ## API | ||
|
|
||
| ### `EvaluatedDiffusion` | ||
|
|
||
| ::: dynestyx.models.diffusions.EvaluatedDiffusion | ||
| options: | ||
| show_root_heading: false | ||
|
|
||
| ### `Diffusion` | ||
|
|
||
| ::: dynestyx.models.diffusions.Diffusion | ||
| options: | ||
| show_root_heading: false | ||
| filters: [] | ||
|
|
||
| ### `FullDiffusion` | ||
|
|
||
| ::: dynestyx.models.diffusions.FullDiffusion | ||
| options: | ||
| show_root_heading: false | ||
| filters: [] | ||
|
|
||
| ### `DiagonalDiffusion` | ||
|
|
||
| ::: dynestyx.models.diffusions.DiagonalDiffusion | ||
| options: | ||
| show_root_heading: false | ||
| filters: [] | ||
|
|
||
| ### `ScalarDiffusion` | ||
|
|
||
| ::: dynestyx.models.diffusions.ScalarDiffusion | ||
| options: | ||
| show_root_heading: false | ||
| filters: [] |
11 changes: 10 additions & 1 deletion
11
docs/api_reference/public/models/core/continuous_time_state_evolution.md
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,16 @@ | ||
| # ContinuousTimeStateEvolution | ||
|
|
||
| `ContinuousTimeStateEvolution` is the public entry point for defining continuous-time | ||
| state evolution. Most users should instantiate this class directly and pass an optional | ||
| `diffusion=` built from [`FullDiffusion`](./diffusion.md), | ||
| [`DiagonalDiffusion`](./diffusion.md), or [`ScalarDiffusion`](./diffusion.md). | ||
|
|
||
| Internally, `DynamicalModel` canonicalizes continuous-time dynamics to refined | ||
| deterministic and stochastic subclasses. Those specialized classes are intended for | ||
| developer-facing integrations and are documented in the developer API rather than the | ||
| public tutorials. | ||
|
|
||
| ::: dynestyx.models.core.ContinuousTimeStateEvolution | ||
| options: | ||
| show_root_heading: false | ||
| show_root_toc_entry: false | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,233 @@ | ||
| # Diffusion | ||
|
|
||
| `Diffusion` objects define the stochastic term in a continuous-time state evolution | ||
|
|
||
| \[ | ||
| dx_t = f(x_t, u_t, t)\,dt + L(x_t, u_t, t)\,dW_t, | ||
| \] | ||
|
|
||
| where: | ||
|
|
||
| - \(x_t \in \mathbb{R}^{d_x}\) is the latent state, | ||
| - \(u_t\) is an optional control, | ||
| - \(W_t \in \mathbb{R}^{d_w}\) is Brownian motion, | ||
| - `bm_dim = d_w`, | ||
| - and \(L(x_t, u_t, t)\) is the diffusion coefficient. | ||
|
|
||
| Dynestyx exposes three structured diffusion classes: | ||
|
|
||
| - `FullDiffusion`: general matrix-valued diffusion \(L \in \mathbb{R}^{d_x \times d_w}\) | ||
| - `DiagonalDiffusion`: vector-valued diffusion \(v \in \mathbb{R}^{d_x}\) | ||
| - `ScalarDiffusion`: scalar-valued diffusion \(\sigma \in \mathbb{R}\) | ||
|
|
||
| In practice, many models use a constant diffusion coefficient. In those cases, | ||
| pass the matrix/vector/scalar value directly. Reserve callable diffusion | ||
| coefficients for cases where the coefficient genuinely depends on state, | ||
| control, or time. | ||
|
|
||
| ## `Diffusion` | ||
|
|
||
| `Diffusion` is the common base class for structured diffusion coefficients. | ||
| Most users should instantiate one of its public subclasses instead of using | ||
| `Diffusion` directly: | ||
|
|
||
| - Use `ScalarDiffusion` when one scalar scale is enough. This is the most common choice for isotropic diffusion. | ||
| - Use `DiagonalDiffusion` when each state coordinate should have its own scale but the loading remains axis-aligned. | ||
| - Use `FullDiffusion` when you need to specify a genuinely matrix-valued loading. | ||
|
|
||
| Each class accepts either: | ||
|
|
||
| - a constant coefficient, or | ||
| - a callable `(x, u, t) -> value`. | ||
|
|
||
| ### `FullDiffusion` | ||
|
|
||
| Use `FullDiffusion(coefficient, bm_dim=None)` when you want to specify the | ||
| matrix-valued diffusion coefficient directly. | ||
|
|
||
| Accepted `coefficient` forms: | ||
|
|
||
| - a constant array with trailing shape `(state_dim, bm_dim)`, or | ||
| - a callable `(x, u, t) -> array` with trailing shape `(state_dim, bm_dim)`. | ||
|
|
||
| If `coefficient` is constant, `bm_dim` is inferred automatically when omitted. | ||
|
|
||
| Mathematically, `FullDiffusion` represents | ||
|
|
||
| \[ | ||
| L(x_t, u_t, t) \in \mathbb{R}^{d_x \times d_w}. | ||
| \] | ||
|
|
||
| Example: | ||
|
|
||
| ```python | ||
| import jax.numpy as jnp | ||
| from dynestyx import ContinuousTimeStateEvolution, FullDiffusion | ||
|
|
||
| state_evolution = ContinuousTimeStateEvolution( | ||
| drift=lambda x, u, t: -x, | ||
| diffusion=FullDiffusion(jnp.eye(2)), | ||
| ) | ||
| ``` | ||
|
|
||
| ### `DiagonalDiffusion` | ||
|
|
||
| Use `DiagonalDiffusion(coefficient, bm_dim)` when the diffusion is naturally | ||
| parameterized by a vector of state-wise loadings. | ||
|
|
||
| Accepted `coefficient` forms: | ||
|
|
||
| - a constant vector with trailing shape `(state_dim,)`, or | ||
| - a callable `(x, u, t) -> array` with trailing shape `(state_dim,)`. | ||
|
|
||
| `bm_dim` is required and must be either `1` or `state_dim`. | ||
|
|
||
| Mathematically, `DiagonalDiffusion` represents a vector-valued coefficient | ||
|
|
||
| \[ | ||
| v(x_t, u_t, t) \in \mathbb{R}^{d_x}. | ||
| \] | ||
|
|
||
| If `bm_dim = d_x`, the vector is interpreted as a diagonal matrix: | ||
|
|
||
| \[ | ||
| L = \mathrm{diag}(v(x_t, u_t, t)). | ||
| \] | ||
|
|
||
| If `bm_dim = 1`, the same vector is interpreted as a column loading vector: | ||
|
|
||
| \[ | ||
| L = v(x_t, u_t, t) \in \mathbb{R}^{d_x \times 1}. | ||
| \] | ||
|
|
||
| Example: | ||
|
|
||
| ```python | ||
| import jax.numpy as jnp | ||
| from dynestyx import ContinuousTimeStateEvolution, DiagonalDiffusion | ||
|
|
||
| state_evolution = ContinuousTimeStateEvolution( | ||
| drift=lambda x, u, t: -x, | ||
| diffusion=DiagonalDiffusion(jnp.array([0.1, 0.2]), bm_dim=2), | ||
| ) | ||
| ``` | ||
|
|
||
| ### `ScalarDiffusion` | ||
|
|
||
| Use `ScalarDiffusion(coefficient, bm_dim)` when one scalar scale is enough. | ||
| This is usually the simplest choice for isotropic diffusion. | ||
|
|
||
| Accepted `coefficient` forms: | ||
|
|
||
| - a scalar, | ||
| - a constant array with trailing shape `(1,)`, or | ||
| - a callable `(x, u, t) -> scalar_or_length_1_array`. | ||
|
|
||
| `bm_dim` is required and must be either `1` or `state_dim`. | ||
|
|
||
| Mathematically, `ScalarDiffusion` represents a scalar-valued coefficient | ||
|
|
||
| \[ | ||
| \sigma(x_t, u_t, t) \in \mathbb{R}. | ||
| \] | ||
|
|
||
| If `bm_dim = d_x`, it is interpreted as isotropic independent noise: | ||
|
|
||
| \[ | ||
| L = \sigma(x_t, u_t, t)\,I_{d_w} | ||
| \] | ||
|
|
||
| with \(d_w = d_x\). | ||
|
|
||
| If `bm_dim = 1`, it is interpreted as a shared scalar driver applied equally to | ||
| every state coordinate: | ||
|
|
||
| \[ | ||
| L = \sigma(x_t, u_t, t)\,\mathbf{1}_{d_x}, | ||
| \] | ||
|
|
||
| viewed as a column vector in \(\mathbb{R}^{d_x \times 1}\). | ||
|
|
||
| Constant example: | ||
|
|
||
| ```python | ||
| from dynestyx import ContinuousTimeStateEvolution, ScalarDiffusion | ||
|
|
||
| state_evolution = ContinuousTimeStateEvolution( | ||
| drift=lambda x, u, t: -x, | ||
| diffusion=ScalarDiffusion(0.1, bm_dim=2), | ||
| ) | ||
| ``` | ||
|
|
||
| Callable example: | ||
|
|
||
| ```python | ||
| import jax.numpy as jnp | ||
| from dynestyx import ContinuousTimeStateEvolution, ScalarDiffusion | ||
|
|
||
| state_evolution = ContinuousTimeStateEvolution( | ||
| drift=lambda x, u, t: -x, | ||
| diffusion=ScalarDiffusion( | ||
| lambda x, u, t: 0.1 + 0.05 * jnp.tanh(x[0]), | ||
| bm_dim=2, | ||
| ), | ||
| ) | ||
| ``` | ||
|
|
||
| ## API | ||
|
|
||
| ### `Diffusion` | ||
|
|
||
| ::: dynestyx.models.diffusions.Diffusion | ||
| options: | ||
| show_root_heading: false | ||
| show_root_toc_entry: false | ||
| filters: | ||
| - "!^evaluate_value$" | ||
| - "!^resolve_metadata$" | ||
| - "!^evaluate$" | ||
| - "!^as_matrix$" | ||
| - "!^gram_matrix$" | ||
| - "!^apply$" | ||
|
|
||
| ### `FullDiffusion` | ||
|
|
||
| ::: dynestyx.models.diffusions.FullDiffusion | ||
| options: | ||
| show_root_heading: false | ||
| show_root_toc_entry: false | ||
| filters: | ||
| - "!^evaluate_value$" | ||
| - "!^resolve_metadata$" | ||
| - "!^evaluate$" | ||
| - "!^as_matrix$" | ||
| - "!^gram_matrix$" | ||
| - "!^apply$" | ||
|
|
||
| ### `DiagonalDiffusion` | ||
|
|
||
| ::: dynestyx.models.diffusions.DiagonalDiffusion | ||
| options: | ||
| show_root_heading: false | ||
| show_root_toc_entry: false | ||
| filters: | ||
| - "!^evaluate_value$" | ||
| - "!^resolve_metadata$" | ||
| - "!^evaluate$" | ||
| - "!^as_matrix$" | ||
| - "!^gram_matrix$" | ||
| - "!^apply$" | ||
|
|
||
| ### `ScalarDiffusion` | ||
|
|
||
| ::: dynestyx.models.diffusions.ScalarDiffusion | ||
| options: | ||
| show_root_heading: false | ||
| show_root_toc_entry: false | ||
| filters: | ||
| - "!^evaluate_value$" | ||
| - "!^resolve_metadata$" | ||
| - "!^evaluate$" | ||
| - "!^as_matrix$" | ||
| - "!^gram_matrix$" | ||
| - "!^apply$" | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,4 +9,3 @@ | |
| - ObservationModel | ||
| - Drift | ||
| - Potential | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.