diff --git a/examples/weather/stormcast/test_training.py b/examples/weather/stormcast/test_training.py index 04efca2d0c..3f5a1daa0b 100644 --- a/examples/weather/stormcast/test_training.py +++ b/examples/weather/stormcast/test_training.py @@ -48,16 +48,19 @@ def _load_config(config_name: str) -> DictConfig: @pytest.fixture def cfg_regression(): + """Load the test regression U-Net config.""" return _load_config(config_name="test_regression_unet.yaml") @pytest.fixture def cfg_diffusion(): + """Load the test diffusion DiT config.""" return _load_config(config_name="test_diffusion.yaml") @pytest.fixture def cfg_diffusion_unet(): + """Load the test diffusion U-Net config.""" return _load_config(config_name="test_diffusion_unet.yaml") @@ -469,6 +472,56 @@ def _check_sigma_pattern(label: str) -> None: torch.distributed.barrier() +@pytest.mark.parametrize( + "sigma_data", + [0.3, 1.0, [0.2, 0.5, 0.8]], + ids=["scalar_0.3", "scalar_1.0", "per_channel"], +) +def test_sigma_data_preconditioner( + tmp_path: Path, + cfg_diffusion_unet: DictConfig, + *, + sigma_data, +): + """Verify that training.loss.sigma_data is forwarded to the EDM preconditioner.""" + dist = DistributedManager() + if dist.world_size > 1: + pytest.skip("Skipping: single-process test.") + + rundir = _setup_rundir(tmp_path, dist.world_size) + cfg = cfg_diffusion_unet.copy() + cfg.training.rundir = rundir + cfg.training.total_train_steps = 1 + cfg.training.loss.sigma_data = sigma_data + + if isinstance(sigma_data, list): + cfg.dataset.num_state_channels = len(sigma_data) + + t = trainer.Trainer(cfg) + + # The preconditioner's sigma_data buffer should match the loss config. + net = t.net + # Unwrap FSDP if needed. + raw_net = net.module if hasattr(net, "module") else net + precond_sd = raw_net.sigma_data + + if isinstance(sigma_data, (list, tuple)): + expected = torch.as_tensor(sigma_data, dtype=torch.float32).reshape(1, -1, 1, 1) + assert precond_sd.shape == expected.shape, ( + f"Per-channel sigma_data shape mismatch: {precond_sd.shape} vs {expected.shape}" + ) + assert torch.allclose(precond_sd.cpu().float(), expected), ( + f"Per-channel sigma_data value mismatch: {precond_sd} vs {expected}" + ) + else: + expected_val = float(sigma_data) + actual_val = float(precond_sd) + assert abs(actual_val - expected_val) < 1e-6, ( + f"Scalar sigma_data mismatch: preconditioner has {actual_val}, " + f"expected {expected_val}" + ) + + @pytest.mark.parametrize("net_architecture", ["unet", "dit"]) @pytest.mark.parametrize( "model_type", ["hybrid", "nowcasting", "downscaling", "unconditional"] diff --git a/examples/weather/stormcast/utils/trainer.py b/examples/weather/stormcast/utils/trainer.py index 0b84842357..759ac29221 100644 --- a/examples/weather/stormcast/utils/trainer.py +++ b/examples/weather/stormcast/utils/trainer.py @@ -403,6 +403,29 @@ def _setup_model(self) -> Module: # Build network model_cfg = self.cfg.model + model_hparams = dict(model_cfg.hyperparameters) + # Ensure the preconditioner uses the same sigma_data as the loss. + # Without this, EDMPrecond defaults to sigma_data=0.5 regardless of + # the value set in training.loss.sigma_data, causing a mismatch + # between the loss weighting and the preconditioning coefficients. + if self.loss_type == "edm": + loss_sigma_data = self.cfg.training.loss.sigma_data + precond_sd_override = model_hparams.get("sigma_data") + if isinstance(loss_sigma_data, (list, tuple)): + loss_sigma_data = torch.as_tensor( + list(loss_sigma_data), dtype=torch.float32 + )[None, :, None, None] + model_hparams.setdefault("sigma_data", loss_sigma_data) + if ( + precond_sd_override is not None + and precond_sd_override != loss_sigma_data + ): + self.logger.info( + f"sigma_data override: preconditioner uses {precond_sd_override} " + f"(from model.hyperparameters), loss uses {loss_sigma_data}" + ) + else: + self.logger.info(f"sigma_data: {model_hparams['sigma_data']}") if model_cfg.architecture == "unet": net = get_preconditioned_unet( name=self.net_name, @@ -412,7 +435,7 @@ def _setup_model(self) -> Module: lead_time_steps=self.lead_time_steps, amp_mode=self.enable_amp, use_apex_gn=self.use_apex_gn, - **model_cfg.hyperparameters, + **model_hparams, ) elif model_cfg.architecture == "dit": net = get_preconditioned_natten_dit( @@ -421,7 +444,7 @@ def _setup_model(self) -> Module: conditional_channels=num_condition_channels, scalar_condition_channels=len(self.scalar_cond_channels), lead_time_steps=self.lead_time_steps, - **model_cfg.hyperparameters, + **model_hparams, ) else: raise ValueError("model.architecture must be 'unet' or 'dit'") diff --git a/physicsnemo/models/dit/dit.py b/physicsnemo/models/dit/dit.py index 47485b17fc..c7532a6896 100644 --- a/physicsnemo/models/dit/dit.py +++ b/physicsnemo/models/dit/dit.py @@ -396,9 +396,7 @@ def _migrate_legacy_checkpoint( if not old_key.startswith(legacy_prefix): continue new_key = new_prefix + old_key[len(legacy_prefix) :] - if old_key == legacy_prefix + "freqs": - del state_dict[old_key] - elif new_key not in state_dict: + if new_key not in state_dict: state_dict[new_key] = state_dict.pop(old_key) def initialize_weights(self): diff --git a/physicsnemo/nn/module/embedding_layers.py b/physicsnemo/nn/module/embedding_layers.py index 9686e5f6fd..00dbaf1fed 100644 --- a/physicsnemo/nn/module/embedding_layers.py +++ b/physicsnemo/nn/module/embedding_layers.py @@ -127,7 +127,15 @@ def __init__( freqs = torch.arange(start=0, end=self.freq_embed_dim // 2, dtype=torch.float32) freqs = freqs / (self.freq_embed_dim // 2 - (1 if self.endpoint else 0)) freqs = (1 / self.max_positions) ** freqs - self.register_buffer("freqs", freqs, persistent=False) + self.register_buffer("freqs", freqs, persistent=True) + self.register_load_state_dict_pre_hook(self._fill_missing_freqs) + + @staticmethod + def _fill_missing_freqs(module, state_dict, prefix, *_args, **_kwargs): + """Backward compat: old checkpoints saved freqs as non-persistent.""" + key = prefix + "freqs" + if key not in state_dict: + state_dict[key] = module.freqs.clone() def forward(self, x): x = torch.outer(x, self.freqs)