diff --git a/CHANGELOG.md b/CHANGELOG.md index 090fb6a4a1..3f8ab042b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 creating Voronoi regions around seed points. BVH-accelerated. - Added support for 1D, 2D, and 3D neighborhood attention (natten) via `physicsnemo.nn.functional` interface, with full `ShardTensor` support. +- Adds distillation training recipe for CorrDiff (in `example/weather/corrdiff/`) ### Changed diff --git a/examples/weather/corrdiff/conf/base/distill/base_all.yaml b/examples/weather/corrdiff/conf/base/distill/base_all.yaml new file mode 100644 index 0000000000..6369819c11 --- /dev/null +++ b/examples/weather/corrdiff/conf/base/distill/base_all.yaml @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Hyperparameters +hp: + training_duration: 5120000 + # Training duration based on the number of processed samples + total_batch_size: 512 + # Total batch size + batch_size_per_gpu: "auto" + # Batch size per GPU + validation_batch_size_per_gpu: 1 + # validation batch size per GPU + scheduler_name: LambdaInverseSquareRootScheduler # modulus_default + scheduler: + modulus_default: + lr_decay: 1 + # LR decay rate + lr_rampup: 0 + # Rampup for learning rate, in number of samples + LambdaInverseSquareRootScheduler: + warm_up_steps: 0 + decay_steps: 7000 + LambdaLinearScheduler: + warm_up_steps: [0] # [1000] + # warm up in the first 1000 iterations + cycle_lengths: [10000000000] + # it means there is no lr decay + f_start: [1.0e-6] + f_max: [1.0] + f_min: [1.0] + grad_clip_threshold: 1000000 + # no gradient clipping for defualt non-patch-based training + optimizer_name: Adam + optimizer: + lr: 0.00002 + # Learning rate + weight_decay: 0. # old: 0.01 + betas: [0.9, 0.99] + eps: 1e-11 + +# Performance +perf: + fp_optimizations: amp-bf16 + # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] + # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} + dataloader_workers: 4 + # DataLoader worker processes + songunet_checkpoint_level: 0 # 0 means no checkpointing + # Gradient checkpointing level, value is number of layers to checkpoint + +# IO +io: + regression_checkpoint_path: null + # Where to load the regression checkpoint. Should be overridden. + print_progress_freq: 1000 + # How often to print progress + save_checkpoint_freq: 5000 + # How often to save the checkpoints, measured in number of processed samples + save_n_recent_checkpoints: -1 + # Set to a positive integer to only keep the most recent n checkpoints + validation_freq: 5000 + # how often to record the validation loss, measured in number of processed samples + validation_steps: 10 + # how many loss evaluations are used to compute the validation loss per checkpoint diff --git a/examples/weather/corrdiff/conf/base/distill/cwb.yaml b/examples/weather/corrdiff/conf/base/distill/cwb.yaml new file mode 100644 index 0000000000..83965f8d82 --- /dev/null +++ b/examples/weather/corrdiff/conf/base/distill/cwb.yaml @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - base_all + +# Hyperparameters +hp: + total_batch_size: 512 + # Total batch size + batch_size_per_gpu: 8 + # Batch size per GPU + validation_batch_size_per_gpu: 8 + # validation batch size per GPU + patching: null + scheduler_name: LambdaInverseSquareRootScheduler + optimizer: + fused: True + mode: cm + cm: + model: + use_ema: False + # multistep generation if larger than 1 (default: single-step generation) + student_sample_steps: 1 + # sampling type in multistep generation ('sde', 'ode') + student_sample_type: 'sde' + # precision for model/optimizer states and data - recommended to be float32 if precision_amp is not None + precision: 'float32' + # AMP during training - if None or equal to precision, AMP is disabled during training. + precision_amp: null + # AMP during inference - if None or equal to precision, AMP is disabled during inference. + precision_amp_infer: null + # AMP during en-/decoding (e.g., for VAEs or text encoders) - if None or equal to precision, AMP is disabled during en-/decoding. + precision_amp_enc: null + # FSDP2 precision for parameter storage and gradient reduction. + # If None, defaults to `precision`. Useful for storing params/grads in float32 while computing in bfloat16. + precision_fsdp: null + # whether to add the teacher model to the fsdp_dict + add_teacher_to_fsdp_dict: True + loss_config: + use_cd: False + huber_const: 0.06 + use_squared_l2: False + weighting_ct_loss: 'c_out_sq' + sample_t_cfg: + train_p_mean: -0.2 # 0.0 # from fastgen: -1.1 + train_p_std: 1.6 # 2 # from fastgen: 2.0 + min_t: 0.002 + t_list: null + max_t: 180 # 800 # from fastgen: 80 + min_r: 0. + sigma_data: 0.5 + quantize: False + time_dist_type: lognormal + block_kwargs: + dropout: 0.2 + callbacks: + ct_schedule: + q: 4.0 + ratio_limit: 0.9961 + # duration // (4 * 1000) + kimg_per_stage: 1280 + # ema: + # type: power + # gamma: 16.97 + # beta: 0.999 + scm: + model: + use_ema: False + # multistep generation if larger than 1 (default: single-step generation) + student_sample_steps: 1 + # sampling type in multistep generation ('sde', 'ode') + student_sample_type: 'sde' + # precision for model/optimizer states and data - recommended to be float32 if precision_amp is not None + precision: 'float32' + # AMP during training - if None or equal to precision, AMP is disabled during training. + precision_amp: null + # AMP during inference - if None or equal to precision, AMP is disabled during inference. + precision_amp_infer: null + # AMP during en-/decoding (e.g., for VAEs or text encoders) - if None or equal to precision, AMP is disabled during en-/decoding. + precision_amp_enc: null + # FSDP2 precision for parameter storage and gradient reduction. + # If None, defaults to `precision`. Useful for storing params/grads in float32 while computing in bfloat16. + precision_fsdp: null + # whether to add the teacher model to the fsdp_dict + add_teacher_to_fsdp_dict: True + loss_config: + # use consistency distillation + use_cd: False + # warm-up steps for tangent + tangent_warmup_steps: 1000 # 1/10 of original + # tangent normalization constant + tangent_warmup_const: 0.1 + # enable prior weighting + prior_weighting_enabled: True + # enable g_norm_spatial_invariance + g_norm_spatial_invariance: True + # enable divide_x_0_spatial_dim + divide_x_0_spatial_dim: True + # finite difference approx. for jvp + use_jvp_finite_diff: False + # finite difference step size + jvp_finite_diff_eps: 1e-4 + # use fp32 jvp + use_fp32_jvp: False + sample_t_cfg: + train_p_mean: -0.2 + train_p_std: 1.6 + sigma_data: 0.5 + min_t: 0.002 + t_list: null + max_t: 180 + quantize: False + # TODO(jberner): change dropout? + block_kwargs: + dropout: 0.2 + # callbacks: + # ema: + # type: power + # gamma: 6.94 # ema_10 + # beta: 0.999 diff --git a/examples/weather/corrdiff/conf/base/distill/gefs_hrrr.yaml b/examples/weather/corrdiff/conf/base/distill/gefs_hrrr.yaml new file mode 100644 index 0000000000..8e1e2837b0 --- /dev/null +++ b/examples/weather/corrdiff/conf/base/distill/gefs_hrrr.yaml @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - base_all + +# Hyperparameters +hp: + total_batch_size: 512 + # Total batch size + batch_size_per_gpu: 1 + # Batch size per GPU + validation_batch_size_per_gpu: 1 + # validation batch size per GPU + patching: + patch_shape_x: ??? + patch_shape_y: ??? + # Patch size. Patch-based sampling will be utilized if these dimensions + # differ from img_shape_x and img_shape_y. Needs to be overridden. + patch_num: ??? + # Number of patches from a single sample. Total number of patches is + # patch_num * batch_size_global. Should be overridden. + subpatch_num: 2 + # Number of patches to include in the super-patch for distillation. + overlap_pix: 32 + # Number of overlapping pixels between adjacent patches + window_function: KBD + # Window function to use for the window smoothing in superpatch training. + window_alpha: 1 + # Alpha for the window function. + scheduler_name: LambdaInverseSquareRootScheduler + optimizer: + fused: True + mode: cm + cm: + model: + use_ema: False + # multistep generation if larger than 1 (default: single-step generation) + student_sample_steps: 1 + # sampling type in multistep generation ('sde', 'ode') + student_sample_type: 'sde' + # precision for model/optimizer states and data - recommended to be float32 if precision_amp is not None + precision: 'float32' + # AMP during training - if None or equal to precision, AMP is disabled during training. + precision_amp: null + # AMP during inference - if None or equal to precision, AMP is disabled during inference. + precision_amp_infer: null + # AMP during en-/decoding (e.g., for VAEs or text encoders) - if None or equal to precision, AMP is disabled during en-/decoding. + precision_amp_enc: null + # FSDP2 precision for parameter storage and gradient reduction. + # If None, defaults to `precision`. Useful for storing params/grads in float32 while computing in bfloat16. + precision_fsdp: null + # whether to add the teacher model to the fsdp_dict + add_teacher_to_fsdp_dict: True + loss_config: + use_cd: False + huber_const: 0.06 + use_squared_l2: False + weighting_ct_loss: 'c_out_sq' + sample_t_cfg: + train_p_mean: -0.2 # 0.0 # from fastgen: -1.1 + train_p_std: 1.6 # 2 # from fastgen: 2.0 + min_t: 0.002 + t_list: null + max_t: 180 # 800 # from fastgen: 80 + min_r: 0. + sigma_data: 0.5 + quantize: False + time_dist_type: lognormal + block_kwargs: + dropout: 0.2 + callbacks: + ct_schedule: + q: 4.0 + ratio_limit: 0.9961 + # duration // (4 * 1000) + kimg_per_stage: 1280 + # ema: + # type: power + # gamma: 16.97 + # beta: 0.999 + scm: + model: + use_ema: False + # multistep generation if larger than 1 (default: single-step generation) + student_sample_steps: 1 + # sampling type in multistep generation ('sde', 'ode') + student_sample_type: 'sde' + # precision for model/optimizer states and data - recommended to be float32 if precision_amp is not None + precision: 'float32' + # AMP during training - if None or equal to precision, AMP is disabled during training. + precision_amp: null + # AMP during inference - if None or equal to precision, AMP is disabled during inference. + precision_amp_infer: null + # AMP during en-/decoding (e.g., for VAEs or text encoders) - if None or equal to precision, AMP is disabled during en-/decoding. + precision_amp_enc: null + # FSDP2 precision for parameter storage and gradient reduction. + # If None, defaults to `precision`. Useful for storing params/grads in float32 while computing in bfloat16. + precision_fsdp: null + # whether to add the teacher model to the fsdp_dict + add_teacher_to_fsdp_dict: True + loss_config: + # use consistency distillation + use_cd: False + # warm-up steps for tangent + tangent_warmup_steps: 1000 # 1/10 of original + # tangent normalization constant + tangent_warmup_const: 0.1 + # enable prior weighting + prior_weighting_enabled: True + # enable g_norm_spatial_invariance + g_norm_spatial_invariance: True + # enable divide_x_0_spatial_dim + divide_x_0_spatial_dim: True + sample_t_cfg: + train_p_mean: -0.2 + train_p_std: 1.6 + sigma_data: 0.5 + min_t: 0.002 + t_list: null + max_t: 180 + # TODO(jberner): change dropout? + block_kwargs: + dropout: 0.2 + # callbacks: + # ema: + # type: power + # gamma: 6.94 # ema_10 + # beta: 0.999 diff --git a/examples/weather/corrdiff/conf/base/distill/hrrr_mini.yaml b/examples/weather/corrdiff/conf/base/distill/hrrr_mini.yaml new file mode 100644 index 0000000000..ca6948ff09 --- /dev/null +++ b/examples/weather/corrdiff/conf/base/distill/hrrr_mini.yaml @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - base_all +# Hyperparameters +hp: + total_batch_size: 256 + # Total batch size + validation_batch_size_per_gpu: 4 + # validation batch size per GPU + patching: null + scheduler_name: LambdaLinearScheduler + optimizer_name: RAdam + optimizer: + lr: 0.0001 + eps: 1e-08 + mode: cm + cm: + model: + use_ema: False + # multistep generation if larger than 1 (default: single-step generation) + student_sample_steps: 1 + # sampling type in multistep generation ('sde', 'ode') + student_sample_type: 'sde' + # precision for model/optimizer states and data - recommended to be float32 if precision_amp is not None + precision: 'float32' + # AMP during training - if None or equal to precision, AMP is disabled during training. + precision_amp: null + # AMP during inference - if None or equal to precision, AMP is disabled during inference. + precision_amp_infer: null + # AMP during en-/decoding (e.g., for VAEs or text encoders) - if None or equal to precision, AMP is disabled during en-/decoding. + precision_amp_enc: null + # FSDP2 precision for parameter storage and gradient reduction. + # If None, defaults to `precision`. Useful for storing params/grads in float32 while computing in bfloat16. + precision_fsdp: null + # whether to add the teacher model to the fsdp_dict + add_teacher_to_fsdp_dict: True + loss_config: + use_cd: False + huber_const: 1e-8 + use_squared_l2: False + weighting_ct_loss: 'default' + sample_t_cfg: + train_p_mean: -0.2 + train_p_std: 1.6 + min_t: 0.002 + t_list: null + max_t: 180 + min_r: 0. + sigma_data: 0.5 + quantize: False + time_dist_type: lognormal + block_kwargs: + dropout: 0.2 + callbacks: + ct_schedule: + q: 2.0 + ratio_limit: 0.999 + # duration // (8 * 1000) + kimg_per_stage: 2000 + # ema: + # beta: 0.999 + scm: # TODO(jberner): set betas to (0.9, 0.99) + model: + use_ema: False + # multistep generation if larger than 1 (default: single-step generation) + student_sample_steps: 1 + # sampling type in multistep generation ('sde', 'ode') + student_sample_type: 'sde' + # precision for model/optimizer states and data - recommended to be float32 if precision_amp is not None + precision: 'float32' + # AMP during training - if None or equal to precision, AMP is disabled during training. + precision_amp: null + # AMP during inference - if None or equal to precision, AMP is disabled during inference. + precision_amp_infer: null + # AMP during en-/decoding (e.g., for VAEs or text encoders) - if None or equal to precision, AMP is disabled during en-/decoding. + precision_amp_enc: null + # FSDP2 precision for parameter storage and gradient reduction. + # If None, defaults to `precision`. Useful for storing params/grads in float32 while computing in bfloat16. + precision_fsdp: null + # whether to add the teacher model to the fsdp_dict + add_teacher_to_fsdp_dict: True + loss_config: + # use consistency distillation + use_cd: False + # warm-up steps for tangent + tangent_warmup_steps: 10000 + # tangent normalization constant + tangent_warmup_const: 0.1 + # enable prior weighting + prior_weighting_enabled: True + # enable g_norm_spatial_invariance + g_norm_spatial_invariance: True + # enable divide_x_0_spatial_dim + divide_x_0_spatial_dim: True + # finite difference approx. for jvp + use_jvp_finite_diff: False + sample_t_cfg: + train_p_mean: 0.0 + train_p_std: 1.2 + sigma_data: 0.5 + min_t: 0.002 + t_list: null + max_t: 180 + block_kwargs: + dropout: 0.2 + # callbacks: + # ema: + # type: "halflife" + dmd2: # TODO(jberner): change weight decay to 0.01 and opt_name to AdamW + model: + use_ema: False + sample_t_cfg: + train_p_mean: 0.0 + train_p_std: 1.2 + sigma_data: 0.5 + min_t: 0.002 + t_list: null + max_t: 180 + student_update_freq: 5 + guidance_scale: 1.0 + gan_loss_weight_gen: 0.001 + block_kwargs: + dropout: 0.0 diff --git a/examples/weather/corrdiff/conf/base/generation/base_all.yaml b/examples/weather/corrdiff/conf/base/generation/base_all.yaml index 194b4a57fe..c9275a9fce 100644 --- a/examples/weather/corrdiff/conf/base/generation/base_all.yaml +++ b/examples/weather/corrdiff/conf/base/generation/base_all.yaml @@ -17,6 +17,7 @@ defaults: - sampler: stochastic # Recommended is stochastic sampler. Change to deterministic if needed. + # Change to `few-step` sampler if running inference with distillation checkpoint. num_ensembles: ??? # Number of ensembles to generate per input. Should be overridden. diff --git a/examples/weather/corrdiff/conf/base/generation/sampler/few-step.yaml b/examples/weather/corrdiff/conf/base/generation/sampler/few-step.yaml new file mode 100644 index 0000000000..4ea3be9697 --- /dev/null +++ b/examples/weather/corrdiff/conf/base/generation/sampler/few-step.yaml @@ -0,0 +1,23 @@ +# @package _global_.sampler + +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +type: few-step +sigma_max: 800 +# Max noise level (adjust based on what's supported by the network) +sigma_mid: null +# Intermediate noise levels \ No newline at end of file diff --git a/examples/weather/corrdiff/conf/config_distill_gefs_hrrr_diffusion.yaml b/examples/weather/corrdiff/conf/config_distill_gefs_hrrr_diffusion.yaml new file mode 100644 index 0000000000..f8551b354a --- /dev/null +++ b/examples/weather/corrdiff/conf/config_distill_gefs_hrrr_diffusion.yaml @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: false + name: gefs_hrrr_diffusion_distill + run: + dir: ./output/${hydra:job.name} + searchpath: + - pkg://conf/base # Do not modify + +# Base parameters for dataset, model, training, and validation +defaults: + + - dataset: gefs_hrrr + # The dataset type for training. + # Accepted values: + # `gefs_hrrr`: full GEFS-HRRR dataset for continental US. + # `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments. + # `cwb`: full CWB dataset for Taiwan. + # `custom`: user-defined dataset. Parameters need to be specified below. + + - model: lt_aware_patched_diffusion + # The model type. + # Accepted values: + # `regression`: a regression UNet for deterministic predictions + # `lt_aware_ce_regression`: similar to `regression` but with lead time + # conditioning + # `diffusion`: a diffusion UNet for residual predictions + # `patched_diffusion`: a more memory-efficient diffusion model + # `lt_aware_patched_diffusion`: similar to `patched_diffusion` but + # with lead time conditioning + + - model_size: normal + # The model size configuration. + # Accepted values: + # `normal`: normal model size + # `mini`: smaller model size for fast experiments + + - distill: ${dataset} + # The base distillation parameters. Determined by the model type. + + +# Dataset parameters. Used for `custom` dataset type. +# Modify or add below parameters that should be passed as argument to the +# user-defined dataset class. +dataset: + data_path: ./data + # Path to .nc data file + stats_path: ./data/stats.json + # Path to json stats file + +# Distillation parameters +distill: + hp: + training_duration: 10000000 + # Training duration based on the number of processed samples + + io: + regression_checkpoint_path: + # Path to load the regression checkpoint + diffusion_checkpoint_path: + # Path to load the diffusion checkpoint + +# Parameters for wandb logging +wandb: + mode: offline + # Configure whether to use wandb: "offline", "online", "disabled" + results_dir: "./wandb" + # Directory to store wandb results + watch_model: false + # If true, wandb will track model parameters and gradients diff --git a/examples/weather/corrdiff/conf/config_distill_hrrr_mini_diffusion.yaml b/examples/weather/corrdiff/conf/config_distill_hrrr_mini_diffusion.yaml new file mode 100644 index 0000000000..7b90a077bd --- /dev/null +++ b/examples/weather/corrdiff/conf/config_distill_hrrr_mini_diffusion.yaml @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: false + name: hrrr_mini_diffusion_distill + run: + dir: ./output/${hydra:job.name} + searchpath: + - pkg://conf/base # Do not modify + +# Base parameters for dataset, model, training, and validation +defaults: + + - dataset: hrrr_mini + # The dataset type for training. + # Accepted values: + # `gefs_hrrr`: full GEFS-HRRR dataset for continental US. + # `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments. + # `cwb`: full CWB dataset for Taiwan. + # `custom`: user-defined dataset. Parameters need to be specified below. + + - model: diffusion + # The model type. + # Accepted values: + # `regression`: a regression UNet for deterministic predictions + # `lt_aware_ce_regression`: similar to `regression` but with lead time + # conditioning + # `diffusion`: a diffusion UNet for residual predictions + # `patched_diffusion`: a more memory-efficient diffusion model + # `lt_aware_patched_diffusion`: similar to `patched_diffusion` but + # with lead time conditioning + + - model_size: mini + # The model size configuration. + # Accepted values: + # `normal`: normal model size + # `mini`: smaller model size for fast experiments + + - distill: ${dataset} + # The base distillation parameters. Determined by the model type. + + +# Dataset parameters. Used for `custom` dataset type. +# Modify or add below parameters that should be passed as argument to the +# user-defined dataset class. +dataset: + data_path: ./data/hrrr_mini/hrrr_mini_train.nc + # Path to .nc data file + stats_path: ./data/hrrr_mini/stats.json + # Path to json stats file + +# Distillation parameters +distill: + hp: + training_duration: 8000000 + # Training duration based on the number of processed samples + io: + print_progress_freq: 10000 + regression_checkpoint_path: + # Path to load the regression checkpoint + diffusion_checkpoint_path: + # Path to load the diffusion checkpoint + +# Parameters for wandb logging +wandb: + mode: offline + # Configure whether to use wandb: "offline", "online", "disabled" + results_dir: "./wandb" + # Directory to store wandb results + watch_model: false + # If true, wandb will track model parameters and gradients diff --git a/examples/weather/corrdiff/conf/config_distill_taiwan_diffusion.yaml b/examples/weather/corrdiff/conf/config_distill_taiwan_diffusion.yaml new file mode 100644 index 0000000000..71f676ce1e --- /dev/null +++ b/examples/weather/corrdiff/conf/config_distill_taiwan_diffusion.yaml @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: false + name: taiwan_diffusion_distill + run: + dir: ./output/${hydra:job.name} + searchpath: + - pkg://conf/base # Do not modify + +# Base parameters for dataset, model, training, and validation +defaults: + + - dataset: cwb + # The dataset type for training. + # Accepted values: + # `gefs_hrrr`: full GEFS-HRRR dataset for continental US. + # `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments. + # `cwb`: full CWB dataset for Taiwan. + # `custom`: user-defined dataset. Parameters need to be specified below. + + - model: diffusion + # The model type. + # Accepted values: + # `regression`: a regression UNet for deterministic predictions + # `lt_aware_ce_regression`: similar to `regression` but with lead time + # conditioning + # `diffusion`: a diffusion UNet for residual predictions + # `patched_diffusion`: a more memory-efficient diffusion model + # `lt_aware_patched_diffusion`: similar to `patched_diffusion` but + # with lead time conditioning + + - model_size: normal + # The model size configuration. + # Accepted values: + # `normal`: normal model size + # `mini`: smaller model size for fast experiments + + - distill: ${dataset} + # The base distillation parameters. Determined by the model type. + + +# Dataset parameters. Used for `custom` dataset type. +# Modify or add below parameters that should be passed as argument to the +# user-defined dataset class. +dataset: + data_path: ./data/2023-01-24-cwb-4years.zarr + +model: + hr_mean_conditioning: false + # High-res mean (regression's output) as additional condition + +# Distillation parameters +distill: + hp: + training_duration: 200000000 + # Training duration based on the number of processed samples + lr_rampup: 10000000 + # Rampup for learning rate, in number of samples + io: + regression_checkpoint_path: + # Path to load the regression checkpoint + diffusion_checkpoint_path: + # Path to load the diffusion checkpoint + +# Additional parameters for validation +validation: + train: false + all_times: false + +# Parameters for wandb logging +wandb: + mode: offline + # Configure whether to use wandb: "offline", "online", "disabled" + results_dir: "./wandb" + # Directory to store wandb results + watch_model: false + # If true, wandb will track model parameters and gradients \ No newline at end of file diff --git a/examples/weather/corrdiff/distill.py b/examples/weather/corrdiff/distill.py new file mode 100644 index 0000000000..d9d73f5274 --- /dev/null +++ b/examples/weather/corrdiff/distill.py @@ -0,0 +1,1171 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +from contextlib import nullcontext + +import psutil +import hydra +from hydra.utils import to_absolute_path +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig, OmegaConf +import torch +from torch.nn.parallel import DistributedDataParallel +from torch.utils.tensorboard import SummaryWriter +from torch.optim.lr_scheduler import LambdaLR +from scipy.signal import windows +import nvtx +import wandb +from typing import Callable +import numpy as np +import gc + +from physicsnemo import Module +from physicsnemo.diffusion.preconditioners import EDMPrecondSuperResolution + +from physicsnemo.distributed import DistributedManager +from physicsnemo.diffusion.metrics import RegressionLoss, ResidualLoss, RegressionLossCE +from physicsnemo.diffusion.multi_diffusion import RandomPatching2D +from physicsnemo.utils.logging.wandb import initialize_wandb +from physicsnemo.utils.logging import PythonLogger, RankZeroLoggingWrapper +from physicsnemo.utils import ( + load_checkpoint, + save_checkpoint, + get_checkpoint_dir, +) +from physicsnemo.experimental.metrics.diffusion import tEDMResidualLoss +from physicsnemo.experimental.models.diffusion.preconditioning import ( + tEDMPrecondSuperRes, +) + +from fastgen.callbacks.ct_schedule import CTScheduleCallback +from fastgen.callbacks.ema import EMACallback +from fastgen.utils.distributed.ddp import DDPWrapper +from fastgen.utils import lr_scheduler +from helpers.distill_helpers import ( + MODEL_MAP, + PRECISION_MAP, + DistillLoss, + get_scheduler, + get_window_function, +) + + +from datasets.dataset import init_train_valid_datasets_from_config, register_dataset +from helpers.train_helpers import ( + set_patch_shape, + set_seed, + configure_cuda_for_consistent_precision, + compute_num_accumulation_rounds, + handle_and_clip_gradients, + is_time_for_periodic_task, +) + +torch._dynamo.reset() +# Increase the cache size limit +torch._dynamo.config.cache_size_limit = 264 # Set to a higher value +torch._dynamo.config.verbose = True # Enable verbose logging +torch._dynamo.config.suppress_errors = False # Forces the error to show all details +torch._logging.set_logs(recompiles=True, graph_breaks=True) + + +def checkpoint_list(path, suffix=".mdlus"): + """Helper function to return sorted list, in ascending order, of checkpoints in a path""" + checkpoints = [] + for file in os.listdir(path): + if file.endswith(suffix): + # Split the filename and extract the index + try: + index = int(file.split(".")[-2]) + checkpoints.append((index, file)) + except ValueError: + continue + + # Sort by index and return filenames + checkpoints.sort(key=lambda x: x[0]) + return [file for _, file in checkpoints] + + +# Define safe CUDA profiler tools that fallback to no-ops when CUDA is not available +def cuda_profiler(): + """ + Safe CUDA profiler tool that falls back to no-op when CUDA is not available. + """ + if torch.cuda.is_available(): + return torch.cuda.profiler.profile() + else: + return nullcontext() + + +def cuda_profiler_start(): + """ + Start CUDA profiler. + """ + if torch.cuda.is_available(): + torch.cuda.profiler.start() + + +def cuda_profiler_stop(): + """ + Stop CUDA profiler. + """ + if torch.cuda.is_available(): + torch.cuda.profiler.stop() + + +def profiler_emit_nvtx(): + """ + Emit NVTX markers for CUDA profiler. + """ + if torch.cuda.is_available(): + return torch.autograd.profiler.emit_nvtx() + else: + return nullcontext() + + +# Distill the CorrDiff Diffusion model +@hydra.main( + version_base="1.2", config_path="conf", config_name="config_distill_mini_diffusion" +) +def main(cfg: DictConfig) -> None: + """ + Entry point for CorrDiff distillation training. + """ + # Initialize distributed environment for training + DistributedManager.initialize() + dist = DistributedManager() + + # Initialize loggers + if dist.rank == 0: + writer = SummaryWriter(log_dir="tensorboard") + logger = PythonLogger("main") # General python logger + logger0 = RankZeroLoggingWrapper(logger, dist) # Rank 0 logger + initialize_wandb( + project="Modulus-Launch", + entity="Modulus", + name=f"CorrDiff-Training-{HydraConfig.get().job.name}", + group="CorrDiff-DDP-Group", + mode=cfg.wandb.mode, + config=OmegaConf.to_container(cfg), + results_dir=cfg.wandb.results_dir, + ) + + # Resolve and parse configs + OmegaConf.resolve(cfg) + dataset_cfg = OmegaConf.to_container(cfg.dataset) # TODO needs better handling + + # Register custom dataset if specified in config + register_dataset(cfg.dataset.type) + logger0.info(f"Using dataset: {cfg.dataset.type}") + + if hasattr(cfg, "validation"): + validation = True + validation_dataset_cfg = OmegaConf.to_container(cfg.validation) + else: + validation = False + validation_dataset_cfg = None + fp_optimizations = cfg.distill.perf.fp_optimizations + songunet_checkpoint_level = cfg.distill.perf.songunet_checkpoint_level + fp16 = fp_optimizations == "fp16" + enable_amp = fp_optimizations.startswith("amp") + amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16 + logger.info(f"Saving the outputs in {os.getcwd()}") + checkpoint_dir = get_checkpoint_dir( + str(cfg.distill.io.get("checkpoint_dir", ".")), cfg.model.name + ) + if cfg.distill.hp.batch_size_per_gpu == "auto": + cfg.distill.hp.batch_size_per_gpu = ( + cfg.distill.hp.total_batch_size // dist.world_size + ) + + # Load the current number of images for resuming + try: + cur_nimg = load_checkpoint( + path=checkpoint_dir, + ) + except Exception: + cur_nimg = 0 + + # Distillation Callbacks + callbacks = [] + distill_cfg = getattr(cfg.distill.hp, cfg.distill.hp.mode) + callbacks_cfg = distill_cfg.get("callbacks", {}) + if "ema" in callbacks_cfg: + callbacks.append(EMACallback(**distill_cfg.callbacks.ema)) + if "ct_schedule" in callbacks_cfg: + callbacks.append( + CTScheduleCallback( + **callbacks_cfg.ct_schedule, batch_size=cfg.distill.hp.total_batch_size + ) + ) + for callback in callbacks: + callback.on_app_begin() + + # Set seeds and configure CUDA and cuDNN settings to ensure consistent precision + set_seed(dist.rank + cur_nimg) + configure_cuda_for_consistent_precision() + + # Instantiate the dataset + for callback in callbacks: + callback.on_dataloader_init_start(None) + + # Instantiate the dataset + data_loader_kwargs = { + "pin_memory": True, + "num_workers": cfg.distill.perf.dataloader_workers, + "prefetch_factor": 2 if cfg.distill.perf.dataloader_workers > 0 else None, + } + ( + dataset, + dataset_iterator, + validation_dataset, + validation_dataset_iterator, + ) = init_train_valid_datasets_from_config( + dataset_cfg, + data_loader_kwargs, + batch_size=cfg.distill.hp.batch_size_per_gpu, + seed=0, + validation_dataset_cfg=validation_dataset_cfg, + validation=validation, + sampler_start_idx=cur_nimg, + ) + # Callbacks + for callback in callbacks: + callback.on_dataloader_init_end( + None, dataset_iterator, validation_dataset_iterator + ) + + # Parse image configuration & update model args + dataset_channels = len(dataset.input_channels()) + img_in_channels = dataset_channels + img_shape = dataset.image_shape() + img_out_channels = len(dataset.output_channels()) + if cfg.model.hr_mean_conditioning: + img_in_channels += img_out_channels + + # Handle distribution type + distribution = getattr(cfg.distill.hp, "distribution", None) + student_t_nu = getattr(cfg.distill.hp, "student_t_nu", None) + residual_loss, edm_precond_super_res = ResidualLoss, EDMPrecondSuperResolution + if distribution is not None and cfg.model.name not in [ + "diffusion", + "patched_diffusion", + "lt_aware_patched_diffusion", + ]: + raise ValueError( + f"cfg.distill.distribution should only be specified for diffusion models." + ) + if distribution not in ["normal", "student_t", None]: + raise ValueError(f"Invalid distribution {distribution}") + if distribution == "student_t": + if student_t_nu is None: + raise ValueError( + "student_t_nu must be provided in cfg.distill.hp.student_t_nu for student_t distribution" + ) + elif student_t_nu <= 2: + raise ValueError(f"Expected nu > 2, but got {student_t_nu}.") + # Reassign models and class for student-t distribution + else: + residual_loss, edm_precond_super_res = tEDMResidualLoss, tEDMPrecondSuperRes + logger0.info( + f"Using student-t distribution with nu={student_t_nu}. " + f"This is an experimental feature and APIs may change without notice." + ) + + # Parse P_mean and P_std + P_mean = getattr(cfg.distill.hp, "P_mean", None) + P_std = getattr(cfg.distill.hp, "P_std", None) + + # Handle patch shape + if cfg.model.name == "lt_aware_ce_regression": + prob_channels = dataset.get_prob_channel_index() + else: + prob_channels = None + + # Parse the patch shape - superpatch vs patch training for distillation + is_superpatch = False + if cfg.distill.hp.get("patching", None) is not None: + patch_shape_x = cfg.distill.hp.patching.patch_shape_x + patch_shape_y = cfg.distill.hp.patching.patch_shape_y + # compute super-patch shape for distillation + subpatch_num = cfg.distill.hp.patching.get("subpatch_num", 2) + overlap_pix = cfg.distill.hp.patching.get("overlap_pix", 32) + super_patch_shape_x = subpatch_num * (patch_shape_x - overlap_pix) + overlap_pix + super_patch_shape_y = subpatch_num * (patch_shape_y - overlap_pix) + overlap_pix + patching_cfg = { + "patch_shape": (patch_shape_y, patch_shape_x), + "overlap_pix": overlap_pix, + } + is_superpatch = True + else: + patch_shape_x = None + patch_shape_y = None + super_patch_shape_x = None + super_patch_shape_y = None + patching_cfg = {} + if ( + super_patch_shape_x + and super_patch_shape_y + and super_patch_shape_y >= img_shape[0] + and super_patch_shape_x >= img_shape[1] + ): + logger0.warning( + f"Patch shape {super_patch_shape_y}x{super_patch_shape_x} is larger than \ + the image shape {img_shape[0]}x{img_shape[1]}. Patching will not be used." + ) + super_patch_shape = (super_patch_shape_y, super_patch_shape_x) + use_patching, img_shape, super_patch_shape = set_patch_shape( + img_shape, super_patch_shape + ) + if use_patching: + # Utility to perform patches extraction and batching + patching = RandomPatching2D( + img_shape=img_shape, + patch_shape=super_patch_shape, + patch_num=cfg.distill.hp.patching.get("patch_num", 1), + ) + logger0.info( + f"Patch-based training enabled with patch shape {super_patch_shape} and patch num {cfg.distill.hp.patching.get('patch_num', 1)}." + ) + else: + patching = None + logger0.info("Patch-based training disabled") + + # set window function with superpatch + window = None + if is_superpatch: + window_function = cfg.distill.hp.patching.get("window_function", None) + window_alpha = cfg.distill.hp.patching.get("window_alpha", 1) + if window_function is not None: + logger0.info( + f"Enabling window function {window_function} with alpha {window_alpha} in superpatch training" + ) + window = get_window_function( + patch_shape_x=patch_shape_x, + patch_shape_y=patch_shape_y, + window_alpha=window_alpha, + type=window_function, + dtype=torch.float32, + device=dist.device, + ) + window = window.reshape((1, 1, window.shape[0], window.shape[1])) + else: + logger0.info("Window function is not used with regular patch training") + + # interpolate global channel if patch-based model is used + if use_patching: + img_in_channels += dataset_channels + + # Instantiate the model and move to device. + model_args = { # default parameters for all networks + "img_out_channels": img_out_channels, + "img_resolution": list(img_shape), + "use_fp16": fp16, + "checkpoint_level": songunet_checkpoint_level, + } + if student_t_nu is not None: + model_args["nu"] = student_t_nu + if cfg.model.name == "lt_aware_ce_regression": + model_args["prob_channels"] = prob_channels + if hasattr(cfg.model, "model_args"): # override defaults from config file + model_args.update(OmegaConf.to_container(cfg.model.model_args)) + + use_torch_compile = getattr(cfg.distill.perf, "torch_compile", False) + use_apex_gn = getattr(cfg.distill.perf, "use_apex_gn", False) + profile_mode = getattr(cfg.distill.perf, "profile_mode", False) + + model_args["use_apex_gn"] = use_apex_gn + model_args["profile_mode"] = profile_mode + + if enable_amp: + model_args["amp_mode"] = enable_amp + + # Load the diffusion checkpoint for distillation + if ( + hasattr(cfg.distill.io, "diffusion_checkpoint_path") + and cfg.distill.io.diffusion_checkpoint_path is not None + ): + diffusion_checkpoint_path = to_absolute_path( + cfg.distill.io.diffusion_checkpoint_path + ) + if not os.path.exists(diffusion_checkpoint_path): + raise FileNotFoundError( + f"Expected this diffusion checkpoint but not found: {diffusion_checkpoint_path}" + ) + diffusion_model = Module.from_checkpoint( + diffusion_checkpoint_path, override_args={"use_apex_gn": use_apex_gn} + ) + diffusion_model.amp_mode = enable_amp + diffusion_model.profile_mode = profile_mode + diffusion_model.eval().requires_grad_(False).to(dist.device) + if use_apex_gn: + diffusion_model.to(memory_format=torch.channels_last) + logger0.success("Loaded the pre-trained diffusion model") + else: + raise ValueError( + "A diffusion checkpoint must be provided for distillation training. " + "Set cfg.distill.io.diffusion_checkpoint_path." + ) + diffusion_model = None + + if cfg.wandb.watch_model and dist.rank == 0: + wandb.watch(diffusion_model) + + # Load the regression checkpoint if applicable + if ( + hasattr(cfg.distill.io, "regression_checkpoint_path") + and cfg.distill.io.regression_checkpoint_path is not None + ): + regression_checkpoint_path = to_absolute_path( + cfg.distill.io.regression_checkpoint_path + ) + if not os.path.exists(regression_checkpoint_path): + raise FileNotFoundError( + f"Expected this regression checkpoint but not found: {regression_checkpoint_path}" + ) + regression_net = Module.from_checkpoint( + regression_checkpoint_path, override_args={"use_apex_gn": use_apex_gn} + ) + regression_net.amp_mode = enable_amp + regression_net.profile_mode = profile_mode + regression_net.eval().requires_grad_(False).to(dist.device) + if use_apex_gn: + regression_net.to(memory_format=torch.channels_last) + logger0.success("Loaded the pre-trained regression model") + else: + regression_net = None + + # Compile the teacher diffusion model and regression net if applicable + if use_torch_compile: + logger0.info("Compiling the diffusion model and regression net...") + if diffusion_model: + diffusion_model = torch.compile(diffusion_model) + if regression_net: + regression_net = torch.compile(regression_net) + + input_dtype = torch.float32 + if enable_amp: + input_dtype = torch.float32 + elif fp16: + input_dtype = torch.float16 + + # Instantiate the loss function for distillation + loss_fn = DistillLoss( + regression_net=regression_net, + hr_mean_conditioning=cfg.model.hr_mean_conditioning, + ) + + # Instantiate the FastGenNet for distillation + model_cfg_update = DictConfig( + { + "precision": PRECISION_MAP[fp_optimizations], + "precision_infer": PRECISION_MAP[str(input_dtype)], + "input_shape": (diffusion_model.img_out_channels, *super_patch_shape), + "window": window, + "device": dist.device, + "net": diffusion_model, + **patching_cfg, + }, + flags={"allow_objects": True}, + ) + model_cfg = OmegaConf.merge(model_cfg_update, distill_cfg.model) + unwrapped_model = MODEL_MAP[cfg.distill.hp.mode](model_cfg) + + for callback in callbacks: + callback.on_model_init_start(unwrapped_model) + unwrapped_model.on_train_begin() + + if dist.world_size > 1: + model = DDPWrapper( + unwrapped_model, + device_ids=[dist.local_rank], + broadcast_buffers=True, + output_device=dist.device, + find_unused_parameters=True, + gradient_as_bucket_view=True, + ) + else: + model = unwrapped_model + + # Enable distributed data parallel if applicable + # move to device and initialize before DDP + save_models = [unwrapped_model.net.net, unwrapped_model.net.logvar_linear] + for name in ["fake_score", "discriminator"]: + if hasattr(unwrapped_model, name): + save_models.append(getattr(unwrapped_model, name)) + logger0.info(f"Saving {name} model") + if unwrapped_model.use_ema: + save_models += [unwrapped_model.ema.net, unwrapped_model.ema.logvar_linear] + logger0.info("Saving EMA model") + + # Compute the number of required gradient accumulation rounds + # It is automatically used if batch_size_per_gpu * dist.world_size < total_batch_size + batch_gpu_total, num_accumulation_rounds = compute_num_accumulation_rounds( + cfg.distill.hp.total_batch_size, + cfg.distill.hp.batch_size_per_gpu, + dist.world_size, + ) + batch_size_per_gpu = cfg.distill.hp.batch_size_per_gpu + logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds") + + # calculate patch per iter + patch_num = getattr(cfg.distill.hp.patching, "patch_num", 1) + if hasattr(cfg.distill.hp.patching, "max_patch_per_gpu"): + max_patch_per_gpu = cfg.distill.hp.patching.max_patch_per_gpu + if max_patch_per_gpu // batch_size_per_gpu < 1: + raise ValueError( + f"max_patch_per_gpu ({max_patch_per_gpu}) must be greater or equal to batch_size_per_gpu ({batch_size_per_gpu})." + ) + max_patch_num_per_iter = min( + patch_num, (max_patch_per_gpu // batch_size_per_gpu) + ) + patch_iterations = ( + patch_num + max_patch_num_per_iter - 1 + ) // max_patch_num_per_iter + patch_nums_iter = [ + min(max_patch_num_per_iter, patch_num - i * max_patch_num_per_iter) + for i in range(patch_iterations) + ] + logger0.info( + f"max_patch_num_per_iter is {max_patch_num_per_iter}, patch_iterations is {patch_iterations}, patch_nums_iter is {patch_nums_iter}" + ) + else: + patch_nums_iter = [patch_num] + + # Distillation only support diffusion models. + # Set patch gradient accumulation only for patched diffusion models + if cfg.model.name in { + "patched_diffusion", + "lt_aware_patched_diffusion", + }: + if len(patch_nums_iter) > 1: + if not patching: + logger0.info( + "Patching is not enabled: patch gradient accumulation automatically disabled." + ) + use_patch_grad_acc = False + else: + use_patch_grad_acc = True + else: + use_patch_grad_acc = False + + # Automatically disable patch gradient accumulation for non-patched models + else: + logger0.info( + "Training a non-patched model: patch gradient accumulation automatically disabled." + ) + use_patch_grad_acc = None + + # Instantiate the optimizer + for callback in callbacks: + callback.on_optimizer_init_start(unwrapped_model) + + optim_cls = getattr(torch.optim, cfg.distill.hp.get("optimizer_name", "Adam")) + optimizer = optim_cls( + params=model.parameters(), + **OmegaConf.to_container(cfg.distill.hp.optimizer, resolve=True), + ) + + scheduler = get_scheduler( + name=cfg.distill.hp.scheduler_name, + # cfg=getattr(cfg.distill.hp.scheduler, cfg.distill.hp.scheduler_name), + cfg=OmegaConf.to_container( + getattr(cfg.distill.hp.scheduler, cfg.distill.hp.scheduler_name), + resolve=True, + ), + optimizer=optimizer, + ) + for callback in callbacks: + callback.on_optimizer_init_end(unwrapped_model) + + # Record the current time to measure the duration of subsequent operations. + start_time = time.time() + + ## Load optimizer checkpoint if exists + if dist.world_size > 1: + torch.distributed.barrier() + # Distill callback + for callback in callbacks: + callback.on_load_checkpoint_start(unwrapped_model) + + cur_nimg = load_checkpoint( + path=checkpoint_dir, + models=save_models, + optimizer=optimizer, + scheduler=scheduler, + device=dist.device, + ) + logger0.info(f"Resuming training from {cur_nimg} images...") + + for callback in callbacks: + callback.on_load_checkpoint_end(unwrapped_model) + + ############################################################################ + # MAIN TRAINING LOOP # + ############################################################################ + + logger0.info(f"Training for {cfg.distill.hp.training_duration} images...") + done = False + + # FastGen Initialization + for callback in callbacks: + callback.on_train_begin( + unwrapped_model, iteration=cur_nimg // cfg.distill.hp.total_batch_size + ) + + # init variables to monitor running mean of average loss since last periodic + average_loss_running_mean = 0 + n_average_loss_running_mean = 1 + start_nimg = cur_nimg + + # enable profiler: + with cuda_profiler(): + with profiler_emit_nvtx(): + while not done: + tick_start_nimg = cur_nimg + tick_start_time = time.time() + + if cur_nimg - start_nimg == 24 * cfg.distill.hp.total_batch_size: + logger0.info(f"Starting Profiler at {cur_nimg}") + cuda_profiler_start() + + if cur_nimg - start_nimg == 25 * cfg.distill.hp.total_batch_size: + logger0.info(f"Stopping Profiler at {cur_nimg}") + cuda_profiler_stop() + + with nvtx.annotate("Training iteration", color="green"): + # Compute & accumulate gradients + optimizer.zero_grad(set_to_none=True) + loss_accum = 0 + for callback in callbacks: + callback.on_training_step_begin( + unwrapped_model, + iteration=cur_nimg // cfg.distill.hp.total_batch_size, + ) + for n_i in range(num_accumulation_rounds): + with nvtx.annotate( + f"accumulation round {n_i}", color="Magenta" + ): + with nvtx.annotate("loading data", color="green"): + img_clean, img_lr, *lead_time_label = next( + dataset_iterator + ) + data_batch = { + "img_clean": img_clean, + "img_lr": img_lr, + "lead_time_label": lead_time_label, + } + for callback in callbacks: + callback.on_training_accum_step_begin( + unwrapped_model, + data_batch=data_batch, + iteration=cur_nimg + // cfg.distill.hp.total_batch_size, + accum_iter=n_i, + ) + if use_apex_gn: + img_clean = img_clean.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + img_lr = img_lr.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + else: + img_clean = ( + img_clean.to(dist.device) + .to(input_dtype) + .contiguous() + ) + img_lr = ( + img_lr.to(dist.device) + .to(input_dtype) + .contiguous() + ) + loss_fn_kwargs = { + "net": model, + "img_clean": img_clean, + "img_lr": img_lr, + "augment_pipe": None, + "iteration": cur_nimg + // cfg.distill.hp.total_batch_size, + } + if use_patch_grad_acc is not None: + loss_fn_kwargs["use_patch_grad_acc"] = ( + use_patch_grad_acc + ) + + if lead_time_label: + lead_time_label = ( + lead_time_label[0].to(dist.device).contiguous() + ) + loss_fn_kwargs.update( + {"lead_time_label": lead_time_label} + ) + else: + lead_time_label = None + if use_patch_grad_acc: + loss_fn.y_mean = None + + for patch_num_per_iter in patch_nums_iter: + if patching is not None: + patching.set_patch_num(patch_num_per_iter) + loss_fn_kwargs.update({"patching": patching}) + with nvtx.annotate(f"loss forward", color="green"): + with torch.autocast( + "cuda", dtype=amp_dtype, enabled=enable_amp + ): + loss, loss_map, output_batch = loss_fn( + **loss_fn_kwargs + ) + + # loss is averaged in the loss_fn (different from train.py); we need to sum it up and divide by num_accumulation_rounds * num_patches + assert loss.ndim == 0, ( + f"Loss has {loss.ndim} dimensions, expected 0" + ) + loss = ( + loss + * patch_num_per_iter + / (num_accumulation_rounds * patch_num) + ) + loss_accum += loss.detach() + with nvtx.annotate(f"loss backward", color="yellow"): + loss.backward() + + with nvtx.annotate(f"loss aggregate", color="green"): + loss_sum = torch.tensor([loss_accum], device=dist.device) + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce( + loss_sum, op=torch.distributed.ReduceOp.SUM + ) + average_loss = (loss_sum / dist.world_size).cpu().item() + + # update running mean of average loss since last periodic task + average_loss_running_mean += ( + average_loss - average_loss_running_mean + ) / n_average_loss_running_mean + n_average_loss_running_mean += 1 + + if dist.rank == 0: + loss_map = {f"training/{k}": v for k, v in loss_map.items()} + loss_map.update( + { + "training/loss": average_loss, + "training/loss_running_mean": average_loss_running_mean, + } + ) + if hasattr(unwrapped_model, "ratio"): + loss_map["schedule/ratio"] = unwrapped_model.ratio + for k, v in loss_map.items(): + writer.add_scalar(k, v, cur_nimg) + if wandb.run is not None: + wandb.log(loss_map, step=cur_nimg) + + ptt = is_time_for_periodic_task( + cur_nimg, + cfg.distill.io.print_progress_freq, + done, + cfg.distill.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ) + if ptt: + # reset running mean of average loss + average_loss_running_mean = 0 + n_average_loss_running_mean = 1 + + # Update weights. + with nvtx.annotate("update weights", color="blue"): + if scheduler is None: + assert cfg.distill.hp.scheduler_name == "modulus_default" + scheduler_cfg = cfg.distill.hp.scheduler.modulus_default + lr_rampup = ( + scheduler_cfg.lr_rampup + ) # ramp up the learning rate + for g in optimizer.param_groups: + if lr_rampup > 0: + g["lr"] = cfg.distill.hp.opt.lr * min( + cur_nimg / lr_rampup, 1 + ) + if cur_nimg >= lr_rampup: + g["lr"] *= scheduler_cfg.lr_decay ** ( + (cur_nimg - lr_rampup) // 5e6 + ) + current_lr = g["lr"] + else: + scheduler.step() + current_lr = scheduler.get_last_lr()[0] + if dist.rank == 0: + writer.add_scalar("learning_rate", current_lr, cur_nimg) + if wandb.run is not None: + wandb.log( + {"learning_rate_decay": current_lr}, step=cur_nimg + ) + + handle_and_clip_gradients( + model, + grad_clip_threshold=cfg.distill.hp.grad_clip_threshold, + ) + with nvtx.annotate("optimizer step", color="blue"): + for callback in callbacks: + callback.on_optimizer_step_begin( + unwrapped_model, + iteration=cur_nimg // cfg.distill.hp.total_batch_size, + ) + optimizer.step() + + cur_nimg += cfg.distill.hp.total_batch_size + done = cur_nimg >= cfg.distill.hp.training_duration + + for callback in callbacks: + callback.on_training_step_end( + unwrapped_model, + data_batch=data_batch, + output_batch=output_batch, + loss_dict=loss_map, + iteration=cur_nimg // cfg.distill.hp.total_batch_size, + ) + del loss, loss_sum, loss_map, output_batch + + with nvtx.annotate("validation", color="red"): + # Validation + if validation_dataset_iterator is not None: + valid_loss_accum = 0 + if is_time_for_periodic_task( + cur_nimg, + cfg.distill.io.validation_freq, + done, + cfg.distill.hp.total_batch_size, + dist.rank, + ): + for callback in callbacks: + callback.on_validation_begin( + unwrapped_model, + iteration=cur_nimg + // cfg.distill.hp.total_batch_size, + ) + with torch.no_grad(): + for _ in range(cfg.distill.io.validation_steps): + ( + img_clean_valid, + img_lr_valid, + *lead_time_label_valid, + ) = next(validation_dataset_iterator) + + for callback in callbacks: + callback.on_validation_step_begin( + unwrapped_model, + data_batch=data_batch, + iteration=cur_nimg + // cfg.distill.hp.total_batch_size, + ) + + if use_apex_gn: + img_clean_valid = img_clean_valid.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + img_lr_valid = img_lr_valid.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + + else: + img_clean_valid = ( + img_clean_valid.to(dist.device) + .to(input_dtype) + .contiguous() + ) + img_lr_valid = ( + img_lr_valid.to(dist.device) + .to(input_dtype) + .contiguous() + ) + + loss_valid_kwargs = { + "net": model, + "img_clean": img_clean_valid, + "img_lr": img_lr_valid, + "augment_pipe": None, + "iteration": cur_nimg + // cfg.distill.hp.total_batch_size, + } + if use_patch_grad_acc is not None: + loss_valid_kwargs["use_patch_grad_acc"] = ( + use_patch_grad_acc + ) + if lead_time_label_valid: + lead_time_label_valid = ( + lead_time_label_valid[0] + .to(dist.device) + .contiguous() + ) + loss_valid_kwargs.update( + {"lead_time_label": lead_time_label_valid} + ) + if use_patch_grad_acc: + loss_fn.y_mean = None + + for patch_num_per_iter in patch_nums_iter: + if patching is not None: + patching.set_patch_num(patch_num_per_iter) + loss_valid_kwargs.update( + {"patching": patching} + ) + with torch.autocast( + "cuda", dtype=amp_dtype, enabled=enable_amp + ): + ( + loss_valid, + loss_map_valid, + output_batch_valid, + ) = loss_fn(**loss_valid_kwargs) + + loss_valid = ( + (loss_valid.sum() / batch_size_per_gpu) + .cpu() + .item() + ) + valid_loss_accum += ( + loss_valid + / cfg.distill.io.validation_steps + / len(patch_nums_iter) + ) + valid_loss_sum = torch.tensor( + [valid_loss_accum], device=dist.device + ) + if dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce( + valid_loss_sum, + op=torch.distributed.ReduceOp.SUM, + ) + average_valid_loss = valid_loss_sum / dist.world_size + if dist.rank == 0: + loss_map_valid = { + f"valid/{k}": v + for k, v in loss_map_valid.items() + } + loss_map_valid.update( + { + "valid/loss": average_valid_loss, + } + ) + for k, v in loss_map_valid.items(): + writer.add_scalar(k, v, cur_nimg) + if wandb.run is not None: + wandb.log(loss_map_valid, step=cur_nimg) + + # generate images and log to wandb + diff_out = output_batch_valid["gen_rand"] + if isinstance(diff_out, Callable): + with torch.autocast( + "cuda", dtype=amp_dtype, enabled=enable_amp + ): + diff_out = diff_out() + assert isinstance(diff_out, torch.Tensor) + y_mean = output_batch_valid["y_mean"] + if patching is not None: + y_mean = patching.apply(y_mean) + img_clean_valid = patching.apply( + img_clean_valid + ) + + image_out = diff_out + y_mean + + # log first element in batch + images = { + "mean": y_mean, + "diff": diff_out, + "pred": image_out, + "truth": img_clean_valid, + } + images = { + name: validation_dataset.denormalize_output( + img.float().cpu().numpy() + ) + for name, img in images.items() + } + + wandb_log = {"valid/loss": average_valid_loss} + for batch_idx in range(images["pred"].shape[0]): + for channel_idx in range( + images["pred"].shape[1] + ): + info = validation_dataset.output_channels()[ + channel_idx + ] + channel_name = info.name + info.level + channel_min = np.min( + images["truth"][batch_idx, channel_idx] + ) + channel_max = np.max( + images["truth"][batch_idx, channel_idx] + ) + span = (channel_max - channel_min) * 1.5 + channel_images = [] + for name, img in images.items(): + img = img[batch_idx, channel_idx] + img = ( + img - channel_min + 0.25 * span + ) / (1.5 * span) + img = ( + (img * 255) + .clip(0, 255) + .astype(np.uint8) + ) + channel_images.append( + wandb.Image(img, caption=name) + ) + wandb_log[ + f"images/{channel_name}_{batch_idx}" + ] = channel_images + wandb.log(wandb_log, step=cur_nimg) + + # free memory on rank 0 + del images, image_out, diff_out, y_mean, wandb_log + + for callback in callbacks: + callback.on_validation_step_end( + unwrapped_model, + data_batch=data_batch, + output_batch=output_batch_valid, + loss_dict=loss_map_valid, + iteration=cur_nimg + // cfg.distill.hp.total_batch_size, + ) + + # free memory after validation + loss_fn.y_mean = None + del img_clean_valid, img_lr_valid, lead_time_label_valid + del ( + loss_valid, + valid_loss_sum, + valid_loss_accum, + average_valid_loss, + loss_map_valid, + output_batch_valid, + ) + gc.collect() + with torch.cuda.device(dist.device): + torch.cuda.empty_cache() + + for callback in callbacks: + callback.on_validation_end( + unwrapped_model, + iteration=cur_nimg + // cfg.distill.hp.total_batch_size, + ) + + if is_time_for_periodic_task( + cur_nimg, + cfg.distill.io.print_progress_freq, + done, + cfg.distill.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + # Print stats if we crossed the printing threshold with this batch + tick_end_time = time.time() + fields = [] + fields += [f"samples {cur_nimg:<9.1f}"] + fields += [f"training_loss {average_loss:<7.2f}"] + fields += [ + f"training_loss_running_mean {average_loss_running_mean:<7.2f}" + ] + fields += [f"learning_rate {current_lr:<7.8f}"] + fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] + fields += [ + f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}" + ] + fields += [ + f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.2f}" + ] + fields += [ + f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" + ] + if torch.cuda.is_available(): + fields += [ + f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" + ] + fields += [ + f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" + ] + torch.cuda.reset_peak_memory_stats() + logger0.info(" ".join(fields)) + + # Save checkpoints + if dist.world_size > 1: + torch.distributed.barrier() + if is_time_for_periodic_task( + cur_nimg, + cfg.distill.io.save_checkpoint_freq, + done, + cfg.distill.hp.total_batch_size, + dist.rank, + rank_0_only=True, + ): + for callback in callbacks: + callback.on_save_checkpoint_start( + unwrapped_model, + iteration=cur_nimg // cfg.distill.hp.total_batch_size, + ) + save_checkpoint( + path=checkpoint_dir, + models=save_models, + optimizer=optimizer, + scheduler=scheduler, + epoch=cur_nimg, + ) + for callback in callbacks: + callback.on_save_checkpoint_success( + unwrapped_model, + iteration=cur_nimg // cfg.distill.hp.total_batch_size, + path=checkpoint_dir, + ) + callback.on_save_checkpoint_end( + unwrapped_model, + iteration=cur_nimg // cfg.distill.hp.total_batch_size, + ) + + # Retain only the recent n checkpoints, if desired + if cfg.distill.io.save_n_recent_checkpoints > 0: + for suffix in [".mdlus", ".pt"]: + ckpts = checkpoint_list(checkpoint_dir, suffix=suffix) + while len(ckpts) > cfg.distill.io.save_n_recent_checkpoints: + os.remove(os.path.join(checkpoint_dir, ckpts[0])) + ckpts = ckpts[1:] + + # Done. + for callback in callbacks: + callback.on_train_end( + unwrapped_model, iteration=cur_nimg // cfg.distill.hp.total_batch_size + ) + callback.on_app_end( + unwrapped_model, iteration=cur_nimg // cfg.distill.hp.total_batch_size + ) + logger0.info("Training Completed.") + + +if __name__ == "__main__": + main() diff --git a/examples/weather/corrdiff/generate.py b/examples/weather/corrdiff/generate.py index a27dd8fef3..3742da62ad 100644 --- a/examples/weather/corrdiff/generate.py +++ b/examples/weather/corrdiff/generate.py @@ -43,6 +43,8 @@ deterministic_sampler, stochastic_sampler, ) +from helpers.distill_helpers import few_step_sampler + from physicsnemo.distributed import DistributedManager from physicsnemo.experimental.models.diffusion.preconditioning import ( tEDMPrecondSuperRes, @@ -197,8 +199,21 @@ def main(cfg: DictConfig) -> None: patching=patching, num_steps=getattr(cfg.sampler, "num_steps", 18), ) + elif cfg.sampler.type == "few-step": + sigma_max = cfg.sampler.get("sigma_max", 800.0) + sigma_mid = cfg.sampler.get("sigma_mid", None) + logger0.info( + f"Using few-step sampler with sigma_max={sigma_max} and sigma_mid={sigma_mid}" + ) + sampler_fn = partial( + few_step_sampler, + sigma_max=sigma_max, + sigma_mid=sigma_mid, + patching=patching, + ) + else: - raise ValueError(f"Unknown sampling method {cfg.sampling.type}") + raise ValueError(f"Unknown sampling method {cfg.sampler.type}") # Parse the distribution type distribution = getattr(cfg.generation, "distribution", None) diff --git a/examples/weather/corrdiff/helpers/distill_helpers.py b/examples/weather/corrdiff/helpers/distill_helpers.py new file mode 100644 index 0000000000..c251aa0d8f --- /dev/null +++ b/examples/weather/corrdiff/helpers/distill_helpers.py @@ -0,0 +1,908 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Distillation helpers for CorrDiff models. + +This module adapts FastGen's distillation methods for CorrDiff model. +The main additions over the base FastGen models are: + +- ``FastGenNet`` wraps the CorrDiff network into FastGen's + ``FastGenNetwork`` interface, adding SuperPatching2D unfold/fold and + window smoothing to support super-patch distillation training. +- ``CMModel``, ``SCMModel``, and ``DMD2Model`` override ``build_model`` + to construct the network via ``FastGenNet`` instead of FastGen's + default path. +- ``DistillLoss`` bridges CorrDiff's training interface to FastGen's ``single_train_step`` API. + +See ``https://github.com/NVlabs/FastGen/blob/main/fastgen/methods/README.md`` for the base implementations. +""" + +from functools import partial +import torch +from typing import Optional, Tuple, List +from copy import deepcopy +import einops +import numpy as np + +from physicsnemo.nn.module import UNetBlock +from physicsnemo.diffusion.multi_diffusion.patching import ( + RandomPatching2D, + BasePatching2D, + GridPatching2D, +) + +from fastgen.networks.network import FastGenNetwork +from fastgen.networks.noise_schedule import NET_PRED_TYPES +from fastgen.methods.consistency_model.CM import CMModel as CMBaseModel +from fastgen.methods.consistency_model.sCM import SCMModel as SCMBaseModel +from fastgen.methods.distribution_matching.dmd2 import DMD2Model as DMD2BaseModel +from fastgen.networks.discriminators import Discriminator_EDM as BaseDiscriminator_EDM +from fastgen.utils import lr_scheduler +from torch.optim.lr_scheduler import LambdaLR +from scipy.signal import windows + +from omegaconf import DictConfig + +PRECISION_MAP = { + "fp32": "float32", + "torch.float32": "float32", + "fp16": "float16", + "torch.float16": "float16", + "amp-fp16": "float16", + "amp-bf16": "bfloat16", +} + + +def change_block(module, attr, value): + """Set an attribute on UNetBlock modules, used to override block-level settings like dropout.""" + if isinstance(module, UNetBlock): + assert hasattr(module, attr), f"Attribute {attr} not found in module" + setattr(module, attr, value) + + +class DistillLoss: + """ + Loss function for CorrDiff-distillation training supported by FastGen framework. + """ + + def __init__(self, regression_net, hr_mean_conditioning=False): + self.regression_net = regression_net + self.hr_mean_conditioning = hr_mean_conditioning + self.y_mean = None + + def compute_loss( + self, + net, + y, + y_lr, + y_lr_res, + lead_time_label=None, + global_index=None, + augment_labels=None, + iteration=None, + ): + """Compute the distillation loss by delegating to FastGen's single_train_step.""" + assert not any(p.requires_grad for p in [y, y_lr, y_lr_res]) + data = { + "real": y, + # low-res (patched), low-res (unpatched), lead_time_label, global_index, augment_labels + "condition": ( + y_lr, + y_lr_res, + lead_time_label, + global_index, + augment_labels, + ), + "neg_condition": None, + } + + loss_map, output = net.single_train_step(data, iteration) + output["y_mean"] = self.y_mean + return loss_map["total_loss"], loss_map, output + + def patching(self, y, y_lr, batch_size, patching=None): + """Extract patches from various input tensors""" + global_index = None + if patching is not None: + # Patched residual + # (batch_size * patch_num, c_out, patch_shape_y, patch_shape_x) + y = patching.apply(input=y) + # Patched conditioning on y_lr and interp(img_lr) + # (batch_size * patch_num, 2*c_in, patch_shape_y, patch_shape_x) + y_lr = patching.apply(input=y_lr) + + global_index = patching.global_index(batch_size, y.device) + + return y, y_lr, global_index + + def augment(self, img_clean, img_lr, augment_pipe=None): + """Apply data augmentation jointly to the clean and low-res images.""" + img_tot = torch.cat((img_clean, img_lr), dim=1) + y_tot, augment_labels = ( + augment_pipe(img_tot) if augment_pipe is not None else (img_tot, None) + ) + y = y_tot[:, : img_clean.shape[1], :, :] + y_lr = y_tot[:, img_clean.shape[1] :, :, :] + return y, y_lr, augment_labels + + def regression(self, y_lr_res, y, lead_time_label=None, augment_labels=None): + """Run the regression network to produce the mean prediction.""" + if lead_time_label is not None: + return self.regression_net( + torch.zeros_like(y, device=y.device), + y_lr_res, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + return self.regression_net( + torch.zeros_like(y, device=y.device), + y_lr_res, + augment_labels=augment_labels, + ) + + def __call__( + self, + net, + img_clean, + img_lr, + patching=None, + lead_time_label=None, + augment_pipe=None, + iteration=None, + use_patch_grad_acc=False, + ): + """Compute the full distillation loss""" + # Safety check: enforce patching object + if patching and not isinstance(patching, RandomPatching2D): + raise ValueError("patching must be a 'RandomPatching2D' object.") + # Safety check: enforce shapes + if ( + img_clean.shape[0] != img_lr.shape[0] + or img_clean.shape[2:] != img_lr.shape[2:] + ): + raise ValueError( + f"Shape mismatch between img_clean {img_clean.shape} and " + f"img_lr {img_lr.shape}. " + f"Batch size, height and width must match." + ) + + # augment for conditional generation + y, y_lr, augment_labels = self.augment( + img_clean=img_clean, img_lr=img_lr, augment_pipe=augment_pipe + ) + del img_clean, img_lr + y_lr_res = y_lr + batch_size = y.shape[0] + + # if using multi-iterations of patching, switch to optimized version + if use_patch_grad_acc: + if self.y_mean is None: + self.y_mean = self.regression( + y_lr_res=y_lr_res, + y=y, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + # if on full domain, or if using patching without multi-iterations + else: + self.y_mean = self.regression( + y_lr_res=y_lr_res, + y=y, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + + y = y - self.y_mean + assert not y.requires_grad + + if self.hr_mean_conditioning: + y_lr = torch.cat((self.y_mean, y_lr), dim=1) + assert not y_lr.requires_grad + + # patchified training + # conditioning: cat(y_mean, y_lr, input_interp, pos_embd), 4+12+100+4 + # removed patch_embedding_selector due to compilation issue with dynamo. + y, y_lr, global_index = self.patching( + y=y, y_lr=y_lr, batch_size=batch_size, patching=patching + ) + + return self.compute_loss( + net=net, + y=y, + y_lr=y_lr, + y_lr_res=y_lr_res, + lead_time_label=lead_time_label, + global_index=global_index, + augment_labels=augment_labels, + iteration=iteration, + ) + + +class SuperPatching2D(BasePatching2D): + """Patching utlities which decompose superpatch into regular patches for superpatch-distillation training. + + Parameters + ---------- + img_shape : Tuple[int, int] + Height and width of the superpatch :math:`(H, W)`. + patch_shape : Tuple[int, int] + Height and width of each patch :math:`(H_p, W_p)`. + Must divide the superpatch dimensions after accounting for overlap. + overlap_pix : int, optional + Number of pixels of overlap between adjacent patches, by default 0. + When non-zero, the ``fuse`` method averages (or applies windowed + smoothing to) the overlapping regions during reassembly. + """ + + def __init__( + self, + img_shape: Tuple[int, int], + patch_shape: Tuple[int, int], + overlap_pix: int = 0, + ): + super().__init__(img_shape, patch_shape) + self.overlap_pix = overlap_pix + self.patch_shape_y = self.patch_shape[0] + self.patch_shape_x = self.patch_shape[1] + self.img_shape_y = self.img_shape[0] + self.img_shape_x = self.img_shape[1] + + self.num_patches_y, remainder_y = divmod( + self.img_shape_y - self.overlap_pix, self.patch_shape_y - self.overlap_pix + ) + self.num_patches_x, remainder_x = divmod( + self.img_shape_x - self.overlap_pix, self.patch_shape_x - self.overlap_pix + ) + assert remainder_x == 0 and remainder_y == 0 + + # Initialize cache for overlap count + self.overlap_count = None + + def unfold(self, x): + """ + Wrapper around torch.nn.functional.unfold to extract regular patches from the superpatch. + """ + + # Cast to float + dtype = x.dtype + if dtype == torch.int32: + x = x.view(torch.float32) + elif dtype == torch.int64: + x = x.view(torch.float64) + + x = torch.nn.functional.unfold( + input=x, + kernel_size=(self.patch_shape_y, self.patch_shape_x), + stride=( + self.patch_shape_y - self.overlap_pix, + self.patch_shape_x - self.overlap_pix, + ), + ) + + # cast back + if dtype in [torch.int32, torch.int64]: + x = x.view(dtype) + + return x + + def fold(self, x): + """ + Wrapper around torch.nn.functional.fold to fold the regular patches into the superpatch. + """ + # Cast to float + dtype = x.dtype + if dtype == torch.int32: + x = x.view(torch.float32) + elif dtype == torch.int64: + x = x.view(torch.float64) + + x = torch.nn.functional.fold( + input=x, + output_size=(self.img_shape_y, self.img_shape_x), + kernel_size=(self.patch_shape_y, self.patch_shape_x), + stride=( + self.patch_shape_y - self.overlap_pix, + self.patch_shape_x - self.overlap_pix, + ), + ) + + # cast back + if dtype in [torch.int32, torch.int64]: + x = x.view(dtype) + + return x + + def apply(self, input, additional_input=None): + """ + Unfold the superpatch into regular patches. + """ + unfold = self.unfold(input) + unfold = einops.rearrange( + unfold, + "b (c p_h p_w) (nb_p_h nb_p_w) -> (nb_p_w nb_p_h b) c p_h p_w", + p_h=self.patch_shape_y, + p_w=self.patch_shape_x, + nb_p_h=self.num_patches_y, + nb_p_w=self.num_patches_x, + ) + if additional_input is not None: + additional_input = torch.nn.functional.interpolate( + input=additional_input, size=self.patch_shape, mode="bilinear" + ) + num_super_patches, rem = divmod(input.shape[0], additional_input.shape[0]) + assert rem == 0, ( + f"{additional_input.shape[0]} must be a factor of {input.shape[0]}" + ) + repeats = self.num_patches_y * self.num_patches_x * num_super_patches + # repeat each patch in the batch patch_num times + additional_input = additional_input.repeat(repeats, 1, 1, 1) + unfold = torch.cat((unfold, additional_input), dim=1) + + return unfold + + def get_overlap_count(self, device, dtype): + """ + Compute the overlap count for the overlapping pixels. + """ + # compute overlap count + ones = torch.ones( + (1, 1, self.img_shape_y, self.img_shape_x), device=device, dtype=dtype + ) + overlap_count = self.unfold(ones) + return self.fold(overlap_count) + + def fuse(self, input, batch_size=None, window=None): + """ + Fold the regular patches into the superpatch. + """ + if window is not None: + if window.shape[0] == 1: + window = window.tile((input.shape[0], input.shape[1], 1, 1)) + + x = einops.rearrange( + input * window, + "(nb_p_w nb_p_h b) c p_h p_w -> b (c p_h p_w) (nb_p_h nb_p_w)", + p_h=self.patch_shape_y, + p_w=self.patch_shape_x, + nb_p_h=self.num_patches_y, + nb_p_w=self.num_patches_x, + ) + weights = einops.rearrange( + window, + "(nb_p_w nb_p_h b) c p_h p_w -> b (c p_h p_w) (nb_p_h nb_p_w)", + p_h=self.patch_shape_y, + p_w=self.patch_shape_x, + nb_p_h=self.num_patches_y, + nb_p_w=self.num_patches_x, + ) + + # Stitch patches together (by summing over overlapping patches) + folded = self.fold(x) + weights = self.fold(weights) + return folded / weights + else: + # Reshape input to make it 3D to apply fold + x = einops.rearrange( + input, + "(nb_p_w nb_p_h b) c p_h p_w -> b (c p_h p_w) (nb_p_h nb_p_w)", + p_h=self.patch_shape_y, + p_w=self.patch_shape_x, + nb_p_h=self.num_patches_y, + nb_p_w=self.num_patches_x, + ) + # Stitch patches together (by summing over overlapping patches) + folded = self.fold(x) + + if self.overlap_count is None: + self.overlap_count = self.get_overlap_count( + device=folded.device, dtype=folded.dtype + ) + if not ( + self.overlap_count.dtype == folded.dtype + and self.overlap_count.device == folded.device + ): + self.overlap_count = self.overlap_count.to(folded) + + # Normalize by overlap count + return folded / self.overlap_count + + +class FastGenNet(FastGenNetwork): + """ + A wrapper around the FastGenNetwork in FastGen, which enables distilling CorrDiff models with various methods in FastGen framework. + Supports super-patching training and window smoothing. + + See `fastgen.networks.network.FastGenNetwork` for more details. + """ + + def __init__( + self, + net, + block_kwargs=None, + patching=None, + window=None, + net_pred_type="x0", + schedule_type="edm", + **kwargs, + ): + super().__init__( + net_pred_type=net_pred_type, schedule_type=schedule_type, **kwargs + ) + self.net = net + self.logvar_linear = torch.nn.Linear(self.net.model.map_noise.num_channels, 1) + if block_kwargs is not None: + for attr, value in block_kwargs.items(): + self.apply(partial(change_block, attr=attr, value=value)) + + # patching + if patching is not None and not isinstance(patching, SuperPatching2D): + raise ValueError("patching must be a 'SuperPatching2D' object.") + self.patching = patching + self.window = window + + def forward( + self, + y_t, + t, + condition=None, + return_features_early=False, + feature_indices=None, + return_logvar=False, + fwd_pred_type: Optional[str] = None, + ): + """Forward pass with superpatch unfold/fold and optional window smoothing.""" + y_lr, y_lr_res, lead_time_label, global_index, augment_labels = condition + # squeeze all dims after the first one and expand to batchsize + t = t.squeeze(list(range(1, t.ndim))).expand(y_t.shape[0]) + assert t.shape == (y_t.shape[0],) + + # Preconditioning weights for input + y_t_in, t_in = y_t, t + + self.net.model.feature_indices = feature_indices + self.net.model.features = [] + + if fwd_pred_type is None: + fwd_pred_type = self.net_pred_type + else: + assert fwd_pred_type in NET_PRED_TYPES, ( + f"{fwd_pred_type} is not supported as fwd_pred_type" + ) + + # superpatch unfolding: superpatches -> regular patches + if self.patching is not None: + y_t = self.patching.apply(y_t) + y_lr = self.patching.apply(y_lr, additional_input=y_lr_res) + + # TODO(jberner): memory optimizations + global_index = self.patching.apply(global_index) + t = t.repeat(self.patching.num_patches_y * self.patching.num_patches_x) + + del y_lr_res + + if lead_time_label is not None: + out = self.net( + y_t, + y_lr, + t, + embedding_selector=None, + global_index=global_index, + lead_time_label=lead_time_label, + augment_labels=augment_labels, + ) + else: + out = self.net( + y_t, + y_lr, + t, + embedding_selector=None, + global_index=global_index, + augment_labels=augment_labels, + ) + + # superpatch folding: regular patches -> superpatches + if self.patching is not None: + out = self.patching.fuse(out, window=self.window) + + out = self.noise_scheduler.convert_model_output( + y_t_in, + out, + t_in, + src_pred_type=self.net_pred_type, + target_pred_type=fwd_pred_type, + ) + + if feature_indices is not None and len(feature_indices) > 0: + features = self.net.model.features + # reset features + self.net.model.features = None + assert len(features) == len(feature_indices), ( + f"{len(features)} != {len(feature_indices)}" + ) + if return_features_early: + return features + # score and features; score, features + out = [out, features] + + if return_logvar: + emb_timestep = self.net.model.map_noise(t.flatten()) + logvar = self.logvar_linear(emb_timestep) + return out, logvar + return out + + +def build(config: DictConfig, use_ema: bool = False): + """Build a FastGenNet and optional EMA copy from a distillation config.""" + # Patching + patching = None + if "patch_shape" in config: + patching = SuperPatching2D( + img_shape=config.input_shape[-2:], + patch_shape=config.patch_shape, + overlap_pix=config.overlap_pix, + ) + window = None + if "window" in config: + window = config.window + + # Instantiate the generator network + net = FastGenNet( + net=config.net, + patching=patching, + window=window, + train_p_mean=config.sample_t_cfg.train_p_mean, + train_p_std=config.sample_t_cfg.train_p_std, + min_t=config.sample_t_cfg.min_t, + max_t=config.sample_t_cfg.max_t, + net_pred_type="x0", + schedule_type="edm", + block_kwargs=config.get("block_kwargs"), + ) + + net.train().requires_grad_(True) + + # initialize EMA network + ema = None + if use_ema: + ema = deepcopy(net) + ema.eval().requires_grad_(False) + + return net, ema + + +class CMModel(CMBaseModel): + """Consistency Model for Corrdiff distillation. + + A wrapper around the FastGen CM model in FastGen framework. + See `fastgen.methods.consistency_model.CM.CMModel` for more details. + + References: + - Song et al., 2023: https://arxiv.org/abs/2303.01469 + - Geng et al., 2024: https://arxiv.org/abs/2406.14548 + """ + + def build_model(self): + """Build the student, EMA, and optional teacher networks for consistency model.""" + self.net, self.ema = build(self.config, use_ema=self.use_ema) + + # instantiate the teacher and consistency network + if self.config.loss_config.use_cd: + self.teacher = deepcopy(self.net) + self.teacher.eval().requires_grad_(False) + + +class SCMModel(SCMBaseModel): + """Continuous-time Consistency Model with TrigFlow for CorrDiff distillation. + + A wrapper around the FastGen sCM model in FastGen framework. + See `fastgen.methods.consistency_model.sCM.SCMModel` for more details. + + References: + - Lu & Song, 2024: https://arxiv.org/abs/2410.11081 + """ + + def build_model(self): + """Build the student, EMA, and optional teacher networks for sCM model.""" + self.net, self.ema = build(self.config, use_ema=self.use_ema) + + # instantiate the teacher and consistency network + if self.config.loss_config.use_cd: + self.teacher = deepcopy(self.net) + self.teacher.eval().requires_grad_(False) + else: + # TODO(jberner): remove this once we do not require a teacher anymore + self.teacher = torch.nn.Identity() + + +class Discriminator_EDM(BaseDiscriminator_EDM): + """EDM Discriminator for CorrDiff distillation. + + A wrapper around the FastGen EDM discriminator in FastGen framework. + See `fastgen.networks.discriminators.Discriminator_EDM` for more details. + """ + + def __init__( + self, + feature_indices=None, + all_res=[32, 16, 8], + in_channels=256, + ): + torch.nn.Module.__init__(self) + if feature_indices is None: + feature_indices = {len(all_res) - 1} # use the middle bottleneck feature + self.feature_indices = { + i for i in feature_indices if i < len(all_res) + } # make sure feature indices are valid + self.in_res = [all_res[i] for i in sorted(feature_indices)] + if not isinstance(in_channels, (list, tuple)): + in_channels = [in_channels] * len(self.feature_indices) + self.in_channels = [in_channels[i] for i in sorted(self.feature_indices)] + + self.discriminator_heads = torch.nn.ModuleList() + for res, in_channels in zip(self.in_res, self.in_channels): + layers = [] + while res > 8: + # reduce the resolution by half, until 8x8 + layers.extend( + [ + torch.nn.Conv2d( + kernel_size=4, + in_channels=in_channels, + out_channels=in_channels, + stride=2, + padding=1, + ), + torch.nn.GroupNorm(num_groups=32, num_channels=in_channels), + torch.nn.SiLU(), + ] + ) + res //= 2 + + layers.extend( + [ + torch.nn.Conv2d( + kernel_size=4, + in_channels=in_channels, + out_channels=in_channels, + stride=2, + padding=1, + ), + # 8x8 -> 4x4 + torch.nn.GroupNorm(num_groups=32, num_channels=in_channels), + torch.nn.SiLU(), + torch.nn.Conv2d( + kernel_size=4, + in_channels=in_channels, + out_channels=in_channels, + stride=4, + padding=0, + ), + # 4x4 -> 1x1 + torch.nn.GroupNorm(num_groups=32, num_channels=in_channels), + torch.nn.SiLU(), + torch.nn.Conv2d( + kernel_size=1, + in_channels=in_channels, + out_channels=1, + stride=1, + padding=0, + ), + # 1x1 -> 1x1 + ] + ) + + # append the layers for current resolution to the discriminator head + self.discriminator_heads.append(torch.nn.Sequential(*layers)) + + +class DMD2Model(DMD2BaseModel): + """VSD + GAN for CorrDiff distillation. + + A wrapper around the FastGen DMD2 model in FastGen framework. + See `fastgen.methods.distribution_matching.dmd2.DMD2Model` for more details. + + References: + - Yin et al., 2024: https://arxiv.org/abs/2405.14867 + """ + + def build_model(self): + """Build the student, teacher, fake-score, and optional discriminator networks for DMD2.""" + self.net, self.ema = build(self.config, use_ema=self.use_ema) + + # instantiate the teacher and consistency network + self.teacher = deepcopy(self.net) + self.teacher.eval().requires_grad_(False) + + # instantiate the fake_score + self.fake_score = deepcopy(self.net) + + if self.config.gan_loss_weight_gen > 0: + # instantiate the discriminator in DMD2 ({0, 1, 2} are all features) + self.discriminator = Discriminator_EDM( + feature_indices={0, 1, 2}, + all_res=[64, 32, 16], + in_channels=[64, 128, 128], + ) + + +MODEL_MAP = { + "cm": CMModel, + "scm": SCMModel, + "dmd2": DMD2Model, +} + + +def get_window_function( + patch_shape_x, patch_shape_y, window_alpha, type="KBD", **kwargs +): + """ + Get the window function for the superpatch. + returns: window function of shape (patch_shape_y, patch_shape_x) + """ + functions = { + "uniform": torch.ones, + "hann": lambda ps: windows.hann(ps, sym=True), + "hamming": lambda ps: windows.hamming(ps, sym=True), + "general_hamming": lambda ps: windows.general_hamming( + ps, window_alpha, sym=True + ), + "kaiser": lambda ps: windows.kaiser(ps, beta=window_alpha * np.pi, sym=True), + "tukey": lambda ps: windows.tukey(ps, alpha=window_alpha, sym=True), + "gaussian": lambda ps: windows.gaussian( + ps, std=window_alpha * ps / 2, sym=True + ), + "KBD": lambda ps: windows.kaiser_bessel_derived(ps, window_alpha * np.pi), + } + if type not in functions.keys(): + raise ValueError( + f"Unknown window function type {type}. Supported types are {list(functions.keys())}" + ) + + window_x = torch.tensor(functions[type](patch_shape_x), **kwargs) + window_y = torch.tensor(functions[type](patch_shape_y), **kwargs) + window = window_x.unsqueeze(0) * window_y.unsqueeze(1) + return window + + +def get_scheduler(name, cfg, optimizer): + """ + Get the scheduler for the CorrDiff-distillation training supported by FastGen framework. + """ + if name is None or name == "modulus_default": + scheduler = LambdaLR( + optimizer, lr_lambda=lambda _: 1.0 + ) # nul scheduler, lr stays constant + else: + schedule = getattr(lr_scheduler, name)(**cfg) + scheduler = LambdaLR( + optimizer, + lr_lambda=schedule, + ) + return scheduler + + +def few_step_sampler( + net: torch.nn.Module, + latents: torch.Tensor, + img_lr: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + patching: Optional[GridPatching2D] = None, + mean_hr: Optional[torch.Tensor] = None, + lead_time_label: Optional[torch.Tensor] = None, + sigma_max: float = 800, + sigma_mid: List[float] = None, + dtype: Optional[torch.dtype] = None, + **kwargs, +) -> torch.Tensor: + """ + Few-step sampler for distillation inference. + """ + # Safety check on type of patching + if patching is not None and not isinstance(patching, GridPatching2D): + raise ValueError("patching must be an instance of GridPatching2D.") + + # Safety check: if patching is used then img_lr and latents must have same + # height and width, otherwise there is mismatch in the number + # of patches extracted to form the final batch_size. + if patching: + if img_lr.shape[-2:] != latents.shape[-2:]: + raise ValueError( + f"img_lr and latents must have the same height and width, " + f"but found {img_lr.shape[-2:]} vs {latents.shape[-2:]}. " + ) + # img_lr and latents must also have the same batch_size, otherwise mismatch + # when processed by the network + if img_lr.shape[0] != latents.shape[0]: + raise ValueError( + f"img_lr and latents must have the same batch size, but found " + f"{img_lr.shape[0]} vs {latents.shape[0]}." + ) + batch_size = img_lr.shape[0] + + # latents to dtype if specified + if dtype is not None: + latents = latents.to(dtype) + + # Time step discretization. + sigma_mid = [] if sigma_mid is None else sigma_mid + # t_0 = T, t_N = 0 + # Max noise level (adjust based on what's supported by the network) + sigma_max = min(sigma_max, net.sigma_max) + t_steps = torch.tensor( + [sigma_max] + list(sigma_mid), dtype=latents.dtype, device=latents.device + ) + t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) + assert torch.all(t_steps[1:] <= t_steps[:-1]) + + # conditioning = [mean_hr, img_lr, global_lr, pos_embd] + x_lr = img_lr + if mean_hr is not None: + if mean_hr.shape[-2:] != img_lr.shape[-2:]: + raise ValueError( + f"mean_hr and img_lr must have the same height and width, " + f"but found {mean_hr.shape[-2:]} vs {img_lr.shape[-2:]}." + ) + x_lr = torch.cat((mean_hr.expand(x_lr.shape[0], -1, -1, -1), x_lr), dim=1) + x_lr = x_lr.to(latents.device) + + # input and position padding + patching + if patching: + # Patched conditioning [x_lr, mean_hr] + # (batch_size * patch_num, C_in + C_out, patch_shape_y, patch_shape_x) + x_lr = patching.apply(input=x_lr, additional_input=img_lr) + + # Function to select the correct positional embedding for each patch + def patch_embedding_selector(emb): + """Select and patch positional embeddings.""" + return patching.apply(emb[None].expand(batch_size, -1, -1, -1)) + else: + patch_embedding_selector = None + + # Sampling steps + latents = latents * t_steps[0] + for t_cur, t_next in zip(t_steps[:-1], t_steps[1:]): + # latent patching + if patching: + latents = patching.apply(input=latents) + + if lead_time_label is not None: + latents = net( + latents, + x_lr, + t_cur, + class_labels, + lead_time_label=lead_time_label, + embedding_selector=patch_embedding_selector, + ).to(latents.dtype) + else: + latents = net( + latents, + x_lr, + t_cur, + class_labels, + embedding_selector=patch_embedding_selector, + ).to(latents.dtype) + + if patching: + # Un-patch the denoised image + # (batch_size, C_out, img_shape_y, img_shape_x) + latents = patching.fuse(input=latents, batch_size=batch_size) + + if t_next > 0: + latents = latents + t_next * torch.randn_like(latents) + + return latents