-
Notifications
You must be signed in to change notification settings - Fork 640
[WIP - DO NOT REVIEW] Diffusion SDA recipe #1511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
CharlelieLrt
wants to merge
32
commits into
NVIDIA:main
Choose a base branch
from
CharlelieLrt:diffusion-sda-recipe
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 5 commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
55d94b6
Merged initial SDA training recipe
CharlelieLrt 688247d
Refactored sda training recipe
CharlelieLrt 1f9a6a4
Some more refactors before starting training
CharlelieLrt 926938b
Some updates to training script
CharlelieLrt 367cc56
Added missing nn.py
CharlelieLrt 513339d
Fixed missing iter in train.py
CharlelieLrt af5a09c
Switched to synchronous zarr reads in data.py
CharlelieLrt 1f068d0
Added random seeds in train.py
CharlelieLrt 9a2e345
Added functionality to handle DDP and compile wrappers in physicsnemo…
CharlelieLrt 6d0f79e
Fixed wrong normalization: now uses variance instead of std
CharlelieLrt 2a1b350
Fixed missing updates of samples counter in train.py
CharlelieLrt 2bae3ff
Fixed some minor bugs in data.py
CharlelieLrt 2ebe6ec
Improved resetting of global_index in patching.py
CharlelieLrt fc6ef61
Minor fixes in diffusion_sda example
CharlelieLrt 26e7820
Added AMP + minor fixes in train.py
CharlelieLrt 0990840
Added missing argument to enable AMP
CharlelieLrt bde5b81
Changed params in train.py
CharlelieLrt e528479
Added option to skip compilation of patching op in losses.py
CharlelieLrt 2ab7b12
Changed params in train.py
CharlelieLrt c2d9654
Merge branch 'diffusion-sda-recipe' of https://github.com/CharlelieLr…
CharlelieLrt ff0157d
Reverted option to skip compilation in losses.py and changed compile/…
CharlelieLrt 835ba56
Reworked patch exctraction to enable compilation in patching.py
CharlelieLrt 82a2ac5
Switched to torch.randint instead of random in patching.py
CharlelieLrt 13d6412
Merge branch 'diffusion-sda-recipe' of https://github.com/CharlelieLr…
CharlelieLrt a292492
Added logging back in train.py - NEEDED when not using hydra
CharlelieLrt c8d5c1b
Merge branch 'diffusion-sda-recipe' of https://github.com/CharlelieLr…
CharlelieLrt 1448ebb
Try some fix for the duplicated logging in train.py
CharlelieLrt b41fa08
Changed parameters in train.py
CharlelieLrt 1c5b97d
Merge branch 'diffusion-sda-recipe' of https://github.com/CharlelieLr…
CharlelieLrt 299dd11
Handle NaN in log-transformed variables
CharlelieLrt 4185eac
Config used to launch experiments
CharlelieLrt c156dcc
Merge branch 'main' into diffusion-sda-recipe
CharlelieLrt File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,202 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. | ||
| # SPDX-FileCopyrightText: All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import asyncio | ||
| import csv | ||
| import datetime | ||
| import os | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| import zarr | ||
| from torch.utils.data import Dataset | ||
|
|
||
| from physicsnemo.utils.zenith_angle import cos_zenith_angle | ||
|
|
||
|
|
||
| class HRRRSurfaceDataset(Dataset): | ||
| """HRRR Surface dataset on S3 | ||
|
|
||
| Parameters | ||
| ---------- | ||
| zarr_url : str | ||
| URL to Zarr group (e.g., s3://bucket/path) | ||
| storage_options : dict, optional | ||
| Backend/storage kwargs passed to Zarr opener (e.g., endpoint_url) | ||
| time_indices : np.array | ||
| Index array of times to use as part of dataset | ||
| stats_csv : str, optional | ||
| Stats CSV location, by default "stats/stats.csv" | ||
| """ | ||
|
|
||
| VARIABLES = [ | ||
| "u10m", | ||
| "v10m", | ||
| "u80m", | ||
| "v80m", | ||
| "t2m", | ||
| "d2m", | ||
| "q2m", | ||
| "sp", | ||
| "fg10m", | ||
| "tcc", | ||
| "sde", | ||
| "snowc", | ||
| "refc", | ||
| "rsds", | ||
| "tp", | ||
| "aerot", | ||
| ] | ||
| LOG_VARIABLES = ("tp", "aerot") # Make sure is consistent with stats CSV | ||
| EPSILON = 1e-8 | ||
|
|
||
| def __init__( | ||
| self, | ||
| zarr_url: str, | ||
| time_indices: np.array, | ||
| stats_csv: str = "stats/stats.csv", | ||
| storage_options: dict | None = None, | ||
| ): | ||
| self.zarr_url = zarr_url | ||
| self.storage_options = storage_options or {} | ||
| self.idx = np.asarray(time_indices, dtype=int).ravel() | ||
|
|
||
| # Verify bounds against available time coordinate in zarr | ||
| _root = zarr.open_group( | ||
| store=self.zarr_url, mode="r", storage_options=self.storage_options | ||
| ) | ||
| n_time = _root["time"].size | ||
| if np.any((self.idx < 0) | (self.idx >= n_time)): | ||
| invalid_values = np.unique(self.idx[out_of_bounds_mask]) | ||
| raise IndexError( | ||
| "time_indices contain out-of-bounds values for zarr_root['time']" | ||
| ) | ||
|
CharlelieLrt marked this conversation as resolved.
|
||
|
|
||
| # Load normalization stats and log-scaling flags from summary_stats.csv | ||
| stats_csv = os.path.join(os.path.dirname(__file__), stats_csv) | ||
| means = [] | ||
| stds = [] | ||
| stats_map = {} | ||
| with open(stats_csv, "r", newline="") as f: | ||
| reader = csv.DictReader(f) | ||
| for row in reader: | ||
| var = row.get("variable") | ||
| mu = float(row.get("mean", "nan")) | ||
| sd = float(row.get("std", "nan")) | ||
| stats_map[var] = (mu, sd) | ||
|
|
||
| # Order based on VARIABLES | ||
| for var in self.VARIABLES: | ||
| mu, sd = stats_map[var] | ||
| means.append(mu) | ||
| stds.append(sd) | ||
|
|
||
| # Instance-level overrides for normalization and log variables | ||
| self.target_means = ( | ||
| torch.tensor(means, dtype=torch.float32).unsqueeze(-1).unsqueeze(-1) | ||
| ) | ||
| self.target_stds = ( | ||
| torch.tensor(stds, dtype=torch.float32).unsqueeze(-1).unsqueeze(-1) | ||
| ) | ||
| # Save zarr coords to memory for use | ||
| self.grid_lat = _root["lat"][:] | ||
| self.grid_lon = _root["lon"][:] | ||
| self.time_array = _root["time"][:] | ||
|
|
||
| def __len__(self): | ||
| return self.idx.shape[0] | ||
|
|
||
| async def _zarr_read( | ||
| self, | ||
| root, | ||
| array_name: str, | ||
| array_idx: int, | ||
| time_idx: int, | ||
| data_arrays: np.array, | ||
| ): | ||
| arr = await root.get(array_name) | ||
| arr = await arr.getitem((time_idx, slice(None), slice(None))) | ||
| if array_name in self.LOG_VARIABLES: | ||
| data_arrays[array_idx] = np.log(arr + self.EPSILON) | ||
| else: | ||
| data_arrays[array_idx] = arr | ||
|
|
||
| async def _get_array(self, idx): | ||
| root = await zarr.api.asynchronous.open_group( | ||
| self.zarr_url, mode="r", storage_options=self.storage_options | ||
| ) | ||
|
|
||
| time_idx = self.idx[idx] | ||
| data_arrays = np.empty( | ||
| (len(self.VARIABLES), self.grid_lat.shape[0], self.grid_lat.shape[1]) | ||
| ) | ||
| jobs = [] | ||
| for i, t in enumerate(self.VARIABLES): | ||
| jobs.append(self._zarr_read(root, t, i, time_idx, data_arrays)) | ||
| await asyncio.gather(*jobs) | ||
| return data_arrays | ||
|
|
||
| def __getitem__(self, idx): | ||
| time_idx = self.idx[idx] | ||
| time_stamp = self.time_array[time_idx] | ||
| data_arrays = asyncio.run(self._get_array(idx)) | ||
|
|
||
| target = torch.Tensor(data_arrays) | ||
| target = (target - self.target_means) / self.target_stds | ||
| # Conditional encoding | ||
| data_arrays = np.empty( | ||
| (3, self.grid_lat.shape[0], self.grid_lat.shape[1]), dtype=np.float32 | ||
| ) | ||
| ts = (time_stamp - np.datetime64("1970-01-01T00:00:00Z")) / np.timedelta64( | ||
| 1, "s" | ||
| ) | ||
| data_arrays[0] = cos_zenith_angle( | ||
| datetime.datetime.utcfromtimestamp(ts), self.grid_lat, self.grid_lon | ||
| ) | ||
| data_arrays[1] = self.grid_lat / 90.0 | ||
| data_arrays[2] = self.grid_lon / 360.0 | ||
| condition_time = np.array( | ||
| [ | ||
| ( | ||
| time_stamp.astype("datetime64[D]") | ||
| - time_stamp.astype("datetime64[Y]") | ||
| + 1 | ||
| ).astype(int) | ||
| ], | ||
| dtype=np.int32, | ||
| ) | ||
|
|
||
| condition_spatial = torch.Tensor(data_arrays) | ||
| return target, condition_spatial, condition_time | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| root = zarr.open_group( | ||
| store="s3://hrrr-surface-sda/zarr-v2", | ||
| mode="r", | ||
| storage_options={"endpoint_url": "https://pdx.s8k.io"}, | ||
| ) | ||
| time = root["time"][:] | ||
| sidx = np.where(time == np.datetime64("2023-01-01T00:00:00"))[0][0] | ||
| eidx = np.where(time == np.datetime64("2023-02-01T00:00:00"))[0][0] | ||
|
|
||
| time_idx = np.arange(sidx, eidx) | ||
|
|
||
| dataset = HRRRSurfaceDataset(root, time_idx) | ||
| cond, target = dataset[30] | ||
|
CharlelieLrt marked this conversation as resolved.
Outdated
|
||
|
|
||
| print(cond.shape) | ||
| print(target) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.