-
Notifications
You must be signed in to change notification settings - Fork 5
Support structured diffusion specs in continuous-time models and enforce cd-dynamax full-diffusion constraints #214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
021550d
9652cbf
8c6726e
38df332
aa50c38
ecd987e
bc0c252
375cf59
0fa5667
bf1f456
ef5dab6
20d1d37
e9ea930
bc41705
3f37300
657a190
958a2a7
0e04441
2cfab46
90a2813
c8c5cd8
2e02f2f
740750e
2ebb45c
3612c92
a4cee35
19a344b
6d39004
d2f8564
83641cb
8501a17
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,40 +22,73 @@ | |
| LinearGaussianObservation, | ||
| LinearGaussianStateEvolution, | ||
| ) | ||
| from dynestyx.models.diffusions import diffusion_as_matrix, evaluate_diffusion | ||
|
|
||
| type SSMType = ContDiscreteNonlinearGaussianSSM | ContDiscreteNonlinearSSM | ||
|
|
||
|
|
||
| def _normalize_cd_dynamax_diffusion( | ||
| diffusion_coefficient, | ||
| state_evolution: ContinuousTimeStateEvolution, | ||
| state_dim: int, | ||
| ): | ||
| """Return a diffusion coeff compatible with cd-dynamax's EnKF SDE solve. | ||
|
|
||
| cd-dynamax's internal diffrax wrapper builds Brownian controls with shape | ||
| equal to `y0.shape` (state_dim). For non-square diffusion coefficients | ||
| (state_dim, bm_dim) with bm_dim != state_dim, pad/truncate columns so the | ||
| returned matrix is always (state_dim, state_dim). | ||
| equal to `y0.shape` (state_dim). For diffusion with `bm_dim < state_dim`, | ||
| pad trailing Brownian columns with zeros to match `(state_dim, state_dim)`. | ||
| Diffusion with `bm_dim > state_dim` is rejected. | ||
| """ | ||
|
|
||
| def _wrapped(x, u, t): | ||
| L = diffusion_coefficient(x, u, t) | ||
| if L.ndim == 1: | ||
| L = jnp.diag(L) | ||
| if L.ndim != 2: | ||
| diffusion = evaluate_diffusion( | ||
| state_evolution.diffusion_coefficient, | ||
| diffusion_type=state_evolution.diffusion_type, | ||
| bm_dim=state_evolution.bm_dim, | ||
| x=x, | ||
| u=u, | ||
| t=t, | ||
| state_dim=state_dim, | ||
| ) | ||
| L = diffusion_as_matrix(diffusion, state_dim=state_dim) | ||
| n_cols = L.shape[-1] | ||
| if n_cols > state_dim: | ||
| raise ValueError( | ||
| "diffusion_coefficient must return a vector or matrix for cd-dynamax." | ||
| "cd-dynamax continuous diffusion requires bm_dim <= state_dim. " | ||
| f"Got state_dim={state_dim}, bm_dim={n_cols}." | ||
| ) | ||
| n_cols = L.shape[-1] | ||
| if n_cols == state_dim: | ||
| return L | ||
| if n_cols < state_dim: | ||
| return jnp.pad(L, ((0, 0), (0, state_dim - n_cols))) | ||
| return L[:, :state_dim] | ||
| L = jnp.pad(L, ((0, 0), (0, state_dim - n_cols))) | ||
| return L | ||
|
|
||
| return _wrapped | ||
|
|
||
|
|
||
| def _validate_cd_dynamax_continuous_diffusion( | ||
| state_evolution: ContinuousTimeStateEvolution, | ||
| state_dim: int, | ||
| ) -> None: | ||
| """Eagerly validate diffusion shape constraints for cd-dynamax continuous filters.""" | ||
| if state_evolution.diffusion_coefficient is None: | ||
| return | ||
|
|
||
| probe_x = jnp.zeros(state_dim) | ||
| diffusion = evaluate_diffusion( | ||
| state_evolution.diffusion_coefficient, | ||
| diffusion_type=state_evolution.diffusion_type, | ||
| bm_dim=state_evolution.bm_dim, | ||
| x=probe_x, | ||
| u=None, | ||
| t=jnp.array(0.0), | ||
| state_dim=state_dim, | ||
| ) | ||
| bm_dim = diffusion_as_matrix(diffusion, state_dim=state_dim).shape[-1] | ||
| if bm_dim > state_dim: | ||
| raise ValueError( | ||
| "Continuous cd-dynamax filters require bm_dim <= state_dim. " | ||
| f"Got state_dim={state_dim}, bm_dim={bm_dim}." | ||
| ) | ||
|
||
|
|
||
|
|
||
| class _ConstantFunction(eqx.Module): | ||
| value: Any | ||
|
|
||
|
|
@@ -198,7 +231,16 @@ def dsx_to_cdlgssm_params(dsx_model: DynamicalModel) -> ParamsCDLGSSM: | |
|
|
||
| # Extract constant L and use inferred Brownian dimension. | ||
| x0 = jnp.zeros(dsx_model.state_dim) | ||
| L = state_evo.diffusion_coefficient(x0, None, jnp.array(0.0)) | ||
| diffusion = evaluate_diffusion( | ||
| state_evo.diffusion_coefficient, | ||
| diffusion_type=state_evo.diffusion_type, | ||
| bm_dim=state_evo.bm_dim, | ||
| x=x0, | ||
| u=None, | ||
|
DanWaxman marked this conversation as resolved.
Outdated
|
||
| t=jnp.array(0.0), | ||
| state_dim=dsx_model.state_dim, | ||
| ) | ||
| L = diffusion_as_matrix(diffusion, state_dim=dsx_model.state_dim) | ||
| if state_evo.bm_dim is None: | ||
| raise ValueError( | ||
| "state_evolution.bm_dim is not set on ContinuousTimeStateEvolution." | ||
|
|
@@ -273,8 +315,12 @@ def dsx_to_cd_dynamax( | |
| raise ValueError( | ||
| "state_evolution.bm_dim is not set on ContinuousTimeStateEvolution." | ||
| ) | ||
| _validate_cd_dynamax_continuous_diffusion( | ||
| state_evo, | ||
| dsx_model.state_dim, | ||
| ) | ||
| diffusion_coeff = _normalize_cd_dynamax_diffusion( | ||
| state_evo.diffusion_coefficient, | ||
| state_evo, | ||
| dsx_model.state_dim, | ||
| ) | ||
| shared_params.update( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.