Add distillation training recipe for CorrDiff#1533
Add distillation training recipe for CorrDiff#1533jialusui1102 wants to merge 10 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: jialusui1102 <jialusui1102@gmail.com>
Signed-off-by: jialusui1102 <jialusui1102@gmail.com>
Greptile SummaryThis PR adds distillation training support for CorrDiff using the FastGen framework, introducing a new Critical issues requiring fixes before merging:
Minor issues:
Important Files Changed
Reviews (1): Last reviewed commit: "update CHANEGLOG.md" | Re-trigger Greptile |
|
|
||
|
|
||
| # Train the CorrDiff model using the configurations in "conf/config_training.yaml" | ||
| @hydra.main(version_base="1.2", config_path="conf", config_name="config_distill_mini") |
There was a problem hiding this comment.
Non-existent default config name
The config_name="config_distill_mini" does not correspond to any file in conf/. The three configs added in this PR are config_distill_hrrr_mini_diffusion.yaml, config_distill_gefs_hrrr_diffusion.yaml, and config_distill_taiwan_diffusion.yaml. Running distill.py without an explicit --config-name override will fail immediately with a Hydra MissingConfigException.
The closest match to the apparent intent is config_distill_hrrr_mini_diffusion:
| @hydra.main(version_base="1.2", config_path="conf", config_name="config_distill_mini") | |
| @hydra.main(version_base="1.2", config_path="conf", config_name="config_distill_hrrr_mini_diffusion") |
| { | ||
| "precision": PRECISION_MAP[fp_optimizations], | ||
| "precision_infer": PRECISION_MAP[str(input_dtype)], | ||
| "input_shape": (diffusion_model.img_out_channels, *super_patch_shape), |
There was a problem hiding this comment.
AttributeError when
diffusion_model is None
diffusion_model is explicitly set to None (around line 403) when no diffusion_checkpoint_path is provided or the key is absent. The attribute access diffusion_model.img_out_channels on line 455 will raise AttributeError: 'NoneType' object has no attribute 'img_out_channels' in that case.
A similar crash exists a few lines earlier (line ~434): torch.compile(diffusion_model) when use_torch_compile=True and diffusion_model is None.
A guard is needed before this block, e.g.:
if diffusion_model is None:
raise ValueError(
"A diffusion checkpoint must be provided for distillation training. "
"Set cfg.distill.io.diffusion_checkpoint_path."
)| loss_valid = ( | ||
| (loss_valid.sum() / batch_size_per_gpu) | ||
| .cpu() | ||
| .item() | ||
| ) |
There was a problem hiding this comment.
Incorrect validation loss scaling
Training explicitly asserts loss.ndim == 0 (the loss returned by loss_fn is already a scalar). However, the validation path applies .sum() / batch_size_per_gpu to the same scalar output. loss_valid.sum() on a 0-dim tensor is a no-op, but the subsequent division by batch_size_per_gpu shrinks the value by that factor, making the logged validation loss batch_size_per_gpu times smaller than the comparable training loss. This makes the two metrics incomparable in TensorBoard/WandB.
| loss_valid = ( | |
| (loss_valid.sum() / batch_size_per_gpu) | |
| .cpu() | |
| .item() | |
| ) | |
| loss_valid = ( | |
| loss_valid.cpu() | |
| .item() | |
| ) |
There was a problem hiding this comment.
skip since that's the original code
| ) | ||
|
|
||
| else: | ||
| raise ValueError(f"Unknown sampling method {cfg.sampling.type}") |
There was a problem hiding this comment.
Wrong config key in error message
Every other branch in this if/elif chain uses cfg.sampler.type, but the fallback error message references cfg.sampling.type, which is a different (and likely non-existent) key. Accessing cfg.sampling.type will raise an omegaconf.errors.ConfigAttributeError instead of surfacing the intended ValueError.
| raise ValueError(f"Unknown sampling method {cfg.sampling.type}") | |
| raise ValueError(f"Unknown sampling method {cfg.sampler.type}") |
| torch._dynamo.reset() | ||
| # Increase the cache size limit | ||
| torch._dynamo.config.cache_size_limit = 264 # Set to a higher value | ||
| torch._dynamo.config.verbose = True # Enable verbose logging | ||
| torch._dynamo.config.suppress_errors = False # Forces the error to show all details | ||
| torch._logging.set_logs(recompiles=True, graph_breaks=True) |
There was a problem hiding this comment.
Global-scope torch._dynamo configuration with side effects
These five statements run at module import time (not inside main()), so they apply to any process that imports distill.py:
torch._dynamo.config.verbose = Trueemits verbose graph-break logs for all dynamo usage.torch._logging.set_logs(recompiles=True, graph_breaks=True)turns on extra logging globally.torch._dynamo.reset()discards any previously compiled caches.
This is especially problematic for multi-process training where workers import the module, and for test suites that import helper functions. These should be moved inside main() and ideally made conditional on a config flag (e.g., cfg.distill.perf.profile_mode).
There was a problem hiding this comment.
Keep it there for torch.compile
|
|
||
|
|
||
| class SCMModel(SCMBaseModel): | ||
| """-time Consistency Model with TrigFlow for CorrDiff distillation. |
There was a problem hiding this comment.
Truncated docstring for
SCMModel
The class docstring starts with -time Consistency Model, which looks like the beginning of the model name was accidentally deleted (probably "Straight-time" or "Stochastic-time").
| """-time Consistency Model with TrigFlow for CorrDiff distillation. | |
| """Straight-time Consistency Model with TrigFlow for CorrDiff distillation. |
| from hydra.core.hydra_config import HydraConfig | ||
| from omegaconf import DictConfig, OmegaConf | ||
| import torch | ||
| from torch.nn.parallel import DistributedDataParallel |
There was a problem hiding this comment.
Several symbols are imported but never referenced in distill.py:
DistributedDataParallel(line 27) —DDPWrapperfrom fastgen is used instead.RegressionLoss,RegressionLossCE(line 41) — onlyResidualLossis referenced.residual_lossandedm_precond_super_res(lines 246/266) are assigned but never consumed downstream.
Removing these will reduce confusion about which loss/preconditioner classes are actually active during distillation.
PhysicsNeMo Pull Request
Description
Added distillation recipe and helper functions for CorrDiff distillation training with fastgen framework
Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.