From 720d97e7e5ba256351a5b8139dfe770129ad1d7d Mon Sep 17 00:00:00 2001 From: Peter Harrington Date: Wed, 8 Apr 2026 17:44:50 -0700 Subject: [PATCH 01/10] Add healda protocols and loaders to experimental --- examples/weather/healda/README.md | 104 +++ examples/weather/healda/requirements.txt | 9 + examples/weather/healda/test/conftest.py | 22 + .../healda/test/test_combined_schema.py | 64 ++ examples/weather/healda/test/test_features.py | 98 +++ examples/weather/healda/test/test_indexing.py | 93 +++ .../weather/healda/test/test_obs_filtering.py | 70 ++ examples/weather/healda/test/test_prefetch.py | 38 + examples/weather/healda/test/test_samplers.py | 143 ++++ .../weather/healda/test/test_time_utils.py | 68 ++ examples/weather/healda/test/test_types.py | 180 +++++ .../experimental/datapipes/__init__.py | 15 + .../experimental/datapipes/healda/__init__.py | 88 +++ .../datapipes/healda/configs/__init__.py | 14 + .../healda/configs/combined_schema.py | 77 ++ .../healda/configs/era5_13_levels_stats.csv | 77 ++ .../airs-pca_normalizations.csv | 65 ++ .../normalizations/amsua_normalizations.csv | 113 +++ .../normalizations/amsub_normalizations.csv | 16 + .../normalizations/atms_normalizations.csv | 67 ++ .../normalizations/conv_normalizations.csv | 9 + .../cris-fsr-pca_normalizations.csv | 97 +++ .../cris-fsr_normalizations.csv | 301 ++++++++ .../iasi-pca_normalizations.csv | 129 ++++ .../normalizations/iasi_normalizations.csv | 699 ++++++++++++++++++ .../normalizations/mhs_normalizations.csv | 31 + .../datapipes/healda/configs/sensors.py | 287 +++++++ .../datapipes/healda/configs/static_data.py | 60 ++ .../healda/configs/variable_configs.py | 75 ++ .../experimental/datapipes/healda/dataset.py | 190 +++++ .../experimental/datapipes/healda/indexing.py | 227 ++++++ .../datapipes/healda/loaders/__init__.py | 14 + .../datapipes/healda/loaders/era5.py | 207 ++++++ .../datapipes/healda/loaders/ufs_obs.py | 324 ++++++++ .../datapipes/healda/loaders/zarr_loader.py | 183 +++++ .../experimental/datapipes/healda/prefetch.py | 179 +++++ .../datapipes/healda/protocols.py | 104 +++ .../experimental/datapipes/healda/samplers.py | 178 +++++ .../datapipes/healda/time_utils.py | 69 ++ .../datapipes/healda/transforms/__init__.py | 14 + .../datapipes/healda/transforms/era5_obs.py | 412 +++++++++++ .../healda/transforms/obs_features.py | 361 +++++++++ .../healda/transforms/obs_features_ext.py | 318 ++++++++ .../healda/transforms/obs_filtering.py | 116 +++ .../experimental/datapipes/healda/types.py | 355 +++++++++ 45 files changed, 6360 insertions(+) create mode 100644 examples/weather/healda/README.md create mode 100644 examples/weather/healda/requirements.txt create mode 100644 examples/weather/healda/test/conftest.py create mode 100644 examples/weather/healda/test/test_combined_schema.py create mode 100644 examples/weather/healda/test/test_features.py create mode 100644 examples/weather/healda/test/test_indexing.py create mode 100644 examples/weather/healda/test/test_obs_filtering.py create mode 100644 examples/weather/healda/test/test_prefetch.py create mode 100644 examples/weather/healda/test/test_samplers.py create mode 100644 examples/weather/healda/test/test_time_utils.py create mode 100644 examples/weather/healda/test/test_types.py create mode 100644 physicsnemo/experimental/datapipes/__init__.py create mode 100644 physicsnemo/experimental/datapipes/healda/__init__.py create mode 100644 physicsnemo/experimental/datapipes/healda/configs/__init__.py create mode 100644 physicsnemo/experimental/datapipes/healda/configs/combined_schema.py create mode 100644 physicsnemo/experimental/datapipes/healda/configs/era5_13_levels_stats.csv create mode 100644 physicsnemo/experimental/datapipes/healda/configs/normalizations/airs-pca_normalizations.csv create mode 100644 physicsnemo/experimental/datapipes/healda/configs/normalizations/amsua_normalizations.csv create mode 100644 physicsnemo/experimental/datapipes/healda/configs/normalizations/amsub_normalizations.csv create mode 100644 physicsnemo/experimental/datapipes/healda/configs/normalizations/atms_normalizations.csv create mode 100644 physicsnemo/experimental/datapipes/healda/configs/normalizations/conv_normalizations.csv create mode 100644 physicsnemo/experimental/datapipes/healda/configs/normalizations/cris-fsr-pca_normalizations.csv create mode 100644 physicsnemo/experimental/datapipes/healda/configs/normalizations/cris-fsr_normalizations.csv create mode 100644 physicsnemo/experimental/datapipes/healda/configs/normalizations/iasi-pca_normalizations.csv create mode 100644 physicsnemo/experimental/datapipes/healda/configs/normalizations/iasi_normalizations.csv create mode 100644 physicsnemo/experimental/datapipes/healda/configs/normalizations/mhs_normalizations.csv create mode 100644 physicsnemo/experimental/datapipes/healda/configs/sensors.py create mode 100644 physicsnemo/experimental/datapipes/healda/configs/static_data.py create mode 100644 physicsnemo/experimental/datapipes/healda/configs/variable_configs.py create mode 100644 physicsnemo/experimental/datapipes/healda/dataset.py create mode 100644 physicsnemo/experimental/datapipes/healda/indexing.py create mode 100644 physicsnemo/experimental/datapipes/healda/loaders/__init__.py create mode 100644 physicsnemo/experimental/datapipes/healda/loaders/era5.py create mode 100644 physicsnemo/experimental/datapipes/healda/loaders/ufs_obs.py create mode 100644 physicsnemo/experimental/datapipes/healda/loaders/zarr_loader.py create mode 100644 physicsnemo/experimental/datapipes/healda/prefetch.py create mode 100644 physicsnemo/experimental/datapipes/healda/protocols.py create mode 100644 physicsnemo/experimental/datapipes/healda/samplers.py create mode 100644 physicsnemo/experimental/datapipes/healda/time_utils.py create mode 100644 physicsnemo/experimental/datapipes/healda/transforms/__init__.py create mode 100644 physicsnemo/experimental/datapipes/healda/transforms/era5_obs.py create mode 100644 physicsnemo/experimental/datapipes/healda/transforms/obs_features.py create mode 100644 physicsnemo/experimental/datapipes/healda/transforms/obs_features_ext.py create mode 100644 physicsnemo/experimental/datapipes/healda/transforms/obs_filtering.py create mode 100644 physicsnemo/experimental/datapipes/healda/types.py diff --git a/examples/weather/healda/README.md b/examples/weather/healda/README.md new file mode 100644 index 0000000000..fcda6ef1a9 --- /dev/null +++ b/examples/weather/healda/README.md @@ -0,0 +1,104 @@ +# HealDA — AI-based Data Assimilation on the HEALPix Grid + +> 🏗️🏗️ **This recipe is under active construction.** Structure and functionality are subject to changes 🏗️🏗️ + +HealDA is a stateless assimilation model that produces a single global weather analysis from conventional and satellite observations. It operates on a HEALPix level-6 padded XY grid and outputs ERA5-compatible atmospheric variables. + +This example provides a recipe to train HealDA, with support for extension to custom data. + +## Setup + +Start by installing PhysicsNeMo (if not already installed) with the `datapipes-extras` optional dependency group, along with the packages in `requirements.txt`. Then, copy this folder (`examples/weather/healda`) to a system with a GPU available. Also, prepare a dataset that can serve training data according to the protocols outlined in the [Generalized Data Loading](#generalized-data-loading) section below. + +## Generalized Data Loading + +The ``physicsnemo.experimental.datapipes.healda`` package provides a composable data loading pipeline with clear extension points. The architecture separates components into loaders, transforms, datasets, and sampling infrastructure. + +### Architecture + +``` +ObsERA5Dataset(era5_data, obs_loader, transform) + | Temporal windowing via FrameIndexGenerator + | __getitems__ -> get() per index -> transform.transform() + v +ChunkedDistributedSampler (contiguous chunks for cache locality) + | +DataLoader (1 worker each, pin_memory, persistent_workers) + | +RoundRobinLoader (interleaves per-worker DataLoaders) + | +prefetch_map(loader, transform.device_transform) + | +Training loop (GPU-ready batch) +``` + +### Key Protocols + +Custom data sources and transforms plug in via these protocols (see `physicsnemo.experimental.datapipes.healda.protocols`): + +**`ObsLoader`** — the observation loading interface: +```python +class MyObsLoader: + async def sel_time(self, times: pd.DatetimeIndex) -> dict[str, list[Any]]: + """Return {"obs": [pa.Table_per_time, ...]}""" + ... +``` + +**`Transform`** / **`DeviceTransform`** — two-stage batch processing: +```python +class MyTransform: + def transform(self, times, frames) -> dict[str, Any]: + """CPU-side: normalize, encode observations, time features.""" + ... + + def device_transform(self, batch, device) -> dict[str, Any]: + """GPU-side: move to device, compute observation features.""" + ... +``` + +### Provided Implementations + +| Component | Module | Description | +|---|---|---| +| `ObsERA5Dataset` | `healda.dataset` | Map-style dataset combining ERA5 state + observations | +| `UFSUnifiedLoader` | `healda.loaders.ufs_obs` | Parquet-based observation loader (satellite + conventional) | +| `ERA5Loader` | `healda.loaders.era5` | Async ERA5 zarr loader (not used by ObsERA5Dataset directly) | +| `ERA5ObsTransform` | `healda.transforms.era5_obs` | Two-stage transform with Triton feature kernels | +| `ChunkedDistributedSampler` | `healda.samplers` | Cache-friendly distributed sampler | +| `RoundRobinLoader` | `healda.samplers` | Multi-loader interleaving | +| `prefetch_map` | `healda.prefetch` | Background CUDA stream prefetching | + +All modules above are under `physicsnemo.experimental.datapipes` (abbreviated as `healda` in the table). + +### Writing a Custom Observation Loader + +Implement `async def sel_time(times)` returning a dict with observation data per timestamp: + +```python +class GOESRadianceLoader: + def __init__(self, data_path, channels): + self.data_path = data_path + self.channels = channels + + async def sel_time(self, times): + tables = [] + for t in times: + table = self._load_goes_radiances(t) + tables.append(table) + return {"obs": tables} +``` + +Then pass it to the dataset: +```python +from physicsnemo.experimental.datapipes.healda import ObsERA5Dataset +from physicsnemo.experimental.datapipes.healda.transforms.era5_obs import ERA5ObsTransform +from physicsnemo.experimental.datapipes.healda.configs.variable_configs import VARIABLE_CONFIGS + +dataset = ObsERA5Dataset( + era5_data=era5_xr["data"], + obs_loader=GOESRadianceLoader(...), + transform=ERA5ObsTransform(...), + variable_config=VARIABLE_CONFIGS["era5"], +) +``` + diff --git a/examples/weather/healda/requirements.txt b/examples/weather/healda/requirements.txt new file mode 100644 index 0000000000..1889a50887 --- /dev/null +++ b/examples/weather/healda/requirements.txt @@ -0,0 +1,9 @@ +# nvidia-physicsnemo[datapipes-extras] +cftime +pyarrow +dotenv +earth2grid @ git+https://github.com/NVlabs/earth2grid.git@main +healpy +matplotlib +joblib +icechunk \ No newline at end of file diff --git a/examples/weather/healda/test/conftest.py b/examples/weather/healda/test/conftest.py new file mode 100644 index 0000000000..73553c39ed --- /dev/null +++ b/examples/weather/healda/test/conftest.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + + +@pytest.fixture +def device(): + return "cuda" if torch.cuda.is_available() else "cpu" diff --git a/examples/weather/healda/test/test_combined_schema.py b/examples/weather/healda/test/test_combined_schema.py new file mode 100644 index 0000000000..48c9a43ebe --- /dev/null +++ b/examples/weather/healda/test/test_combined_schema.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the combined observation schema and sensor config consistency.""" + +import pyarrow as pa + +from physicsnemo.experimental.datapipes.healda.configs.combined_schema import ( + get_channel_table_schema, + get_combined_observation_schema, +) +from physicsnemo.experimental.datapipes.healda.configs.sensors import SENSOR_CONFIGS, SENSOR_NAME_TO_ID + + +def test_combined_schema_has_required_fields(): + schema = get_combined_observation_schema() + required = [ + "Latitude", "Longitude", "Absolute_Obs_Time", "DA_window", + "Platform_ID", "Observation", "Global_Channel_ID", + ] + for name in required: + assert name in schema.names, f"Missing required field: {name}" + + +def test_combined_schema_satellite_fields(): + schema = get_combined_observation_schema() + for name in ["Sat_Zenith_Angle", "Sol_Zenith_Angle", "Scan_Angle"]: + assert name in schema.names + + +def test_combined_schema_conventional_fields(): + schema = get_combined_observation_schema() + for name in ["Pressure", "Height", "Observation_Type"]: + assert name in schema.names + + +def test_channel_table_schema(): + schema = get_channel_table_schema() + assert "Global_Channel_ID" in schema.names + assert "sensor_id" in schema.names + assert "mean" in schema.names + assert "stddev" in schema.names + + +def test_sensor_configs_consistent(): + """All sensors in SENSOR_CONFIGS have a matching SENSOR_NAME_TO_ID entry.""" + for name in SENSOR_CONFIGS: + assert name in SENSOR_NAME_TO_ID + + +def test_sensor_channels_positive(): + for name, cfg in SENSOR_CONFIGS.items(): + assert cfg.channels > 0, f"Sensor {name} has non-positive channel count" diff --git a/examples/weather/healda/test/test_features.py b/examples/weather/healda/test/test_features.py new file mode 100644 index 0000000000..080a8e210c --- /dev/null +++ b/examples/weather/healda/test/test_features.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for observation metadata featurization (standard and extended). + +The Triton kernel tests require CUDA and validate that the Triton +implementation matches the reference Python implementation. +""" + +import pytest +import torch + +from physicsnemo.experimental.datapipes.healda.transforms import obs_features as standard +from physicsnemo.experimental.datapipes.healda.transforms import obs_features_ext as extended + + +def _make_obs_data(n, device, include_lat=False): + g = torch.Generator(device=device) + g.manual_seed(42) + + height = torch.rand(n, device=device, generator=g) * 50000 + pressure = torch.rand(n, device=device, generator=g) * 1100 + scan_angle = torch.rand(n, device=device, generator=g) * 100 - 50 + sat_zenith_angle = torch.rand(n, device=device, generator=g) * 120 - 60 + sol_zenith_angle = torch.rand(n, device=device, generator=g) * 160 + 10 + + # Conv/sat split: NaN height -> satellite, valid height -> conventional + is_sat = torch.rand(n, device=device, generator=g) < 0.4 + height[is_sat] = float("nan") + pressure[is_sat] = float("nan") + scan_angle[~is_sat] = float("nan") + sat_zenith_angle[~is_sat] = float("nan") + sol_zenith_angle[~is_sat] = float("nan") + + data = dict( + target_time_sec=torch.full( + (n,), 1_700_000_000, dtype=torch.int64, device=device + ), + time=torch.full( + (n,), 1_700_000_100_000_000_000, dtype=torch.int64, device=device + ), + lon=torch.rand(n, device=device, generator=g) * 360 - 180, + height=height, + pressure=pressure, + scan_angle=scan_angle, + sat_zenith_angle=sat_zenith_angle, + sol_zenith_angle=sol_zenith_angle, + ) + if include_lat: + data["lat"] = torch.rand(n, device=device, generator=g) * 180 - 90 + return data + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA required for Triton kernel" +) +@pytest.mark.parametrize("n", [0, 1, 137, 10_000]) +def test_standard_triton_matches_reference(n): + device = torch.device("cuda") + data = _make_obs_data(max(n, 1), device) + if n == 0: + data = {k: v[:0] for k, v in data.items()} + + ref = standard._compute_unified_metadata_reference(**data) + triton_out = standard.compute_unified_metadata(**data) + + assert ref.shape == triton_out.shape == (n, standard.N_FEATURES) + if n > 0: + torch.testing.assert_close(ref, triton_out, atol=1e-5, rtol=1e-5) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA required for Triton kernel" +) +@pytest.mark.parametrize("n", [0, 1, 137, 10_000]) +def test_extended_triton_matches_reference(n): + device = torch.device("cuda") + data = _make_obs_data(max(n, 1), device, include_lat=True) + if n == 0: + data = {k: val[:0] for k, val in data.items()} + + ref = extended._compute_unified_metadata_reference(**data) + triton_out = extended.compute_unified_metadata(**data) + + assert ref.shape == triton_out.shape == (n, extended.N_FEATURES) + if n > 0: + torch.testing.assert_close(ref, triton_out, atol=1e-5, rtol=1e-5) diff --git a/examples/weather/healda/test/test_indexing.py b/examples/weather/healda/test/test_indexing.py new file mode 100644 index 0000000000..b722c689ff --- /dev/null +++ b/examples/weather/healda/test/test_indexing.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for frame indexing and temporal windowing.""" + +import numpy as np +import torch + +from physicsnemo.experimental.datapipes.healda.indexing import FrameIndexGenerator, split_array_contiguous + + +def test_split_array_contiguous_single_segment(): + arr = np.arange(10) + (output,) = split_array_contiguous(arr) + assert np.all(arr == output) + + +def test_split_array_contiguous_two_segments(): + arr = np.array([0, 1, 2, 5, 6]) + out1, out2 = split_array_contiguous(arr) + assert out1.tolist() == [0, 1, 2] + assert out2.tolist() == [5, 6] + + +def test_frame_index_generator_basic(): + """Test basic frame index generation with striding.""" + times = np.arange(100) + generator = FrameIndexGenerator( + times=times, time_length=3, frame_step=2, model_rank=0, model_world_size=1 + ) + + start_indices = torch.tensor([0, 10]) + frame_idxs = generator.generate_frame_indices(start_indices) + + expected = [[0, 2, 4], [10, 12, 14]] + assert frame_idxs == expected + + +def test_frame_index_generator_model_rank_slicing(): + """Test model-parallel rank slicing of frame indices.""" + times = np.arange(100) + generator = FrameIndexGenerator( + times=times, time_length=4, frame_step=1, model_rank=1, model_world_size=2 + ) + + start_indices = torch.tensor([5]) + frame_idxs = generator.generate_frame_indices(start_indices) + + # Full range: [5, 6, 7, 8], rank 1 gets second half: [7, 8] + assert frame_idxs[0] == [7, 8] + + +def test_frame_index_generator_multiple_segments(): + """Test frame index generation across non-contiguous segments.""" + times = np.concatenate([ + np.arange(0, 10), # [0, 1, ..., 9] + np.arange(20, 35), # [20, 21, ..., 34] + ]) + + generator = FrameIndexGenerator( + times=times, time_length=3, frame_step=1, model_rank=0, model_world_size=1 + ) + + # Verify mapping across segment boundary + assert times[generator._map_logical_to_physical(0)] == 0 + assert times[generator._map_logical_to_physical(1)] == 1 + assert times[generator._map_logical_to_physical(7)] == 7 + assert times[generator._map_logical_to_physical(8)] == 20 + + assert all(times[generator.generate_frame_indices([7])[0]] == [7, 8, 9]) + assert all(times[generator.generate_frame_indices([8])[0]] == [20, 21, 22]) + + +def test_frame_index_generator_valid_length(): + """Test valid length computation.""" + times = np.arange(20) + generator = FrameIndexGenerator( + times=times, time_length=3, frame_step=2, model_rank=0, model_world_size=1 + ) + # frames_per_window = (3-1)*2 + 1 = 5 + # valid_length = 20 - 5 + 1 = 16 + assert generator.get_valid_length() == 16 diff --git a/examples/weather/healda/test/test_obs_filtering.py b/examples/weather/healda/test/test_obs_filtering.py new file mode 100644 index 0000000000..fa54ed205a --- /dev/null +++ b/examples/weather/healda/test/test_obs_filtering.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for observation quality-control filtering.""" + +import numpy as np +import pyarrow as pa + +from physicsnemo.experimental.datapipes.healda.configs.sensors import SENSOR_OFFSET +from physicsnemo.experimental.datapipes.healda.transforms.obs_filtering import filter_observations + + +def _make_filter_test_table(): + """Create a minimal table with channel metadata columns required for filtering.""" + conv_offset = SENSOR_OFFSET["conv"] + + # Mix of GPS (0,1,2), PS (3), UV (6,7) channels + channels = [ + conv_offset + 0, + conv_offset + 1, + conv_offset + 3, + conv_offset + 6, + conv_offset + 7, + ] + n = len(channels) + + return pa.table( + { + "Observation": np.array([100.0, 200.0, 500.0, 50.0, 60.0], dtype=np.float32), + "Global_Channel_ID": np.array(channels, dtype=np.uint16), + "Pressure": np.array([500.0, 800.0, 600.0, 400.0, 300.0], dtype=np.float32), + "Height": np.array([1000.0, 5000.0, 100.0, 2000.0, 3000.0], dtype=np.float32), + "Observation_Type": np.array([200, 210, 220, 230, 280], dtype=np.uint16), + "QC_Flag": np.array([0, 0, 0, 0, 0], dtype=np.int32), + "Analysis_Use_Flag": np.array([1, 1, 0, 1, 1], dtype=np.int8), + "min_valid": np.array([0.0, 0.0, 0.0, -100.0, -100.0], dtype=np.float32), + "max_valid": np.array([400.0, 400.0, 1e6, 100.0, 100.0], dtype=np.float32), + "is_conv": np.array([True, True, True, True, True]), + "local_channel_id": np.array([0, 1, 3, 6, 7], dtype=np.uint16), + } + ) + + +def test_filter_observations_basic(): + """Basic filtering removes out-of-range observations.""" + table = _make_filter_test_table() + filtered = filter_observations(table, qc_filter=False) + + assert filtered.num_rows >= 0 + assert filtered.num_rows <= table.num_rows + + +def test_filter_observations_qc(): + """QC filtering is more restrictive.""" + table = _make_filter_test_table() + no_qc = filter_observations(table, qc_filter=False) + with_qc = filter_observations(table, qc_filter=True) + + assert with_qc.num_rows <= no_qc.num_rows diff --git a/examples/weather/healda/test/test_prefetch.py b/examples/weather/healda/test/test_prefetch.py new file mode 100644 index 0000000000..f2c1b000fd --- /dev/null +++ b/examples/weather/healda/test/test_prefetch.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for prefetch_map background processing.""" + +import pytest + +from physicsnemo.experimental.datapipes.healda.prefetch import prefetch_map + + +def test_prefetch_map_basic(): + """Prefetch with a simple doubling transform.""" + data = list(range(10)) + loader = prefetch_map(data, lambda x: 2 * x) + assert list(loader) == list(range(0, 20, 2)) + + +def test_prefetch_map_error_propagation(): + """Exceptions in the background thread propagate to the consumer.""" + data = list(range(4)) + + def failing_transform(x): + raise ValueError("Test error") + + loader = prefetch_map(data, failing_transform) + with pytest.raises(ValueError, match="Test error"): + list(loader) diff --git a/examples/weather/healda/test/test_samplers.py b/examples/weather/healda/test/test_samplers.py new file mode 100644 index 0000000000..948a7a0c0f --- /dev/null +++ b/examples/weather/healda/test/test_samplers.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for ChunkedDistributedSampler and RoundRobinLoader.""" + +import itertools + +import torch +import torch.utils.data + +from physicsnemo.experimental.datapipes.healda.samplers import ChunkedDistributedSampler, RoundRobinLoader + + +def test_chunked_sampler_sequential(): + """Indices within a chunk must be consecutive.""" + s = ChunkedDistributedSampler(list(range(100)), chunk_size=5) + it = iter(s) + visited = set() + for chunk in range(20): + last_i = 0 + for i in range(5): + idx = next(it) + if i > 0: + assert idx - last_i == 1 + last_i = idx + visited.add(idx) + + assert len(visited) == 100 + + +def test_chunked_sampler_with_islice(): + """Verify iter(sampler) continues rather than resetting.""" + dataset = list(range(100)) + sampler = ChunkedDistributedSampler(dataset, chunk_size=10, drop_last=False) + + iterator = iter(sampler) + first_10 = list(itertools.islice(iterator, 10)) + assert first_10 == list(range(10)) + + # Re-calling iter should continue, not restart + iterator2 = iter(sampler) + next_10 = list(itertools.islice(iterator2, 10)) + assert next_10 == list(range(10, 20)) + assert first_10 != next_10 + + +def test_shuffle_within_chunk(): + """Within-chunk shuffle randomizes order but preserves membership.""" + s = ChunkedDistributedSampler( + list(range(100)), + chunk_size=10, + shuffle=False, + shuffle_within_chunk=True, + seed=42, + ) + + indices = list(s) + assert sorted(indices) == list(range(100)) + + first_chunk = indices[:10] + assert sorted(first_chunk) == list(range(10)) + assert first_chunk != list(range(10)) # order should differ + + +def test_shuffle_epoch_changes_chunks(): + """Epoch auto-increment produces different chunk orderings.""" + s = ChunkedDistributedSampler( + list(range(100)), + chunk_size=10, + shuffle=True, + shuffle_within_chunk=True, + seed=42, + ) + + epoch1 = list(s) + epoch2 = list(s) + + assert sorted(epoch1) == list(range(100)) + assert sorted(epoch2) == list(range(100)) + assert sorted(epoch1[:10]) != sorted(epoch2[:10]) + + +# --------------------------------------------------------------------------- +# RoundRobinLoader tests +# --------------------------------------------------------------------------- + + +def test_round_robin_loader(): + """Round-robin interleaving across three loaders.""" + loader1 = torch.utils.data.DataLoader(list(range(0, 10)), batch_size=2) + loader2 = torch.utils.data.DataLoader(list(range(10, 15)), batch_size=2) + loader3 = torch.utils.data.DataLoader(list(range(15, 20)), batch_size=2) + + rr = RoundRobinLoader([loader1, loader2, loader3]) + assert len(rr) == len(loader1) + len(loader2) + len(loader3) + + batches = list(rr) + assert len(batches) == 11 + + # First round + assert torch.equal(batches[0], torch.tensor([0, 1])) + assert torch.equal(batches[1], torch.tensor([10, 11])) + assert torch.equal(batches[2], torch.tensor([15, 16])) + + +def test_round_robin_loader_uneven(): + """Uneven loader lengths — shorter ones drop out first.""" + loader1 = torch.utils.data.DataLoader(list(range(0, 20)), batch_size=2) + loader2 = torch.utils.data.DataLoader(list(range(20, 22)), batch_size=2) + + rr = RoundRobinLoader([loader1, loader2]) + batches = list(rr) + assert len(batches) == 11 + + assert torch.equal(batches[0], torch.tensor([0, 1])) + assert torch.equal(batches[1], torch.tensor([20, 21])) + assert torch.equal(batches[2], torch.tensor([2, 3])) + + +def test_round_robin_loader_empty(): + rr = RoundRobinLoader([]) + assert list(rr) == [] + + +def test_round_robin_loader_single(): + loader = torch.utils.data.DataLoader(list(range(10)), batch_size=3) + rr = RoundRobinLoader([loader]) + expected = list(torch.utils.data.DataLoader(list(range(10)), batch_size=3)) + actual = list(rr) + assert len(actual) == len(expected) + for a, e in zip(actual, expected): + assert torch.equal(a, e) diff --git a/examples/weather/healda/test/test_time_utils.py b/examples/weather/healda/test/test_time_utils.py new file mode 100644 index 0000000000..aa4cd27664 --- /dev/null +++ b/examples/weather/healda/test/test_time_utils.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for datetime conversion utilities.""" + +import datetime + +import cftime +import numpy as np +import pandas as pd + +from physicsnemo.experimental.datapipes.healda.time_utils import as_cftime, as_numpy, as_pydatetime, as_timestamp + + +def test_as_numpy_from_pandas_index(): + idx = pd.date_range("2020-01-01", periods=3, freq="h") + result = as_numpy(idx) + assert isinstance(result, np.ndarray) + assert result.dtype == np.dtype("datetime64[ns]") + assert len(result) == 3 + + +def test_as_numpy_from_timestamp(): + ts = pd.Timestamp("2020-06-15T12:00:00") + result = as_numpy(ts) + assert result.shape == (1,) + + +def test_as_numpy_from_cftime(): + t = cftime.DatetimeGregorian(2022, 3, 1, 6, 0, 0) + result = as_numpy(t) + assert result.shape == (1,) + + +def test_as_cftime_roundtrip(): + ts = pd.Timestamp("2023-07-04T18:30:00") + cf = as_cftime(ts) + assert isinstance(cf, cftime.DatetimeGregorian) + assert cf.year == 2023 + assert cf.month == 7 + assert cf.day == 4 + assert cf.hour == 18 + assert cf.minute == 30 + + +def test_as_pydatetime_from_cftime(): + cf = cftime.DatetimeGregorian(2021, 12, 25, 0, 0, 0) + result = as_pydatetime(cf) + assert isinstance(result, datetime.datetime) + assert result.tzinfo is not None # UTC + + +def test_as_timestamp(): + idx = pd.date_range("2020-01-01", periods=1, freq="h") + result = as_timestamp(idx) + assert result.dtype == int + assert result[0] == 1577836800 # 2020-01-01T00:00:00 UTC diff --git a/examples/weather/healda/test/test_types.py b/examples/weather/healda/test/test_types.py new file mode 100644 index 0000000000..d83fa3b919 --- /dev/null +++ b/examples/weather/healda/test/test_types.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for UnifiedObservation and split_by_sensor.""" + +import pytest +import torch + +from physicsnemo.experimental.datapipes.healda.types import UnifiedObservation, split_by_sensor + + +def make_realistic_obs( + B: int = 2, T: int = 2, sensors: list[int] = [0, 1, 2] +) -> UnifiedObservation: + """Create realistic cyclic observation data matching real UFS patterns.""" + S = len(sensors) + + all_obs = [] + for b in range(B): + for t in range(T): + for i in range(6): + sensor_id = sensors[i % S] + all_obs.append((sensor_id, b, t, len(all_obs))) + + all_obs.sort(key=lambda x: (x[0], x[3])) + + values = torch.tensor([x[3] for x in all_obs], dtype=torch.float32) + + lengths_3d = torch.zeros((S, B, T), dtype=torch.int32) + for s_local, s_id in enumerate(sensors): + for b in range(B): + for t in range(T): + lengths_3d[s_local, b, t] = sum( + 1 for obs in all_obs if obs[0] == s_id and obs[1] == b and obs[2] == t + ) + + sensor_id_to_local = torch.full((max(sensors) + 1,), -1, dtype=torch.int32) + for local_idx, s_id in enumerate(sensors): + sensor_id_to_local[s_id] = local_idx + + nobs = len(all_obs) + return UnifiedObservation( + obs=values.unsqueeze(1).expand(nobs, 3), + time=values.long(), + float_metadata=values.unsqueeze(1).expand(nobs, 5), + pix=torch.arange(nobs, dtype=torch.long), + local_channel=torch.zeros(nobs, dtype=torch.long), + platform=torch.zeros(nobs, dtype=torch.long), + obs_type=torch.zeros(nobs, dtype=torch.long), + global_channel=torch.zeros(nobs, dtype=torch.long), + hpx_level=6, + lengths=lengths_3d, + sensor_id_to_local=sensor_id_to_local, + ) + + +def test_split_preserves_all_observations(): + obs = make_realistic_obs(B=2, T=2, sensors=[0, 1, 2]) + total_before = obs.obs.shape[0] + + split = split_by_sensor(obs, [0, 1, 2]) + + total_after = sum(split[sid].obs.shape[0] for sid in [0, 1, 2]) + assert total_after == total_before + + for sid in [0, 1, 2]: + assert split[sid].obs.shape[0] == 8 + + +def test_split_content_correctness(): + obs = make_realistic_obs(B=2, T=2, sensors=[0, 1, 2]) + split = split_by_sensor(obs, [0, 1, 2]) + + for sid in [0, 1, 2]: + assert split[sid].obs.shape[0] == 8 + + +def test_split_lengths_match_obs_count(): + obs = make_realistic_obs(B=1, T=2, sensors=[0, 1]) + split = split_by_sensor(obs, [0, 1]) + + for sid in [0, 1]: + s_obs = split[sid] + assert s_obs.lengths.sum().item() == s_obs.obs.shape[0] + + +def test_split_empty_sensor(): + obs = make_realistic_obs(B=1, T=1, sensors=[0, 1]) + split = split_by_sensor(obs, [0, 1, 2]) + + assert split[2].obs.shape[0] == 0 + assert split[2].lengths.shape == (1, 1, 1) + + +def test_split_requires_lengths(): + obs = UnifiedObservation( + obs=torch.randn(10, 3), + time=torch.zeros(10, dtype=torch.long), + float_metadata=torch.randn(10, 5), + pix=torch.zeros(10, dtype=torch.long), + local_channel=torch.zeros(10, dtype=torch.long), + platform=torch.zeros(10, dtype=torch.long), + obs_type=torch.zeros(10, dtype=torch.long), + global_channel=torch.zeros(10, dtype=torch.long), + hpx_level=6, + lengths=None, + sensor_id_to_local=None, + ) + + with pytest.raises(ValueError, match="lengths is required"): + split_by_sensor(obs, [0, 1]) + + +def test_lengths_nonnegative(): + obs = make_realistic_obs(B=2, T=3, sensors=[0, 1, 2]) + assert torch.all(obs.lengths >= 0) + + +def test_split_handles_sparse_windows(): + """Sensor missing from some (b,t) windows.""" + B, T = 2, 3 + sensors = [0, 4] + + all_obs = [] + for b in range(B): + for t in range(T): + all_obs.extend([(0, b, t)] * 2) + all_obs.extend([(4, 1, 2)] * 3) + + nobs = len(all_obs) + + lengths_3d = torch.zeros((2, B, T), dtype=torch.int32) + lengths_3d[0, :, :] = 2 + lengths_3d[1, 1, 2] = 3 + + sensor_id_to_local = torch.full((5,), -1, dtype=torch.int32) + for local_idx, s_id in enumerate(sensors): + sensor_id_to_local[s_id] = local_idx + + obs = UnifiedObservation( + obs=torch.arange(nobs, dtype=torch.float32).unsqueeze(1).expand(nobs, 3), + time=torch.zeros(nobs, dtype=torch.long), + float_metadata=torch.arange(nobs, dtype=torch.float32).unsqueeze(1).expand(nobs, 5), + pix=torch.arange(nobs, dtype=torch.long), + local_channel=torch.zeros(nobs, dtype=torch.long), + platform=torch.zeros(nobs, dtype=torch.long), + obs_type=torch.zeros(nobs, dtype=torch.long), + global_channel=torch.zeros(nobs, dtype=torch.long), + hpx_level=6, + lengths=lengths_3d, + sensor_id_to_local=sensor_id_to_local, + ) + + assert obs.batch_dims == (2, 3) + + split = split_by_sensor(obs, [0, 4, 99]) + + s0 = split[0] + assert s0.obs.shape[0] == 12 + assert s0.lengths.shape == (1, 2, 3) + assert s0.lengths.sum().item() == 12 + + s4 = split[4] + assert s4.obs.shape[0] == 3 + assert s4.lengths[0, 1, 2].item() == 3 + + s99 = split[99] + assert s99.obs.shape[0] == 0 + assert torch.all(s99.lengths == 0) diff --git a/physicsnemo/experimental/datapipes/__init__.py b/physicsnemo/experimental/datapipes/__init__.py new file mode 100644 index 0000000000..af85283aa4 --- /dev/null +++ b/physicsnemo/experimental/datapipes/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/physicsnemo/experimental/datapipes/healda/__init__.py b/physicsnemo/experimental/datapipes/healda/__init__.py new file mode 100644 index 0000000000..14315e6e3b --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/__init__.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""HealDA data loading pipeline. + +Provides the complete data pipeline for HealDA training: observation loading, +ERA5 state loading, two-stage transforms (CPU + GPU), distributed sampling, +and background CUDA prefetching. + +Key entry points: + +- :class:`ObsERA5Dataset` — map-style dataset combining ERA5 state + observations +- :class:`UFSUnifiedLoader` — parquet-based observation loader +- :class:`ERA5ObsTransform` — two-stage transform with Triton feature kernels +- :func:`prefetch_map` — background CUDA stream prefetching +- :class:`ChunkedDistributedSampler` — cache-friendly distributed sampler +- :class:`RoundRobinLoader` — multi-loader round-robin interleaving + +Protocols for custom loaders/transforms: + +- :class:`ObsLoader` — async observation loading interface +- :class:`Transform` — CPU-side batch transform +- :class:`DeviceTransform` — GPU-side batch transform +""" + +from physicsnemo.experimental.datapipes.healda.dataset import ObsERA5Dataset +from physicsnemo.experimental.datapipes.healda.indexing import ( + FrameIndexGenerator, + MultiCoordIndex, + get_flat_indexer, + split_array_contiguous, +) +from physicsnemo.experimental.datapipes.healda.prefetch import prefetch_map +from physicsnemo.experimental.datapipes.healda.protocols import ( + DeviceTransform, + ObsLoader, + Transform, +) +from physicsnemo.experimental.datapipes.healda.samplers import ( + ChunkedDistributedSampler, + RoundRobinLoader, +) +from physicsnemo.experimental.datapipes.healda.types import ( + Batch, + BatchInfo, + TimeUnit, + UnifiedObservation, + VariableConfig, + empty_batch, + split_by_sensor, +) + +__all__ = [ + # Dataset + "ObsERA5Dataset", + # Protocols + "ObsLoader", + "Transform", + "DeviceTransform", + # Types + "UnifiedObservation", + "Batch", + "BatchInfo", + "VariableConfig", + "TimeUnit", + "empty_batch", + "split_by_sensor", + # Infrastructure + "prefetch_map", + "ChunkedDistributedSampler", + "RoundRobinLoader", + "FrameIndexGenerator", + "MultiCoordIndex", + "get_flat_indexer", + "split_array_contiguous", +] diff --git a/physicsnemo/experimental/datapipes/healda/configs/__init__.py b/physicsnemo/experimental/datapipes/healda/configs/__init__.py new file mode 100644 index 0000000000..3159bfe656 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/configs/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/physicsnemo/experimental/datapipes/healda/configs/combined_schema.py b/physicsnemo/experimental/datapipes/healda/configs/combined_schema.py new file mode 100644 index 0000000000..fb2138c099 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/configs/combined_schema.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Combined PyArrow schema for unified satellite and conventional observation data. + +This schema handles both satellite observations (atms, mhs, amsua, etc.) +and conventional observations (gps, ps, q, t, uv). Conventional observations +are flattened into a single ``Observation`` column with multiple rows per +location for multi-component observations. +""" + +import pyarrow as pa + +GLOBAL_CHANNEL_ID = pa.field("Global_Channel_ID", pa.uint16(), nullable=False) +SENSOR_ID = pa.field("sensor_id", pa.uint16()) + + +def get_combined_observation_schema() -> pa.Schema: + """Create a combined PyArrow schema for satellite and conventional obs.""" + common_fields = [ + pa.field("Latitude", pa.float32()), + pa.field("Longitude", pa.float32()), + pa.field("Absolute_Obs_Time", pa.timestamp("ns")), + pa.field("DA_window", pa.timestamp("ns")), + pa.field("Platform_ID", pa.uint16()), + pa.field("Observation", pa.float32()), + GLOBAL_CHANNEL_ID, + ] + + satellite_fields = [ + pa.field("Sat_Zenith_Angle", pa.float32(), nullable=True), + pa.field("Sol_Zenith_Angle", pa.float32(), nullable=True), + pa.field("Scan_Angle", pa.float32(), nullable=True), + ] + + conventional_fields = [ + pa.field("Pressure", pa.float32(), nullable=True), + pa.field("Height", pa.float32(), nullable=True), + pa.field("Observation_Type", pa.uint16(), nullable=True), + ] + + analysis_fields = [ + pa.field("QC_Flag", pa.int32(), nullable=True), + pa.field("Analysis_Use_Flag", pa.int8(), nullable=True), + ] + + all_fields = ( + common_fields + satellite_fields + conventional_fields + analysis_fields + ) + return pa.schema(all_fields) + + +def get_channel_table_schema(): + """Schema for the channel metadata table.""" + return pa.schema( + [ + GLOBAL_CHANNEL_ID, + pa.field("min_valid", pa.float32()), + pa.field("max_valid", pa.float32()), + SENSOR_ID, + pa.field("is_conv", pa.bool_()), + pa.field("name", pa.string()), + pa.field("mean", pa.float32()), + pa.field("stddev", pa.float32()), + ] + ) diff --git a/physicsnemo/experimental/datapipes/healda/configs/era5_13_levels_stats.csv b/physicsnemo/experimental/datapipes/healda/configs/era5_13_levels_stats.csv new file mode 100644 index 0000000000..d57746565e --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/configs/era5_13_levels_stats.csv @@ -0,0 +1,77 @@ +variable,level,std,mean +U,1000,6.0761532856545495,-0.4012036280949351 +U,925,7.8438841318137085,0.1875586020534487 +U,850,8.117545798838897,1.0676392264524672 +U,700,9.161262910521476,3.1993547283822656 +U,600,10.351920350301654,4.753506500293617 +U,500,12.020863918100371,6.643045555863143 +U,400,14.37786455089647,9.166420770646372 +U,300,17.368910160479544,12.604205075841238 +U,250,18.63126418939197,14.458098693848147 +U,200,18.826071570402007,15.603421939654016 +U,150,17.338743800703853,14.694208759254527 +U,100,14.207755780676433,10.084541123485184 +U,50,14.828288062604061,3.447298300210448 +V,1000,5.0562302721142185,0.1881812905956076 +V,925,6.087634459436962,0.18351119177675437 +V,850,5.788139666651214,0.08865528649740444 +V,700,6.326620906882278,-0.01624326876106696 +V,600,7.18152443362136,-0.046446792220198006 +V,500,8.45848562350414,-0.032289933410975316 +V,400,10.38343861775161,-0.02010503440455113 +V,300,12.609499025107015,-0.025803734238538843 +V,250,13.057155983469311,-0.04324580333432856 +V,200,12.04089626815829,-0.07330144652625546 +V,150,9.733874174308893,-0.06046714286886432 +V,100,7.052043591267376,0.016139828470052277 +V,50,5.626839932669169,0.000506756767366129 +T,1000,13.3926808838397,288.34811963631944 +T,925,12.75887207878649,284.14180529926887 +T,850,12.337912690900826,281.1236718383152 +T,700,11.485912313471955,273.6522781145025 +T,600,10.914242703518575,266.8225997851199 +T,500,10.947319271596404,258.48344241794354 +T,400,10.889317302118906,247.5078738667765 +T,300,9.38121872205424,233.2984452545507 +T,250,7.28726013236191,225.6329557585312 +T,200,5.29271246649905,218.53493343877926 +T,150,7.447585393440247,211.46310701251957 +T,100,11.427625612914374,204.83933151080333 +T,50,7.494050550499991,211.44287440973292 +Z,1000,893.4657098478391,935.5049140418249 +Z,925,1008.7658942548571,7360.676041141024 +Z,850,1197.2156800245555,14248.78134896371 +Z,700,1736.7489061933863,29767.144398835746 +Z,600,2188.213530821762,41747.04921153613 +Z,500,2729.9217253050874,55509.29599258445 +Z,400,3402.601245989482,71726.5567290568 +Z,300,4213.390725518206,91575.34112031905 +Z,250,4602.533249864132,103579.02469155286 +Z,200,4831.571016444258,117791.77588622355 +Z,150,4711.8643515216945,135539.4191093247 +Z,100,4105.159434190763,159702.46332065153 +Z,50,3851.9190680022352,200924.18906200194 +Q,1000,0.005781177709317301,0.00936470198136523 +Q,925,0.004981825818460944,0.008008882445730241 +Q,850,0.004179579792520903,0.006031608541363992 +Q,700,0.002737859580661181,0.0031682692015671905 +Q,600,0.0019479582011688481,0.0019997242562778714 +Q,500,0.0012087878889463929,0.0011044648648525284 +Q,400,0.0005684813340060475,0.0005001939892631466 +Q,300,0.00018788822542209214,0.00016860593486309685 +Q,250,8.253217825122686e-05,7.728339864456503e-05 +Q,200,2.476106074180645e-05,2.5977126545155267e-05 +Q,150,4.08396736011409e-06,6.446599195006216e-06 +Q,100,6.154373024833254e-07,2.6868721151493317e-06 +Q,50,2.5960505606920645e-07,2.6752383145132987e-06 +tcwv,-1,16.707112756025314,24.224098621015028 +tas,-1,15.355112655773071,287.40642670567786 +uas,-1,5.436672873915895,-0.3751276131251324 +vas,-1,4.491587780564064,0.18441198768158903 +100u,-1,6.684378709883516,-0.36058662354821425 +100v,-1,5.613643415909631,0.1893235299805761 +pres_msl,-1,1109.2809461275167,101138.98195665042 +sst,-1,8.851771853018453,290.8944586515412 +sic,-1,0.18702627733432053,0.04226433445734063 +orog,-1.0,627.3885284872,232.56013904090733 +lfrac,-1.0,0.4695501683565522,0.3410480857539571 \ No newline at end of file diff --git a/physicsnemo/experimental/datapipes/healda/configs/normalizations/airs-pca_normalizations.csv b/physicsnemo/experimental/datapipes/healda/configs/normalizations/airs-pca_normalizations.csv new file mode 100644 index 0000000000..704ad1fc8a --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/configs/normalizations/airs-pca_normalizations.csv @@ -0,0 +1,65 @@ +Raw_Channel_ID,Platform_ID,obs_std,obs_mean,obs_min,obs_max +1,-1,8.214708551830356,-0.007816872517859445,-54.583263,25.529593 +1,0,8.214708551830421,-0.007816872517859456,-54.583263,25.529593 +2,-1,6.204024141659925,0.060118015162609874,-54.41878,28.946451 +2,0,6.204024141659928,0.060118015162609874,-54.41878,28.946451 +3,-1,1.7957874344187403,-0.04797540239732354,-11.155533,14.771423 +3,0,1.7957874344187437,-0.047975402397323494,-11.155533,14.771423 +4,-1,1.4675865871956242,-0.0006924302620134747,-9.648644,9.285899 +4,0,1.4675865871956315,-0.0006924302620134743,-9.648644,9.285899 +5,-1,0.8913469566843333,-0.005252729000527411,-5.767573,7.9473076 +5,0,0.8913469566843248,-0.005252729000527415,-5.767573,7.9473076 +6,-1,0.623793467189813,-0.003225442053726113,-5.3376155,3.5984225 +6,0,0.6237934671898061,-0.0032254420537261115,-5.3376155,3.5984225 +7,-1,0.39469768069981787,-0.0023120367162287455,-5.812195,7.6763997 +7,0,0.39469768069981237,-0.002312036716228743,-5.812195,7.6763997 +8,-1,0.33593410156027675,0.005151922195575682,-9.785906,9.08154 +8,0,0.33593410156027365,0.005151922195575682,-9.785906,9.08154 +9,-1,0.28105098555908775,0.001158328666100196,-3.716396,3.4178336 +9,0,0.281050985559087,0.0011583286661001964,-3.716396,3.4178336 +10,-1,0.22843497168668764,0.0036994384355875766,-3.6491232,4.755522 +10,0,0.22843497168668825,0.003699438435587577,-3.6491232,4.755522 +11,-1,0.21192941892048298,-0.0003545486185746004,-13.372186,3.3225749 +11,0,0.21192941892048256,-0.0003545486185746005,-13.372186,3.3225749 +12,-1,0.1347720053776834,-0.0007126142925684956,-6.045535,2.4844725 +12,0,0.13477200537768147,-0.0007126142925684956,-6.045535,2.4844725 +13,-1,0.11211111955725256,-0.0005063795610494643,-1.8508754,4.7352424 +13,0,0.11211111955725352,-0.000506379561049464,-1.8508754,4.7352424 +14,-1,0.11594864230525384,0.002015852397427362,-2.623983,5.3442745 +14,0,0.11594864230525545,0.002015852397427362,-2.623983,5.3442745 +15,-1,0.11065300843350233,0.0006929825699963552,-2.9404325,3.254625 +15,0,0.11065300843350065,0.0006929825699963552,-2.9404325,3.254625 +16,-1,0.11059560698104977,0.00025496963146238625,-5.1120696,1.139813 +16,0,0.11059560698104766,0.0002549696314623864,-5.1120696,1.139813 +17,-1,0.10827448760995192,-0.0008072181292068404,-11.610527,1.4685344 +17,0,0.1082744876099539,-0.0008072181292068405,-11.610527,1.4685344 +18,-1,0.10103222698655831,-0.001441306732794491,-3.8353717,4.2164545 +18,0,0.10103222698655692,-0.0014413067327944903,-3.8353717,4.2164545 +19,-1,0.10147859528906694,-0.0006220671585844798,-2.781216,3.3901258 +19,0,0.10147859528906512,-0.0006220671585844798,-2.781216,3.3901258 +20,-1,0.10368942955363945,-0.00020344373144167612,-1.5792689,8.15311 +20,0,0.10368942955363931,-0.00020344373144167606,-1.5792689,8.15311 +21,-1,0.10095808472137889,-0.00023500744504603737,-1.6666156,1.577552 +21,0,0.10095808472137917,-0.00023500744504603707,-1.6666156,1.577552 +22,-1,0.10052316918429405,-3.0905465844828725e-05,-1.1307209,9.927661 +22,0,0.1005231691842936,-3.090546584482874e-05,-1.1307209,9.927661 +23,-1,0.09880220102875298,0.00047988024789517484,-0.9339404,6.719097 +23,0,0.0988022010287523,0.00047988024789517495,-0.9339404,6.719097 +24,-1,0.08934753826574814,0.0005772402600992258,-10.382512,1.9781011 +24,0,0.0893475382657479,0.0005772402600992256,-10.382512,1.9781011 +25,-1,0.08806596592232664,-0.001283645375872585,-5.6984606,2.2212803 +25,0,0.08806596592232749,-0.001283645375872586,-5.6984606,2.2212803 +26,-1,0.08731936349619145,-0.0006102947848890263,-1.3738391,7.589438 +26,0,0.08731936349619131,-0.0006102947848890257,-1.3738391,7.589438 +27,-1,0.08717698513915002,-0.0005823197553379738,-2.9844427,1.3560678 +27,0,0.087176985139149,-0.0005823197553379737,-2.9844427,1.3560678 +28,-1,0.08492734980919361,-0.00038831001660858896,-2.6943138,1.0411136 +28,0,0.08492734980919496,-0.0003883100166085897,-2.6943138,1.0411136 +29,-1,0.08305636257870289,-0.00016189496049658577,-9.235278,2.5213282 +29,0,0.08305636257870143,-0.0001618949604965858,-9.235278,2.5213282 +30,-1,0.08194517117431961,-0.00030206228476307577,-1.3414413,1.4573736 +30,0,0.08194517117432089,-0.0003020622847630757,-1.3414413,1.4573736 +31,-1,0.0812069714976724,-8.516453259253043e-05,-0.6720387,7.2829967 +31,0,0.08120697149767325,-8.516453259253042e-05,-0.6720387,7.2829967 +32,-1,0.08075419401644526,-0.0008985027944582864,-11.253723,0.53134376 +32,0,0.08075419401644411,-0.0008985027944582865,-11.253723,0.53134376 diff --git a/physicsnemo/experimental/datapipes/healda/configs/normalizations/amsua_normalizations.csv b/physicsnemo/experimental/datapipes/healda/configs/normalizations/amsua_normalizations.csv new file mode 100644 index 0000000000..44ccce3ac0 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/configs/normalizations/amsua_normalizations.csv @@ -0,0 +1,113 @@ +Raw_Channel_ID,Platform_ID,obs_std,obs_mean +1,-1,43.14827526898869,208.81245889545207 +1,14,43.245863004474046,208.5147247714152 +1,15,42.644794686368726,207.84771330773998 +1,16,43.32608185643857,208.88801777990457 +1,20,42.862800259005354,209.69306267322926 +1,21,41.00672569472209,206.3581961906114 +1,23,43.10684451429107,208.63778656625053 +1,24,43.477805956939,208.56640086972172 +2,-1,44.96814209305951,200.67487673075226 +2,14,45.0688240991038,200.64044574997308 +2,15,44.65034395075911,199.57948659408837 +2,16,45.10653156173914,200.78904240241823 +2,20,44.54634966722693,200.9030014910355 +2,21,42.16959824168713,197.74683727240892 +2,23,44.961453342136245,200.72370211205973 +2,24,45.38303896246749,200.65617535973996 +3,-1,20.60997229307027,241.33242782399282 +3,14,20.638403211245485,240.9346938961322 +3,15,20.56083967162193,241.01957817359767 +3,16,20.756384664385838,241.36646570607004 +3,20,20.40978773273391,242.06647213486357 +3,21,18.549164473665893,239.64565679658295 +3,23,20.594839602248793,241.05065270256063 +3,24,20.776392520670253,241.31193090983035 +4,-1,12.391091090191496,255.40195807265977 +4,14,12.39558220414599,255.2485698626346 +4,15,12.441051959212107,255.61576880929422 +4,16,12.39107026991132,255.39039410065763 +4,20,12.457001058263584,255.74958936586165 +4,21,11.702991216116937,254.80897097129633 +4,23,12.367115477186875,255.28127542166982 +4,24,12.324249047160235,255.26146633367378 +5,-1,9.72289107163208,248.30325035989532 +5,14,9.712855426440534,248.12249857181723 +5,15,9.782916048429138,248.4087687449923 +5,16,9.773101672516752,248.26028480979204 +5,20,9.71372617268573,248.84428383985446 +5,21,9.270351647297238,248.21563820415776 +5,23,9.699518172278252,248.0950168016271 +5,24,9.705036110488924,247.99733787940366 +6,-1,6.591734627516939,233.2546738888034 +6,14,6.626838235834015,233.21555828000677 +6,15,6.715863262633315,233.8078009560835 +6,16,6.751256745920496,233.18885362349366 +6,20,6.2770916623924276,233.12018885091325 +6,21,6.055423123419333,233.83423502032096 +6,23,6.615882647299575,233.14683849524496 +6,24,6.603232459352038,233.36377680690163 +7,-1,5.200585036541759,223.37161080931776 +7,14,5.328507768691345,222.63488189090833 +7,15,5.290039398907846,223.40604859527315 +7,16,5.286683193657975,222.87131600917934 +7,20,5.065071419902494,224.32254865454104 +7,21,4.660127715690419,223.87643319144212 +7,23,5.171385847225267,222.70343722755194 +7,24,5.07430384498421,223.1480783475126 +8,-1,6.0071988697567456,216.48895482206066 +8,14,5.993478811669694,216.18038748856023 +8,15,5.992285017757789,216.32746437053305 +8,16,5.9158778196382995,216.28546186411424 +8,20,6.002563219758394,217.0717997186405 +8,21,6.320274200173209,217.0598985070878 +8,23,6.017509215427473,215.95926857080985 +8,24,5.297709489024461,215.76674689563418 +9,-1,8.509114940634637,211.2537007874646 +9,14,8.546910890556669,211.04022344033265 +9,15,8.485593502788452,211.30400409997011 +9,16,8.31323746741205,210.86774331842972 +9,20,8.532774635936512,211.7353577068666 +9,23,8.602809925010613,210.69029206493593 +9,24,8.490408071518324,211.0450224915925 +10,-1,8.008770826557447,214.5290848362116 +10,14,8.031084627218561,214.49865927153715 +10,15,7.98725303457022,214.67147267452094 +10,16,7.901115635400929,214.31155514650166 +10,20,8.053008659834964,215.03146673742023 +10,23,8.016193430235884,214.09543446059095 +10,24,7.944071272830509,214.45702080831023 +11,-1,8.572729274070603,220.667324154054 +11,14,8.612188988458303,220.78406698291334 +11,15,8.578335740928825,220.91224176979185 +11,16,8.606801252318974,220.51350307302872 +11,23,8.590054337884801,220.36146839247078 +11,24,8.489296811083372,220.762503852709 +12,-1,9.827017164339123,229.2238827001046 +12,14,9.848325150318507,229.15015194738777 +12,15,9.787099403535226,229.28036940555563 +12,16,9.802699078194221,228.54596812574724 +12,20,9.876530004901452,229.60541801900766 +12,23,9.857903078183186,228.98018510273 +12,24,9.733668818886466,229.31037852193862 +13,-1,10.657981916161267,239.62808380868842 +13,14,10.676786662443906,239.41935483813367 +13,15,10.57700750393631,239.59523381548578 +13,16,10.591162650433894,239.05696101001703 +13,20,10.702002553441044,240.14076712717042 +13,23,10.708136300859543,239.4459129578857 +13,24,10.59151936202361,239.66792244377427 +14,-1,10.460706383630459,249.80052228607403 +14,14,10.518120440742436,249.5965305240693 +14,15,10.355629661230186,249.94929470130802 +14,16,10.340302426577065,249.42994207978592 +14,23,10.525085837365992,249.78271346792957 +14,24,10.436851910807803,250.01848408363801 +15,-1,27.206282747597683,241.08475017755075 +15,14,27.38726738210949,241.0812384969291 +15,15,27.297859237142635,241.426373404627 +15,16,27.312180489599005,241.39647412995765 +15,20,26.740265157879765,240.9054814678835 +15,21,26.25128791327429,237.4066499607801 +15,23,27.290915782286874,240.861508938203 +15,24,27.384349423192475,241.38266784296735 diff --git a/physicsnemo/experimental/datapipes/healda/configs/normalizations/amsub_normalizations.csv b/physicsnemo/experimental/datapipes/healda/configs/normalizations/amsub_normalizations.csv new file mode 100644 index 0000000000..07d1fe95b3 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/configs/normalizations/amsub_normalizations.csv @@ -0,0 +1,16 @@ +Raw_Channel_ID,Platform_ID,obs_std,obs_mean +1,-1,27.2737718832671,236.66623997760456 +1,21,27.025331796802135,236.49292906247072 +1,22,27.46773606448012,236.80319681448057 +2,-1,27.203280859273686,259.53019990435104 +2,21,27.22256458977998,259.7678475989869 +2,22,27.186577013708682,259.34254894910356 +3,-1,9.489860500316203,247.52103646621188 +3,21,9.638507561000745,247.71051502953418 +3,22,9.36779125986415,247.37103182370328 +4,-1,12.906392940834248,258.5867625015679 +4,21,12.922531248194963,259.3452157730607 +4,22,12.862125887997312,257.98840743403474 +5,-1,18.451601367383574,264.40828198308105 +5,21,18.290864177699667,263.3042943725722 +5,22,18.53145681557096,265.2848912188969 diff --git a/physicsnemo/experimental/datapipes/healda/configs/normalizations/atms_normalizations.csv b/physicsnemo/experimental/datapipes/healda/configs/normalizations/atms_normalizations.csv new file mode 100644 index 0000000000..e1b35d0f4d --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/configs/normalizations/atms_normalizations.csv @@ -0,0 +1,67 @@ +Raw_Channel_ID,Platform_ID,obs_std,obs_mean +1,-1,43.09279016587131,210.02112362010823 +1,25,43.22600229860692,209.8786665026229 +1,26,43.02804318886837,210.09002498629346 +2,-1,44.86846399304575,201.07331959768962 +2,25,45.01897932416165,200.9561896404016 +2,26,44.79537395581081,201.12997111896166 +3,-1,20.91136978667041,240.9698729566799 +3,25,21.021514548376974,240.87188064377247 +3,26,20.857843018367628,241.0171624463704 +4,-1,15.539263168345263,251.10140095634452 +4,25,15.666611148345122,250.9308978682481 +4,26,15.476750931187427,251.18369532111635 +5,-1,11.871982173921344,255.5391212721772 +5,25,11.960523646605074,255.4958955549361 +5,26,11.828949044784242,255.55998657237063 +6,-1,9.684743648063607,248.12555462660075 +6,25,9.764115534036241,248.04013477579554 +6,26,9.645898474886815,248.16681639990207 +7,-1,6.784248199795853,233.28375341081514 +7,25,6.856615715379073,232.9826081385835 +7,26,6.744210848770702,233.4291778163615 +8,-1,5.345625849992116,222.9450280954434 +8,25,5.436060756468369,222.7293455903235 +8,26,5.298259131893759,223.0491825905517 +9,-1,5.991791106218113,215.77753474885782 +9,25,6.008561080213766,215.47870864850285 +9,26,5.97833237796406,215.92181916545474 +10,-1,8.255314915635376,211.32904065225617 +10,25,8.188014023078008,210.92054661357824 +10,26,8.280406858326261,211.52630443706133 +11,-1,7.73402919628314,214.83510628159783 +11,25,7.7618820734836405,214.45412105711267 +11,26,7.713808772103638,215.01908359278923 +12,-1,8.383691337628381,221.3587919894676 +12,25,8.511318845815346,220.9317387710988 +12,26,8.313507519747509,221.56501849554937 +13,-1,9.592169458722168,229.7670172844918 +13,25,9.730511977369336,229.30138467444334 +13,26,9.516498564639026,229.99184850867996 +14,-1,10.36873946997381,240.4398718490441 +14,25,10.485895010432683,240.0398505723095 +14,26,10.3061359195915,240.63302697849377 +15,-1,10.10820595492633,250.57600766864383 +15,25,10.142481762137182,250.22218274721686 +15,26,10.087171228051647,250.74685867758924 +16,-1,27.16492755041575,241.56023841681622 +16,25,27.28324796338492,241.81119744782654 +16,26,27.106770027220723,241.43904225866117 +17,-1,25.07116867528301,263.5651818722436 +17,25,25.050729106218757,263.4870884050113 +17,26,25.080946292324455,263.6028943828347 +18,-1,18.129795132200364,264.1620960226355 +18,25,18.08394182285079,263.9588970537857 +18,26,18.151082646579844,264.2602246948714 +19,-1,14.859816103843695,261.3678407504765 +19,25,14.858920932630344,261.123239331056 +19,26,14.858806723767309,261.48596330493314 +20,-1,11.999951828546328,257.6685400346904 +20,25,12.028708960118182,257.37727666792404 +20,26,11.983504400438001,257.8092033704697 +21,-1,10.093610918549906,252.1632616023329 +21,25,10.136081891315454,251.87539123929696 +21,26,10.070091212448709,252.30227626712684 +22,-1,8.78830692800466,246.87701063895838 +22,25,8.83211644958051,246.57528075835782 +22,26,8.76335506458442,247.02271043043058 diff --git a/physicsnemo/experimental/datapipes/healda/configs/normalizations/conv_normalizations.csv b/physicsnemo/experimental/datapipes/healda/configs/normalizations/conv_normalizations.csv new file mode 100644 index 0000000000..abe77dbbc8 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/configs/normalizations/conv_normalizations.csv @@ -0,0 +1,9 @@ +Raw_Channel_ID,Platform_ID,obs_std,obs_mean,obs_min,obs_max +1,-1,0.006021306479256463,0.004113416818592855,9.99999993922529e-09,0.0828777477145195 +2,-1,24.040581924661886,237.75966260053028,171.33290100097656,322.6221618652344 +3,-1,0.001796639343013556,0.0004967400106678111,1.0000000116860974e-07,0.026085082441568375 +4,-1,51.69925682101712,983.1479345318987,200.0,1100.0 +5,-1,0.005461665852510159,0.004603054301095736,0.0,0.06551100313663483 +6,-1,27.571154926597,258.28742986188126,173.14999389648438,349.95001220703125 +7,-1,14.89412603357212,7.152579155956843,-100.0,100.0 +8,-1,10.19761780393707,0.21634765886519516,-100.0,100.0 diff --git a/physicsnemo/experimental/datapipes/healda/configs/normalizations/cris-fsr-pca_normalizations.csv b/physicsnemo/experimental/datapipes/healda/configs/normalizations/cris-fsr-pca_normalizations.csv new file mode 100644 index 0000000000..c82685028f --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/configs/normalizations/cris-fsr-pca_normalizations.csv @@ -0,0 +1,97 @@ +Raw_Channel_ID,Platform_ID,obs_std,obs_mean,obs_min,obs_max +1,-1,8.113731802616636,0.01651320704049069,-22.699179,180.61093 +1,25,8.158990654606411,0.00731116688961064,-22.484966,180.61093 +1,26,8.03518241049331,0.032357432721205136,-22.699179,180.61093 +2,-1,5.2520426951039205,0.03775147719903329,-27.162464,137.37079 +2,25,5.264311190825827,0.05796637692049009,-27.092379,121.12732 +2,26,5.230668154079859,0.0029451268627828514,-27.162464,137.37079 +3,-1,1.8908030805896912,-0.02063435869345464,-62.440613,8.056832 +3,25,1.8883629409442553,-0.0377655578077736,-36.20284,8.056832 +3,26,1.894634275706942,0.008862424682628262,-62.440613,7.7325597 +4,-1,1.2550114321169146,0.005899992689136042,-51.8124,50.653717 +4,25,1.2384460465664056,0.008078419422640086,-51.8124,50.653717 +4,26,1.2830242644175782,0.0021491413313272094,-18.104618,30.854942 +5,-1,1.0355770245602534,0.0011728299215596702,-8.10866,5.8395243 +5,25,1.0342660405895214,-0.0012681624146074068,-8.10866,5.8395243 +5,26,1.0378169736710516,0.005375771098188517,-6.8847127,5.7952814 +6,-1,0.5493822121451002,0.0036791990498476032,-12.703717,13.661785 +6,25,0.5223683556764287,0.006430280210638828,-12.703717,13.661785 +6,26,0.592988689396745,-0.0010576582190667792,-12.383174,6.8123903 +7,-1,0.3255828136582528,0.002480108574867595,-6.011691,14.647189 +7,25,0.26640313973354024,0.0015354013766269161,-3.6407008,14.647189 +7,26,0.4078251001117691,0.00410672113096357,-6.011691,4.4886813 +8,-1,0.26758123164244024,0.0033795806969370086,-3.2466252,10.183069 +8,25,0.26632283059089557,0.008081420107786804,-3.2466252,10.183069 +8,26,0.269542094424536,-0.004716124550866895,-2.3600237,1.7463634 +9,-1,0.1962527740104057,-5.5949416656172864e-05,-5.058476,3.9574463 +9,25,0.19222928636496564,-0.001677236175508665,-5.058476,3.9574463 +9,26,0.20296335507173885,0.002735609072463909,-1.5942534,3.1422544 +10,-1,0.13015063025884935,-0.0008409591891093529,-3.7672338,8.647489 +10,25,0.1280901455750696,-0.0010135838580174522,-3.7672338,8.647489 +10,26,0.13362345536560788,-0.000543731164954626,-2.8124948,1.4940547 +11,-1,0.09736682370697192,0.0004948809771133328,-3.2520914,1.9735277 +11,25,0.09665249450538914,0.0020333158875877563,-3.179755,1.0819921 +11,26,0.09852836924270333,-0.002154021822372433,-3.2520914,1.9735277 +12,-1,0.08449291647792997,-0.00015908060227801872,-3.8171158,6.006262 +12,25,0.08288337321816482,-0.0010891632007279718,-3.8171158,2.6338766 +12,26,0.08717141354969027,0.0014423510748199206,-0.91960233,6.006262 +13,-1,0.07876476730583207,-0.0007481134102039366,-3.015542,8.513709 +13,25,0.0771227240777569,-0.0013911656654801326,-3.015542,8.513709 +13,26,0.08150270567906999,0.00035910465122932915,-0.54739743,2.686968 +14,-1,0.06435283954670194,0.0008178949375494241,-1.7583451,2.3168159 +14,25,0.06342331547347822,0.0012792422123076726,-1.3276563,1.434597 +14,26,0.06591503969325464,2.353953172637753e-05,-1.7583451,2.3168159 +15,-1,0.062281492352083224,9.956594785958151e-05,-9.202019,23.577919 +15,25,0.05960121279694898,3.3575702879594815e-05,-9.202019,23.577919 +15,26,0.06664417849634201,0.000213189048503035,-7.7245126,0.7483987 +16,-1,0.053170138403606786,-0.0002782876452943954,-5.2730403,6.4986434 +16,25,0.05090963976793033,-9.089318427967239e-05,-5.2730403,6.4986434 +16,26,0.05685063477109915,-0.0006009465429899427,-2.4634545,3.1796315 +17,-1,0.0462249969905426,6.735756532865781e-05,-6.796467,3.3835852 +17,25,0.04627713717767539,0.0005964508644584201,-6.796467,3.3835852 +17,26,0.04612086292721632,-0.0008436440712773303,-4.0116854,0.58511776 +18,-1,0.03788835000169828,-0.0003514180435949096,-2.7663088,1.4395287 +18,25,0.03745761229051916,-0.0008174831079994006,-2.7663088,1.3517247 +18,26,0.038605563668077596,0.0004510605308563524,-0.5526807,1.4395287 +19,-1,0.03330582676476361,-0.0008574601679794651,-6.19803,5.5526767 +19,25,0.032727610415141196,-0.0016696842637779065,-6.19803,5.5526767 +19,26,0.034233436147920786,0.0005410407803036202,-1.491681,3.0981164 +20,-1,0.028468433380136483,-0.0005675972791015586,-9.06101,2.9966123 +20,25,0.026643184578347776,-0.0009988920709598087,-9.06101,2.3903935 +20,26,0.03134934849528494,0.000175013262156009,-2.6174066,2.9966123 +21,-1,0.024268825929043395,0.0002599081395317108,-1.9682393,0.8665377 +21,25,0.0224297582412051,2.0550463659553276e-06,-1.9682393,0.73365885 +21,26,0.027139217697359913,0.0007038838811717654,-1.2463758,0.8665377 +22,-1,0.023653236177545802,-2.3053129620237552e-05,-1.571571,6.7783623 +22,25,0.022901615809238295,0.0005573283528326744,-1.571571,6.7783623 +22,26,0.024862543398320246,-0.0010223636127430628,-1.4353843,0.94891393 +23,-1,0.0227950591817633,7.47340289880043e-05,-1.3103967,3.68755 +23,25,0.021006005909240646,0.0002808284683341765,-1.3103967,3.68755 +23,26,0.02558009822371429,-0.0002801228022146704,-1.0873473,0.76674837 +24,-1,0.02163025281614505,0.00047457330557465434,-6.2863092,5.610426 +24,25,0.02078166535950987,0.0015827913900130566,-6.2863092,5.610426 +24,26,0.022892809629481283,-0.00143357501212334,-1.5460097,2.214486 +25,-1,0.020180399097139597,-0.00013897911048542052,-7.081303,3.0626888 +25,25,0.019194882874143886,-0.0007289904180184681,-7.081303,3.0626888 +25,26,0.02173550438228641,0.0008769121655196096,-1.4447576,0.97537893 +26,-1,0.019364271442448903,-0.0002416602799426142,-6.701268,1.8637762 +26,25,0.01836366698129971,-0.0004320107116650731,-6.701268,1.7319096 +26,26,0.020971529909368472,8.608825724928664e-05,-0.93850017,1.8637762 +27,-1,0.017319452118580603,-0.00013860826794000325,-0.6164454,1.9288743 +27,25,0.016136272119552268,-0.000514231594614049,-0.6164454,1.9288743 +27,26,0.019169215854635987,0.0005081462192369678,-0.58550715,0.61321104 +28,-1,0.017700405302912354,-0.0001325316854622738,-13.401714,2.195423 +28,25,0.017145463430931325,-0.0004955198328022546,-13.401714,2.195423 +28,26,0.018600601430655454,0.0004924673399254899,-1.5438087,2.142796 +29,-1,0.015425732668083376,0.0002193715662504277,-0.6605541,0.53531766 +29,25,0.014380885705600187,0.0004007458176717921,-0.14273225,0.53531766 +29,26,0.01707109280259627,-9.29216345041803e-05,-0.6605541,0.22257876 +30,-1,0.01504755953893869,-4.3247367891373935e-05,-0.40668693,0.15491618 +30,25,0.013958022374844569,-5.858884883002894e-05,-0.12909138,0.13453749 +30,26,0.016758339226176453,-1.6832151017408964e-05,-0.40668693,0.15491618 +31,-1,0.015065970340400598,-6.205377879144825e-05,-4.89397,1.1568414 +31,25,0.014121691891133752,-0.00018564009517773203,-4.89397,1.1568414 +31,26,0.016564038096794912,0.0001507391949150614,-0.6891684,1.0968304 +32,-1,0.014772341527657062,-6.904989192718075e-05,-8.888641,1.298905 +32,25,0.013787270516411243,-0.00032197688415858746,-8.888641,1.298905 +32,26,0.01632061780609942,0.0003664440072464144,-1.6017404,1.008066 diff --git a/physicsnemo/experimental/datapipes/healda/configs/normalizations/cris-fsr_normalizations.csv b/physicsnemo/experimental/datapipes/healda/configs/normalizations/cris-fsr_normalizations.csv new file mode 100644 index 0000000000..15ba09fdf0 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/configs/normalizations/cris-fsr_normalizations.csv @@ -0,0 +1,301 @@ +Raw_Channel_ID,Platform_ID,obs_std,obs_mean +24,-1,7.148680945983152,221.7428510438544 +24,25,7.226625973984494,221.72723726263527 +24,26,7.012362664944118,221.7697351118772 +26,-1,7.1183690738760745,220.21698041628375 +26,25,7.187047453778234,220.20723472808095 +26,26,6.998506694575881,220.23376070421637 +28,-1,6.810027715339033,225.11215231553538 +28,25,6.88641664148765,225.09473612568024 +28,26,6.676345902019827,225.1421398007059 +32,-1,8.586831187042542,232.24373320645796 +32,25,8.674517098267549,232.20074248110168 +32,26,8.433202677079201,232.31775535193378 +37,-1,7.101698793998617,220.8547745437012 +37,25,7.17415360906482,220.84003686001532 +37,26,6.975108627925394,220.8801501324295 +39,-1,7.084499131404986,219.52015323580326 +39,25,7.147340844991026,219.50862805040438 +39,26,6.9749260966309325,219.5399974915854 +42,-1,7.103298339878284,220.7258401200952 +42,25,7.176376792531341,220.7126013511023 +42,26,6.975618121257783,220.74863485259138 +44,-1,7.072375481378026,219.21996256837906 +44,25,7.130500410254753,219.21370781837263 +44,26,6.9711463883192,219.23073210090158 +47,-1,7.088522908525025,220.65073272762893 +47,25,7.159615620612103,220.6372255890556 +47,26,6.9643526598480525,220.67398954332933 +49,-1,7.0585366790228745,219.0336902154235 +49,25,7.113326222513172,219.0264455221647 +49,26,6.963171413476821,219.04616424869195 +51,-1,7.043834215481206,220.0082110773996 +51,25,7.10766727653143,219.99779181336595 +51,26,6.932511637784182,220.0261511393934 +53,-1,7.097154641713036,221.48956417686423 +53,25,7.173213631048205,221.47356906935863 +53,26,6.964162811668813,221.5171048188252 +55,-1,7.027936818498324,220.56581122312792 +55,25,7.093666375094964,220.55327740186326 +55,26,6.913245169151606,220.58739216490855 +57,-1,6.986009691056735,218.17761325408964 +57,25,7.025644659287636,218.17248883733 +57,26,6.917224570420424,218.1864365600403 +59,-1,6.957474561686182,217.5947181448047 +59,25,6.988781837074984,217.5902339182468 +59,26,6.903229715403152,217.60243916062527 +61,-1,6.858448008153826,218.81845319444992 +61,25,6.903815698966285,218.80969093107814 +61,26,6.77959533976492,218.8335402051628 +63,-1,6.824041345248238,220.35039481195363 +63,25,6.879358836517029,220.33796076391843 +63,26,6.72767527166563,220.3718039625294 +65,-1,6.772041886497526,218.48205618962905 +65,25,6.809038015889509,218.47133553268873 +65,26,6.707823133395746,218.5005151949391 +67,-1,6.615958129465173,216.75712420819255 +67,25,6.633803666213863,216.74480895375086 +67,26,6.585064154878169,216.77832881798457 +69,-1,6.3869434244147705,217.9668858199093 +69,25,6.422831493887996,217.95072153884342 +69,26,6.324576842770392,217.99471774771823 +71,-1,6.3687329469696845,218.37176920764338 +71,25,6.409529123604809,218.3557861949569 +71,26,6.2977752056619,218.39928902454378 +73,-1,6.231960909808288,217.74396842271508 +73,25,6.263902359178369,217.7267904371157 +73,26,6.176464649332763,217.77354576383877 +75,-1,5.747537030125413,217.7108760713893 +75,25,5.778702638529444,217.6882536945453 +75,26,5.693265248141307,217.7498276558503 +77,-1,5.3709703711015955,218.9489486746189 +77,25,5.415050605310716,218.92350724874214 +77,26,5.293925909453148,218.99275414456045 +79,-1,5.402782926138339,219.24549479231524 +79,25,5.4485355837028715,219.22184646945195 +79,26,5.3228377012751595,219.2862128676847 +81,-1,5.310695809183949,219.12941337127202 +81,25,5.354317655404501,219.10460393324956 +81,26,5.234459791990948,219.17213067402167 +83,-1,4.929063461686605,220.7598649082012 +83,25,4.983904706653539,220.73255139432644 +83,26,4.832817317317919,220.80689377037405 +85,-1,4.781525500353753,221.85931616238483 +85,25,4.837270020633143,221.83139884092395 +85,26,4.683599498018493,221.9073846704955 +87,-1,4.803066436137887,224.03463127414375 +87,25,4.865750841450445,224.00949031701813 +87,26,4.6928583823622905,224.0779193919955 +89,-1,4.870657216747649,224.8576073834463 +89,25,4.932906658212551,224.8337831094861 +89,26,4.761289248416708,224.89862841433893 +91,-1,5.31551725723172,227.77403112213702 +91,25,5.372932300819551,227.75304317342793 +91,26,5.214980267447988,227.8101685210692 +93,-1,5.488476582628741,229.10937273728692 +93,25,5.538821292290113,229.08959890888914 +93,26,5.400523033578721,229.14341964368694 +95,-1,5.7185069753729065,229.94375370060385 +95,25,5.768839014268012,229.92334964657877 +95,26,5.630617096124981,229.97888573996406 +97,-1,5.997081260478454,231.75153722972925 +97,25,6.040867757908995,231.73510874510055 +97,26,5.920823464593723,231.77982406762405 +99,-1,6.424621856295242,233.73030489322696 +99,25,6.458604879276553,233.7147772666865 +99,26,6.3655955344728685,233.75704061869206 +103,-1,7.035743304214101,236.78961918345988 +103,25,7.064369191596282,236.77916474817476 +103,26,6.986143332907935,236.807619803899 +105,-1,7.4242187585584665,237.96473991634969 +105,25,7.447903012602698,237.9575507834448 +105,26,7.383244450964255,237.97711828487732 +107,-1,8.533038279366547,242.58541758542094 +107,25,8.541967753830795,242.58391604346247 +107,26,8.517640878304602,242.58800296532246 +109,-1,8.28367833280382,242.44918642551977 +109,25,8.28910233788292,242.44401670375743 +109,26,8.274323360711675,242.45808773837712 +111,-1,4.790411106057031,226.38039215135038 +111,25,4.8422137544251695,226.36026785065238 +111,26,4.699676826899978,226.41504250678219 +113,-1,6.392115911364461,221.9808166808176 +113,25,6.4592633516151805,221.95933631146835 +113,26,6.274644570709405,222.01780193779265 +115,-1,6.116276270112719,232.69438277919767 +115,25,6.168936097091183,232.6725244382642 +115,26,6.0243414372417785,232.73201883392872 +117,-1,10.54643588799885,249.50830848344825 +117,25,10.535631454865753,249.51378747227818 +117,26,10.565006680733363,249.4988746694135 +119,-1,10.363587856616613,248.56650995907498 +119,25,10.358034751616351,248.57401517208282 +119,26,10.37312968268347,248.55358735857243 +121,-1,10.967417711963517,251.44628922255885 +121,25,10.955245440244632,251.46081311811614 +121,26,10.988299645740154,251.42128173778374 +123,-1,11.282576583228686,252.41196448645687 +123,25,11.267508670282483,252.42674889218097 +123,26,11.308428524381442,252.38650845094656 +125,-1,10.54213713130317,249.47747508620503 +125,25,10.534673480729518,249.4848214461781 +125,26,10.554963909955422,249.46482600149804 +127,-1,8.459408971865988,243.05743068769692 +127,25,8.46795124999899,243.05835354849458 +127,26,8.44468032282402,243.05584169063653 +129,-1,8.237674224604191,242.01746666952454 +129,25,8.25135132315659,242.01626185300702 +129,26,8.214071061213657,242.01954114262986 +131,-1,8.718113620057526,244.53885340586203 +131,25,8.721612200272396,244.54375661657105 +131,26,8.712080021880643,244.53041097616605 +133,-1,10.403694347328528,249.1326945566447 +133,25,10.39718529991075,249.14180445731117 +133,26,10.414873617249011,249.11700897825978 +135,-1,10.24724794932284,248.6327243329683 +135,25,10.241615444019942,248.6390002768527 +135,26,10.256929931371603,248.62191830847445 +137,-1,9.648325500286765,246.36191055958292 +137,25,9.650819400555067,246.36605969715396 +137,26,9.644025848995888,246.35476650554952 +139,-1,9.392715598800262,245.2503819911986 +139,25,9.399069827113424,245.2551671251538 +139,26,9.381759058852538,245.24214286801984 +141,-1,9.866835778064734,247.45362455657633 +141,25,9.865824701186552,247.4630557642839 +141,26,9.868555387705598,247.43738574638746 +143,-1,10.543732184899328,250.63131243458636 +143,25,10.531687636176892,250.64483612349778 +143,26,10.564398035150901,250.60802712223872 +145,-1,9.800556910225932,249.48376419906586 +145,25,9.786692769019846,249.49302897465697 +145,26,9.824362183589985,249.46781195446144 +147,-1,6.243908891438654,236.1068024864158 +147,25,6.2589117726836125,236.09968273836986 +147,26,6.217972808619235,236.11906138694007 +149,-1,10.205670187916358,249.64682340121078 +149,25,10.197496817532627,249.65789313948812 +149,26,10.219699905627355,249.62776334185259 +151,-1,10.784897364791094,252.06538713703387 +151,25,10.770662532573168,252.084207426667 +151,26,10.80928651714695,252.03298204952773 +153,-1,10.703081043649087,253.44747511832583 +153,25,10.681924446415128,253.46985329906903 +153,26,10.739301986605877,253.40894399477003 +155,-1,13.446725295052872,260.17642585611475 +155,25,13.41432123692421,260.2092786181187 +155,26,13.502149725059851,260.1198594244442 +157,-1,13.60488630690257,260.9768444226937 +157,25,13.5694977712579,261.01023015656875 +157,26,13.665413130455873,260.9193603111967 +159,-1,13.096089675072577,259.7058672761275 +159,25,13.061942802067788,259.73613337033834 +159,26,13.154513566496913,259.65375461205645 +163,-1,13.331338771418029,260.2095728654148 +163,25,13.296419572474836,260.2425660779894 +163,26,13.391059499082516,260.15276460361724 +167,-1,13.587918871163545,262.69527396898957 +167,25,13.545840308783054,262.7331766035323 +167,26,13.659820392779968,262.6300125828467 +171,-1,14.932715767841447,265.4096066838884 +171,25,14.886538027054604,265.45202958907805 +171,26,15.011611730011678,265.3365622206134 +175,-1,16.18625774698439,267.99002808488234 +175,25,16.13913515874035,268.04179493148297 +175,26,16.266688454763624,267.9008950679783 +179,-1,17.250792853649305,270.6285052231255 +179,25,17.20019076695825,270.68717622984076 +179,26,17.337109203156867,270.5274845085808 +183,-1,17.69838250496001,271.3241391527176 +183,25,17.647664931363416,271.3851463730712 +183,26,17.784879753374206,271.21909590669964 +187,-1,17.813474312013092,271.84774949656054 +187,25,17.762426958504125,271.9097494317418 +187,26,17.90052436217987,271.7409969776523 +190,-1,17.683204396031265,271.84363517407184 +190,25,17.631725363083298,271.9054357498551 +190,26,17.770988982777947,271.7372259154886 +194,-1,17.558059698808247,271.79111099502103 +194,25,17.50633482338257,271.85271556246425 +194,26,17.646261393996433,271.6850392268557 +197,-1,18.64238809582609,273.5180108138881 +197,25,18.591783117763946,273.58673379598474 +197,26,18.72860946459705,273.3996824410558 +200,-1,18.45098367803441,273.24475539100337 +200,25,18.400659185487925,273.31204355512915 +200,26,18.536740914936548,273.1288975116016 +211,-1,19.045454120899937,274.30958014766765 +211,25,18.996554831667453,274.38204733545325 +211,26,19.128713704359686,274.1848049392469 +224,-1,19.34341472510976,274.875252048082 +224,25,19.295004732546765,274.9498619493414 +224,26,19.42581388008933,274.74678747998763 +275,-1,19.677606842238163,275.69239458033877 +275,25,19.632786615760565,275.76721775587464 +275,26,19.75387684370453,275.5635627930503 +279,-1,19.69230669261267,275.7887325256918 +279,25,19.64738030421323,275.8635272917121 +279,26,19.768759539400712,275.65994965437943 +291,-1,19.788261313683435,276.04009135761567 +291,25,19.744434742832055,276.11609070096864 +291,26,19.862814836897808,275.9092344250304 +311,-1,19.65409404285493,275.76714292030044 +311,25,19.613290243596435,275.84371484788505 +311,26,19.723456504242087,275.6353001026681 +332,-1,19.775508237054876,276.17478883734236 +332,25,19.734330526951346,276.25031522721594 +332,26,19.845535163291395,276.04474624391906 +342,-1,19.850639473695704,276.2970560588583 +342,25,19.809790748482026,276.37262007759597 +342,26,19.920105724935965,276.1669486754313 +389,-1,19.840464153508705,276.4154742148947 +389,25,19.803142422520466,276.488371623284 +389,26,19.903935924086518,276.2899582454877 +410,-1,19.47964336183736,276.26509997475426 +410,25,19.443557945767978,276.3360428998123 +410,26,19.541016350911352,276.142949267233 +427,-1,19.70068820857589,276.57320984754745 +427,25,19.669879710714365,276.64771142902345 +427,26,19.752963940209064,276.44493178625646 +464,-1,19.721798236970347,277.1173727365535 +464,25,19.693294557540447,277.1893792234199 +464,26,19.77016569189154,276.99339077063274 +482,-1,19.58381751840419,276.9227136873008 +482,25,19.557473735968596,276.99210970578446 +482,26,19.628519102464296,276.80322646923076 +501,-1,19.79281300616905,277.4870385598215 +501,25,19.768782897624316,277.55951676617264 +501,26,19.833499775489052,277.3622443794514 +529,-1,19.364877769477527,276.74638892333707 +529,25,19.342787037626803,276.81595388527614 +529,26,19.4022707510178,276.62661081561737 +710,-1,19.384145248686572,276.06837788192524 +710,25,19.37510150979452,276.14231782796134 +710,26,19.39993057773661,275.93328179332025 +713,-1,19.3397500288157,275.73983974096427 +713,25,19.33390920742277,275.8164601466514 +713,26,19.349633867111372,275.5998461695514 +742,-1,18.657448340927807,275.78992506912755 +742,25,18.649609651382555,275.852503928025 +742,26,18.671220428598353,275.67558690993167 +882,-1,11.13239209721585,258.43401833615354 +882,25,11.12548494962383,258.44849624591393 +882,26,11.14495265209976,258.40756567204477 +890,-1,9.119963275360648,249.36032074235715 +890,25,9.120943674809512,249.37628759441878 +890,26,9.1180995802985,249.33114762375055 +937,-1,9.436974459758718,250.20699610787517 +937,25,9.435502631337794,250.228092237003 +937,26,9.43954137084727,250.1684512600328 +995,-1,10.343182041679201,255.52739461811973 +995,25,10.344860301584001,255.5508926216122 +995,26,10.339977157523222,255.48446129306623 +1008,-1,8.012700182682034,243.34596152629126 +1008,25,8.017486115182507,243.3770852762682 +1008,26,8.003635879717297,243.28909516051965 +1022,-1,10.952386967893961,257.5198463600538 +1022,25,10.952908931651697,257.53621659855366 +1022,26,10.951370119488704,257.4899362116897 +1058,-1,9.392306071523679,250.018266429657 +1058,25,9.391673367329897,250.02775441362263 +1058,26,9.393437311645817,250.00093088454014 diff --git a/physicsnemo/experimental/datapipes/healda/configs/normalizations/iasi-pca_normalizations.csv b/physicsnemo/experimental/datapipes/healda/configs/normalizations/iasi-pca_normalizations.csv new file mode 100644 index 0000000000..5566b3c4a5 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/configs/normalizations/iasi-pca_normalizations.csv @@ -0,0 +1,129 @@ +Raw_Channel_ID,Platform_ID,obs_std,obs_mean,obs_min,obs_max +1,-1,10.16384598075378,-0.11578126917322405,-70.115036,67.5379 +1,14,10.144844944019896,-0.17180982432481942,-70.115036,67.12935 +1,15,10.117032923964555,-0.13534534542551505,-23.84075,67.5379 +1,16,10.36645870193704,0.15457158672327753,-23.895481,64.10818 +2,-1,7.3720902502472345,-0.011398788802979953,-69.58385,27.106546 +2,14,7.408354623192937,-0.033540395834743233,-69.58385,26.86886 +2,15,7.38308754618453,-0.040821627762674104,-35.73581,26.775658 +2,16,7.197277111477421,0.15891298734824988,-32.380424,27.106546 +3,-1,2.415178591123785,0.014606195810605006,-17.465494,10.614266 +3,14,2.4117889904979277,0.025530758192094077,-17.465494,10.614266 +3,15,2.4118225548535817,0.006701324015950001,-16.245832,10.486391 +3,16,2.4375695805475215,-0.003814585183830675,-16.093887,10.174647 +4,-1,2.1137909933142707,-0.012366427297395179,-23.838263,10.660836 +4,14,2.1131650971261413,-0.11091984976782285,-23.838263,10.266308 +4,15,2.1115362994913744,0.04197765763510022,-10.771783,10.660836 +4,16,2.100924890207585,0.20346317164919536,-10.313636,10.579529 +5,-1,1.4895056055253255,0.015714926696960375,-8.8541355,7.3725705 +5,14,1.4831121674309395,0.008687727346691402,-8.8541355,7.163042 +5,15,1.4920626547765274,0.022477324325635862,-8.29941,7.3725705 +5,16,1.5060660076310253,0.022655104789458494,-7.775946,7.17718 +6,-1,0.7474589262986069,-0.0044207293706204605,-4.39215,7.544279 +6,14,0.7481429604526755,-0.03261158083485469,-4.39215,7.1897388 +6,15,0.7466810279890996,0.014906071159060997,-4.111029,6.9743967 +6,16,0.7426449201064272,0.04625019242205241,-3.9642997,7.544279 +7,-1,0.592554877955259,0.014294444715331244,-6.5000753,4.1782694 +7,14,0.5945889391301633,0.0423889235728364,-6.278107,4.1782694 +7,15,0.5888286835429928,-0.0007285253674890736,-6.5000753,4.154543 +7,16,0.5892362289435799,-0.04860361744250353,-5.979437,4.081578 +8,-1,0.5073028941599157,-0.005480693408424849,-5.6289406,3.6449907 +8,14,0.5092852370823666,-0.01318423016998478,-5.6289406,3.6412852 +8,15,0.5062152497851627,-0.003205194563961296,-5.5404353,3.6449907 +8,16,0.5021587544974437,0.01716132356058474,-4.9796276,3.256661 +9,-1,0.319155301263443,0.0018874141315705772,-2.6887596,2.7577827 +9,14,0.31777908736103516,-0.00011142141680679401,-2.54663,2.7577827 +9,15,0.3188619436735035,0.0029212103142570974,-2.6170118,2.524858 +9,16,0.325119227540568,0.006464955278109927,-2.6887596,2.2580612 +10,-1,0.24576400857313757,-0.002904770072379496,-1.9653,1.9914408 +10,14,0.24560582464325617,-0.005595659568989295,-1.6574273,1.9914408 +10,15,0.24575333449517595,-0.0025926002334247184,-1.7135054,1.9291075 +10,16,0.24616300142474046,0.006416599973651741,-1.9653,1.9329331 +11,-1,0.2052519194630894,-0.0008292540185629837,-2.4269035,4.275482 +11,14,0.2060414754178197,0.0017183040757965202,-2.4269035,4.275482 +11,15,0.20437881101577896,-0.0028544550010604465,-2.4239132,4.0686407 +11,16,0.20466777673089995,-0.004592858962850275,-2.4268682,4.0996847 +12,-1,0.1600413184979059,0.0016911035141503803,-2.369342,1.7326239 +12,14,0.15977749011970102,0.004403783414487348,-2.369342,1.5905337 +12,15,0.1598870547522381,0.0003468629593378756,-1.5889701,1.7326239 +12,16,0.16125921857781852,-0.004693142973078381,-1.2982864,1.4566338 +13,-1,0.156637722895137,0.0007217974244652316,-1.4074177,1.5935816 +13,14,0.15670216109142604,-0.0012057353127545027,-1.389396,1.5685824 +13,15,0.15679951689457444,0.002861331665849974,-1.4074177,1.5935816 +13,16,0.15582527147071262,0.0017925715098349312,-1.1188769,1.3014976 +14,-1,0.12122455985938849,-9.395895121152183e-05,-1.7415439,2.1401186 +14,14,0.12179802829533352,0.005630057425092977,-1.7415439,1.7244978 +14,15,0.12046868939058411,-0.005026536688557099,-1.6983584,2.1401186 +14,16,0.12020912437038608,-0.007431810773962197,-1.5034088,1.031158 +15,-1,0.08912428923694071,-0.0002954394086278218,-0.7643497,1.2684778 +15,14,0.08970672290694674,-0.0002542048062865138,-0.7643497,1.2684778 +15,15,0.0886182664155935,-0.0013750804237421547,-0.73106354,0.8765023 +15,16,0.08830474897751397,0.0027069262062209753,-0.65798026,0.8396846 +16,-1,0.08369533176258243,0.0005640216112958324,-0.7478854,0.88349766 +16,14,0.08503635698094053,0.0033331062231248174,-0.6551791,0.88349766 +16,15,0.08265565108905186,-0.0004444043735053519,-0.7478854,0.7513186 +16,16,0.08099840504223273,-0.007017407857148513,-0.615259,0.6203963 +17,-1,0.07541322011917798,-3.555967540353909e-05,-0.9543784,1.3299066 +17,14,0.07624033199914984,0.0005687846091489142,-0.9543784,1.3299066 +17,15,0.07491390360670387,-0.0004671704929098946,-0.8859532,0.8546724 +17,16,0.07366610926007723,-0.001071227909829426,-0.57342154,0.75387377 +18,-1,0.0659239608950789,0.0010402435586074807,-0.79023623,0.814086 +18,14,0.06659983497483432,0.004355797390116598,-0.79023623,0.814086 +18,15,0.0652039486671305,-0.0009010464515499121,-0.7182315,0.66189367 +18,16,0.06465827843705547,-0.005889978803059955,-0.6976637,0.6214272 +19,-1,0.06077963219299458,0.0004908257852194686,-1.1956189,1.1055057 +19,14,0.06210216425899309,-0.002500127317215281,-1.1341717,1.1055057 +19,15,0.059604944556958035,0.003800923075883944,-1.1956189,0.94502634 +19,16,0.058488399555782285,0.0021810770513757422,-1.1363206,0.5038915 +20,-1,0.058573607575298864,0.0004369104733404216,-0.57480556,0.49712297 +20,14,0.06078973213817348,-0.000779628941855008,-0.54442686,0.49712297 +20,15,0.056712589417216325,0.0020965845832922837,-0.57480556,0.47636917 +20,16,0.05512542351595473,0.0002075606890555336,-0.4483251,0.46493176 +21,-1,0.057409387944944774,-0.0004679878322903191,-0.5531818,0.4385244 +21,14,0.060094850882160815,-0.0005988649628692427,-0.5233578,0.4385244 +21,15,0.05514273407186683,-0.0013685880821012705,-0.5531818,0.3923587 +21,16,0.05319666224382016,0.0026651048667955804,-0.53307384,0.34304833 +22,-1,0.056716850416005146,0.00017889373424435378,-0.40499026,0.59640205 +22,14,0.05952874548112656,0.0006144679683606531,-0.4044233,0.59640205 +22,15,0.05433567242167013,0.0004226595796739127,-0.40499026,0.5523689 +22,16,0.052352792304419325,-0.002191115582539295,-0.37258968,0.558844 +23,-1,0.05738503506841794,0.0002915964775403029,-0.41909268,0.47572902 +23,14,0.062119215856558885,-0.00022881234431018098,-0.41909268,0.47572902 +23,15,0.05301019621444499,0.0014990994538740948,-0.350389,0.39156067 +23,16,0.05046482305973548,-0.0012623690448801631,-0.35189015,0.41263208 +24,-1,0.05629817832557509,-0.0011535983524246073,-0.44527066,1.0995467 +24,14,0.05982346401260559,-0.004190661963303076,-0.44527066,0.51397127 +24,15,0.05308728552666641,0.0012205944939495295,-0.3945093,1.0995467 +24,16,0.05066987937483793,0.0034506405080156435,-0.3653711,0.38389608 +25,-1,0.05577545508179762,0.0006631969394137377,-0.99278796,0.53750217 +25,14,0.06021785360778572,0.0022225314253001586,-0.67450505,0.53750217 +25,15,0.051633919940863506,-0.000829543573435437,-0.99278796,0.5108202 +25,16,0.049340210160091644,-0.0008997525705235764,-0.55203766,0.32546312 +26,-1,0.05455172060307614,0.0006970651908912962,-0.6276804,0.47650766 +26,14,0.05734995220430608,0.0012991450758166171,-0.38686976,0.47650766 +26,15,0.05211831002280347,0.0014061990604566333,-0.6276804,0.35979897 +26,16,0.05022503419390421,-0.003667991362457362,-0.33765188,0.3331192 +27,-1,0.05463785516106688,2.7725567611897368e-05,-0.4889761,0.37899846 +27,14,0.05829975099484399,4.670773788950273e-05,-0.4889761,0.37899846 +27,15,0.05136745997753125,-0.0003221444038925016,-0.4038399,0.37244293 +27,16,0.04924903152187081,0.0009793020341476472,-0.40843204,0.33169365 +28,-1,0.05412503531329384,0.000520615850812093,-0.6537395,0.5192948 +28,14,0.05772479660184095,-6.39800634337659e-05,-0.6537395,0.5192948 +28,15,0.050864762758778374,0.0008256212450902778,-0.6340794,0.46081963 +28,16,0.04895271859242637,0.0018516386055598207,-0.5266372,0.36741775 +29,-1,0.05323330275954097,-0.0005420847090697627,-0.6835556,0.7398786 +29,14,0.056069739890299036,-0.000284307457383957,-0.63065916,0.7398786 +29,15,0.05071739997253867,-0.0018203811583674687,-0.6835556,0.6795276 +29,16,0.049063259953557356,0.002217955976672146,-0.33469966,0.5843655 +30,-1,0.05268664507093167,0.0004305211782272328,-0.7483244,0.4951867 +30,14,0.05563085869364301,0.002247515040689564,-0.44440868,0.4951867 +30,15,0.050144180134041254,-0.0010720510383476591,-0.7483244,0.44467148 +30,16,0.04798381762649267,-0.002083669567404401,-0.35060766,0.42631716 +31,-1,0.05179105242478689,0.00041054237241177857,-0.6042099,0.5955411 +31,14,0.054305141792349125,0.0020272261618122013,-0.4733694,0.5955411 +31,15,0.0496583890438443,-0.0007867759335017497,-0.6042099,0.56808573 +31,16,0.04774666693922744,-0.002234988385272294,-0.4294472,0.47578403 +32,-1,0.05031063394652676,0.0007722156055080443,-0.79842126,0.60976607 +32,14,0.05281679260164394,0.0010774775378229234,-0.79842126,0.60976607 +32,15,0.048046099000917164,0.00172496072845701,-0.69703746,0.59639174 +32,16,0.04670410862212604,-0.003176735890430844,-0.3257097,0.5569593 diff --git a/physicsnemo/experimental/datapipes/healda/configs/normalizations/iasi_normalizations.csv b/physicsnemo/experimental/datapipes/healda/configs/normalizations/iasi_normalizations.csv new file mode 100644 index 0000000000..2781a37561 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/configs/normalizations/iasi_normalizations.csv @@ -0,0 +1,699 @@ +Raw_Channel_ID,Platform_ID,obs_std,obs_mean +16,-1,8.148566703939132,231.5696314828929 +16,14,8.16659063824984,231.5899448741447 +16,15,8.080259900068903,231.62268837141235 +16,16,8.274089773247885,231.33711634584245 +38,-1,7.6839326337454095,215.09468773136138 +38,14,7.730132510666448,215.06334598883302 +38,15,7.678726627442439,215.16592609227328 +38,16,7.519422450941319,215.00544159369431 +49,-1,7.457045667309594,218.2708576624594 +49,14,7.4710016471106755,218.24828727173008 +49,15,7.4440864250822765,218.33711300901177 +49,16,7.440009147744245,218.16283068363074 +51,-1,7.643226914012837,215.43678103395163 +51,14,7.683411781818187,215.40718181659258 +51,15,7.638104249733504,215.5052609380326 +51,16,7.502042426802164,215.3489788668497 +55,-1,7.440690620121674,218.28582424491685 +55,14,7.455226361928304,218.26378748340835 +55,15,7.4280326685980365,218.35058120073694 +55,16,7.420631482841768,218.18015213166495 +57,-1,7.694807649805706,215.1106416837574 +57,14,7.73882961690392,215.0809186227904 +57,15,7.690803215337125,215.17780451032993 +57,16,7.535425892847301,215.02716453618947 +59,-1,7.715303264905434,224.1601880926176 +59,14,7.712982162872568,224.16819773248594 +59,15,7.675182498227635,224.20274932570686 +59,16,7.838327688548475,224.00518237654217 +61,-1,7.450445404992458,218.79085568561987 +61,14,7.460583984780689,218.77233137521844 +61,15,7.435683055098165,218.852557543473 +61,16,7.453362129650873,218.6807636287963 +63,-1,7.6513378709615765,215.27879751876034 +63,14,7.6933374022502266,215.25118642931005 +63,15,7.64731900538729,215.343943122196 +63,16,7.499932709971022,215.19319018030149 +66,-1,7.6752320381806625,223.59292997873996 +66,14,7.672589390465481,223.5992553580554 +66,15,7.637483531657956,223.6355357697093 +66,16,7.792771337594965,223.44419998336863 +70,-1,6.961835073823521,221.28114876183045 +70,14,6.976858210690474,221.28742549756706 +70,15,6.940766929347125,221.32968641152831 +70,16,6.963705479380555,221.11524621764858 +72,-1,8.108291979131826,230.7910603085405 +72,14,8.122010202711568,230.8181510695525 +72,15,8.043559597982819,230.82946258327934 +72,16,8.240356614800811,230.5756491069457 +74,-1,7.452844614872846,222.83378317315635 +74,14,7.454468446682034,222.82733508837134 +74,15,7.423904118368168,222.8900615554259 +74,16,7.528850145493042,222.69362900984186 +79,-1,7.442997263911691,221.28872655502428 +79,14,7.445449486515733,221.2768642120128 +79,15,7.418089325198928,221.34613277115895 +79,16,7.504479748262484,221.16586536259808 +81,-1,7.408187798505816,222.10725904651184 +81,14,7.409841551960028,222.09165892893446 +81,15,7.381469879709513,222.1724782416252 +81,16,7.477542689284157,221.97575249891688 +83,-1,7.399205454556235,219.67256732880614 +83,14,7.411305524562781,219.65052422591268 +83,15,7.382022082128196,219.73350195523292 +83,16,7.4019055543511865,219.57810408083688 +85,-1,7.436787452490125,218.97070521542918 +85,14,7.447228344707931,218.95058240780335 +85,15,7.420918095087655,219.0315770940754 +85,16,7.441906308860176,218.86912171258479 +87,-1,7.6184393178012275,216.78241632580614 +87,14,7.638259297699789,216.7508562494609 +87,15,7.609631865616113,216.8512074640012 +87,16,7.56694824541699,216.70116158748002 +104,-1,7.470118463793038,218.67629623030197 +104,14,7.485022143012302,218.66828902833248 +104,15,7.458636327817316,218.7281622741795 +104,16,7.445379281795185,218.55498339316736 +106,-1,7.8276406702962325,225.96181958261502 +106,14,7.827171665425341,225.98079264508354 +106,15,7.781040122036949,225.99847433319738 +106,16,7.961834196087203,225.7823976743125 +109,-1,7.646890479703366,215.79138536301636 +109,14,7.685607198234471,215.7704194193942 +109,15,7.642526488304995,215.84936051917222 +109,16,7.509509523499673,215.7014850097762 +111,-1,7.4601429791654565,219.70520302473733 +111,14,7.466081407569862,219.7012514648927 +111,15,7.44213627081482,219.7553186285631 +111,16,7.488462835169181,219.5735865735657 +113,-1,7.908964113646412,225.87223306969673 +113,14,7.9068295305100325,225.8862131523485 +113,15,7.862145917092457,225.920740135559 +113,16,8.049535903843982,225.67712025108088 +116,-1,7.686448485950832,215.28666941326773 +116,14,7.7244833554766075,215.2676743084813 +116,15,7.6834513020794635,215.3444874815548 +116,16,7.547602298978031,215.18973261626874 +119,-1,7.990387843639178,226.5240851005008 +119,14,7.988941603660206,226.53175128199766 +119,15,7.939914531516684,226.5833000659401 +119,16,8.138521416541417,226.3216542538732 +122,-1,7.769527848121822,214.83300143084742 +122,14,7.81314090523951,214.80882094589958 +122,15,7.766097129860208,214.89679619993765 +122,16,7.610093781045913,214.7382985059943 +125,-1,8.075167374779538,226.9218117677382 +125,14,8.073941489993771,226.92711200339158 +125,15,8.02208936747214,226.9895825246283 +125,16,8.22936784451435,226.70334421196372 +128,-1,7.74625027447432,214.9907825820514 +128,14,7.788793895147038,214.9657786931796 +128,15,7.742342327720888,215.05787913085814 +128,16,7.592226756730147,214.8895499490575 +131,-1,7.952543587990458,226.22214474548852 +131,14,7.950999744641339,226.22631462687164 +131,15,7.9026753136196906,226.29309714412057 +131,16,8.098470967380626,225.9986665135917 +133,-1,7.4382507114881085,220.04690415277065 +133,14,7.443443185664738,220.03183925266788 +133,15,7.418760003574452,220.12127186678202 +133,16,7.472500250238374,219.88659187325294 +135,-1,7.748896805898435,214.9785618321739 +135,14,7.790954575464112,214.95381085160935 +135,15,7.7455774668438515,215.04453134339434 +135,16,7.595076612100289,214.879665151381 +138,-1,8.111588382186833,227.39487771646546 +138,14,8.111624612583228,227.40168637633337 +138,15,8.057239789157286,227.46719528213163 +138,16,8.264085813584925,227.15736816399718 +141,-1,7.80581710677709,214.60647398592974 +141,14,7.852059805141757,214.58421572854238 +141,15,7.803361848213758,214.67025887853808 +141,16,7.6331354301319,214.50448874684713 +144,-1,8.060191156588465,226.95382337875432 +144,14,8.059744065284622,226.96752988454628 +144,15,8.007496694660412,227.01400250935208 +144,16,8.21025045806167,226.72559678885128 +146,-1,7.423434283795452,218.02202366485116 +146,14,7.439210190224939,218.01099732707357 +146,15,7.4131522987179785,218.08640795337274 +146,16,7.391069238124511,217.87556376367863 +148,-1,7.724277093513682,214.98308251316752 +148,14,7.765566167169593,214.968658118887 +148,15,7.7233596135935825,215.03621205205593 +148,16,7.566649224618156,214.8824803519597 +151,-1,7.872359682039427,225.48796575701633 +151,14,7.869664711101843,225.49810631334444 +151,15,7.827415277331509,225.54174621686371 +151,16,8.009551282611222,225.29202582814483 +154,-1,7.806028291262147,214.31838838341184 +154,14,7.85334866660939,214.30482485825817 +154,15,7.806705590822431,214.37074033688643 +154,16,7.620082952674023,214.21678723569167 +157,-1,7.892393875633006,226.20433852144834 +157,14,7.892158042911494,226.22408306406024 +157,15,7.844968164801328,226.24410440795404 +157,16,8.02777869637783,226.01287859947368 +159,-1,7.464026514044788,216.4104396525505 +159,14,7.4915467436021705,216.40745842395606 +159,15,7.461538916891947,216.45868668297425 +159,16,7.364182434477678,216.2806002793832 +161,-1,7.69591839706524,214.7091329577858 +161,14,7.738255569663876,214.70115960937682 +161,15,7.697153497133314,214.75829189024446 +161,16,7.5276800140762,214.59561279824783 +163,-1,7.698867054746856,225.2672977944934 +163,14,7.699274574169289,225.2934191562494 +163,15,7.656351176251079,225.2945247105456 +163,16,7.818068246327807,225.08827459682033 +167,-1,7.718018838785841,214.15798893418102 +167,14,7.763211710057439,214.15741223560684 +167,15,7.720933941701662,214.2064877498692 +167,16,7.533168143977793,214.0182671398676 +170,-1,7.546657926708745,223.77545629991488 +170,14,7.5461322820150825,223.79833952257727 +170,15,7.512742311399138,223.80006419984113 +170,16,7.645109368356236,223.6164130602323 +173,-1,7.552237663206375,214.1789944669328 +173,14,7.591783086782173,214.1925938020429 +173,15,7.557935743250408,214.21820238320956 +173,16,7.380624199965868,214.0125406215213 +176,-1,7.564065178026076,224.83719212332218 +176,14,7.566693114928036,224.86707767713 +176,15,7.525817301959962,224.8608799430772 +176,16,7.662517511203877,224.65420773887905 +180,-1,7.252942147716254,215.3265210200013 +180,14,7.283814565233424,215.34775487150597 +180,15,7.259103868481533,215.36094635815928 +180,16,7.113257685897501,215.14502379139554 +185,-1,6.8869581018557975,215.74855491132158 +185,14,6.906874636845168,215.7818819327798 +185,15,6.893310747644031,215.77820638098365 +185,16,6.787983743229243,215.53503027424588 +187,-1,7.062935066210278,215.25181028462808 +187,14,7.087792760005573,215.28216474212186 +187,15,7.07010013109274,215.27815728777406 +187,16,6.943214365135618,215.05926122474042 +193,-1,6.53596841076748,215.57741454642513 +193,14,6.547352499276671,215.63103640641575 +193,15,6.544946438746775,215.5936921415412 +193,16,6.460164671712149,215.32583235119435 +199,-1,5.800751854744839,216.8961097964946 +199,14,5.793573309769425,216.9724593866467 +199,15,5.808696761376613,216.8991096338648 +199,16,5.795143699812542,216.5969352590282 +205,-1,5.2700606364895375,218.76010916310278 +205,14,5.250813415499617,218.8505357450125 +205,15,5.272625140182401,218.75399623654695 +205,16,5.32232170623981,218.43405802877663 +207,-1,6.288345263437963,217.06348538934702 +207,14,6.294573370528691,217.1165019236446 +207,15,6.293036118285355,217.07757839806638 +207,16,6.245178491230654,216.82059800567745 +210,-1,5.731451283090575,219.13252439930653 +210,14,5.726789576413561,219.2017527024309 +210,15,5.730222454320901,219.13033470505874 +210,16,5.745404332176095,218.87562115568838 +212,-1,5.028236697209228,220.00352679133002 +212,14,5.00567800627236,220.10186181852578 +212,15,5.02714807626114,219.98761010324364 +212,16,5.102096536959653,219.6760831660727 +214,-1,6.236040619863338,218.1030528103395 +214,14,6.240742367899467,218.15758451666318 +214,15,6.235890313948984,218.10497702079036 +214,16,6.214004894970708,217.89001029820267 +217,-1,4.9927452542318544,221.16464182092992 +217,14,4.974166818565571,221.25652074426412 +217,15,4.987014339327083,221.14875199590915 +217,16,5.06700321527197,220.86167545453176 +219,-1,4.920934776073383,220.96992925332594 +219,14,4.896871518072765,221.07311635336245 +219,15,4.916166366697869,220.9478960098413 +219,16,5.0101825091003835,220.64192873097093 +222,-1,5.961704357965497,218.49242315960194 +222,14,5.962238086522809,218.54662712593242 +222,15,5.964325572159455,218.5041659201264 +222,16,5.946159169413703,218.25189656594722 +224,-1,4.778606753772839,222.6182877919982 +224,14,4.757612773384559,222.72436240249357 +224,15,4.768753718603335,222.58948269320317 +224,16,4.870803982708677,222.2991201200239 +226,-1,4.88119513675867,222.05944378173623 +226,14,4.8567976506723625,222.1656664843828 +226,15,4.87196256335097,222.03986092158735 +226,16,4.982866176273519,221.71272709766347 +230,-1,4.664961338696116,222.94336360464408 +230,14,4.646192798192134,223.0501747836415 +230,15,4.654830750517009,222.91899341827536 +230,16,4.748211026947446,222.60841712453356 +232,-1,4.8350244025225,223.74000897751085 +232,14,4.815150911422786,223.8523450423326 +232,15,4.820446612837068,223.7162947829025 +232,16,4.933666047443132,223.38212899475104 +236,-1,4.859761377552439,225.9191299803178 +236,14,4.842740474441567,226.0295641248934 +236,15,4.841599800068688,225.88598215071707 +236,16,4.960546555720021,225.59608829420836 +239,-1,5.467313246912352,227.8967950024798 +239,14,5.451835008694295,228.0154773106447 +239,15,5.4458889515755455,227.86808965445513 +239,16,5.57023557425349,227.52938192233472 +243,-1,5.240257239687782,228.65808899000416 +243,14,5.22646382384671,228.7584816653345 +243,15,5.222186498375415,228.63459306713528 +243,16,5.331479379629498,228.3449971270908 +246,-1,5.880139203896406,230.60246717315405 +246,14,5.869396606351132,230.72025063610533 +246,15,5.856217232148349,230.56504060616504 +246,16,5.975468937959747,230.26399253938123 +249,-1,5.74782862396384,230.85625401632393 +249,14,5.742160061573653,230.97752671931502 +249,15,5.724773031686518,230.8180735421985 +249,16,5.8203843223439,230.50671407095544 +252,-1,6.273771895326091,233.2878545221113 +252,14,6.272362232578364,233.41427312285842 +252,15,6.245326911623123,233.23585231425395 +252,16,6.3476596682873385,232.9591866945953 +254,-1,5.136005280497208,227.67468038407844 +254,14,5.120598649244679,227.77475306402775 +254,15,5.112426967408727,227.6549382846771 +254,16,5.248160781504058,227.3518213407548 +260,-1,5.726940641008719,231.23836402055048 +260,14,5.718888335051473,231.3448463080937 +260,15,5.70022601100582,231.21643324195225 +260,16,5.820736843838926,230.8975303864223 +262,-1,6.393923886569731,234.0608486343007 +262,14,6.3958963857630575,234.18622153489193 +262,15,6.364043919171431,234.01611848517447 +262,16,6.458827250200367,233.71487893493682 +265,-1,7.378273750338902,238.30549465200343 +265,14,7.392674886207412,238.45090390230934 +265,15,7.342690670824257,238.23883084156176 +265,16,7.412221994822641,237.9474978616577 +267,-1,5.970650279599814,232.74623868551862 +267,14,5.968881263915609,232.8660409306501 +267,15,5.937687144334024,232.69422886908845 +267,16,6.059952029484149,232.4427585057461 +269,-1,6.072017788435322,232.69632348167733 +269,14,6.072017788435295,232.69632348167733 +275,-1,6.052555500232738,233.88854924028468 +275,14,6.041958871290569,233.9789303864483 +275,15,6.030892205748039,233.8760285344375 +275,16,6.144998628740027,233.58142110606303 +282,-1,6.2495952313646965,232.85263983393475 +282,14,6.242416084903742,232.96962075010026 +282,15,6.217785151752775,232.81560791510654 +282,16,6.355287814727145,232.5160628924962 +294,-1,5.494039295767262,231.68938071415954 +294,14,5.494057870049777,231.80042014753633 +294,15,5.464084156480761,231.65182547783928 +294,16,5.567365996122895,231.37693358058252 +296,-1,4.829539784513678,226.99509365083532 +296,14,4.8213307835236,227.10483239075413 +296,15,4.807015251600032,226.9597757584889 +296,16,4.910484129516022,226.68104689412962 +299,-1,5.448866379042412,219.94589289829761 +299,14,5.448212848647708,220.00687918187617 +299,15,5.447110386315124,219.9585615736143 +299,16,5.448506698638163,219.6768603094058 +303,-1,7.750554377252405,225.81008058875298 +303,14,7.7462404816965975,225.83657203315397 +303,15,7.709112544692212,225.84480714311061 +303,16,7.8837132766385025,225.60770463909384 +306,-1,5.614883443542906,229.38874751003064 +306,14,5.612786171679389,229.5368667776371 +306,15,5.583975713417016,229.33463088641128 +306,16,5.689750355743291,228.98372806719698 +323,-1,7.5100961266193424,238.44182641661223 +323,14,7.511412641374718,238.56762864868503 +323,15,7.4706584880549,238.38920058960383 +323,16,7.6079905781592725,238.11732776351005 +327,-1,11.970767887574866,256.096874642667 +327,14,11.995565188393346,256.1931710632066 +327,15,11.929783684600043,256.0704806865801 +327,16,11.99100990826014,255.80784303909473 +329,-1,6.828871508535757,235.64816212750588 +329,14,6.823601413641264,235.76569057504003 +329,15,6.796979202003627,235.6117052546416 +329,16,6.928670985973985,235.3078199680253 +335,-1,6.618308722179919,234.7690499213081 +335,14,6.610746561789071,234.8758506409663 +335,15,6.5889520441805285,234.74721015987876 +335,16,6.719753990161554,234.4267387962481 +345,-1,10.15041058737343,251.8145274858527 +345,14,10.16008759540388,251.87994192274056 +345,15,10.119130577788827,251.79821843395695 +345,16,10.20201294579696,251.61344581942174 +347,-1,7.02157731140916,237.0483588959896 +347,14,7.013258743082633,237.14238549358302 +347,15,6.995906813137304,237.03850401127193 +347,16,7.117410282343044,236.71956462776447 +350,-1,11.58840162871626,254.62142256963074 +350,14,11.616685398916749,254.73921169641244 +350,15,11.545236788631783,254.57520474458045 +350,16,11.599991671605705,254.30865101736663 +354,-1,6.7964002801322625,235.4253909848031 +354,14,6.791804254104565,235.5491452019627 +354,15,6.763637137284354,235.38408759537108 +354,16,6.895140785193838,235.0755507205936 +356,-1,11.507586718151394,254.3002904734637 +356,14,11.535542797388775,254.4165010403384 +356,15,11.464133996783593,254.25189033531038 +356,16,11.521577462179808,253.9999087881439 +360,-1,6.77672566817323,235.38185441564718 +360,14,6.769577522822322,235.49307639638096 +360,15,6.746485201493061,235.35166852303226 +360,16,6.879477689038735,235.047149085477 +366,-1,6.908380071815146,235.95007347085465 +366,14,6.902349377773114,236.0554013851113 +366,15,6.87874093832045,235.92631923553213 +366,16,7.006052154748827,235.61896623400625 +371,-1,9.516379557069625,246.65488108848788 +371,14,9.523490386968227,246.74324624407714 +371,15,9.485993846557921,246.63910887809115 +371,16,9.571917218671853,246.3649352185187 +373,-1,8.894879473914333,244.63032202360267 +373,14,8.909259864454283,244.75815476708672 +373,15,8.855351184172754,244.57972387733676 +373,16,8.94506425316418,244.29216697983114 +375,-1,12.27350491321289,257.99445036350943 +375,14,12.308225326917551,258.1014793103917 +375,15,12.225681954163868,257.9599650719904 +375,16,12.275188056806059,257.68827401255896 +377,-1,9.939829683218493,248.68326854414727 +377,14,9.95098238557818,248.7721507577566 +377,15,9.907150088497454,248.66660181344898 +377,16,9.986981037528164,248.39397354726245 +379,-1,8.503895671238197,243.7093437460912 +379,14,8.516419018717109,243.83098255704354 +379,15,8.46516328651614,243.66238986056536 +379,16,8.559218626227599,243.384083706702 +381,-1,11.540785170230743,256.7985204481981 +381,14,11.570559849088905,256.8933898559721 +381,15,11.496483626254067,256.7717612867613 +381,16,11.551738035507407,256.5159851530432 +383,-1,8.063462630428717,244.58393101467829 +383,14,8.069503904357187,244.65757576855154 +383,15,8.036488739142774,244.57319197900864 +383,16,8.113980877679472,244.3352464876607 +386,-1,6.056839764813776,236.16099910546123 +386,14,6.0751048422731735,236.28772711901934 +386,15,6.025361336707141,236.11328028982072 +386,16,6.063861619114356,235.81862053365026 +389,-1,9.10388002663358,246.87457208355843 +389,14,9.11858023400268,246.97370253653116 +389,15,9.068033695258684,246.8487166315192 +389,16,9.145364757228005,246.5731854566144 +398,-1,11.353089958256309,256.66458549740116 +398,14,11.365006758784029,256.71982371111517 +398,15,11.313166310790434,256.6522470217014 +398,16,11.422283369207516,256.4905906720129 +401,-1,10.330896533053043,253.38017926908918 +401,14,10.335075775778028,253.4166369006013 +401,15,10.301263392878102,253.39208476396655 +401,16,10.39963425278723,253.20667487509613 +404,-1,13.63778447444461,261.7360465824113 +404,14,13.657818962931827,261.7966134650399 +404,15,13.594106668067507,261.72798288950935 +404,16,13.686956270976031,261.52927540757753 +407,-1,12.935376589053641,259.5612783877454 +407,14,12.951318904902077,259.60818968739346 +407,15,12.898895644469256,259.5750919068645 +407,16,12.979031143727637,259.34243016468224 +410,-1,13.924543322990335,262.71473785611414 +410,14,13.943801799840758,262.76918836815963 +410,15,13.880645296739726,262.7108038594756 +410,16,13.977611883960067,262.5191462161045 +414,-1,9.698625013440502,249.6312714744092 +414,14,9.704615939732308,249.6886547277935 +414,15,9.667778376110249,249.62804768349324 +414,16,9.762842270571943,249.42244694408222 +416,-1,13.580244135615462,262.44255683934585 +416,14,13.594829403849909,262.48083780661335 +416,15,13.538058838059145,262.4476133110677 +416,16,13.646670888301138,262.28215870175234 +426,-1,11.126427121785305,254.3804781826789 +426,14,11.14363559391911,254.44406488937602 +426,15,11.089175013737325,254.3767405751045 +426,16,11.166585648638677,254.149562256647 +428,-1,14.53588401180475,265.13996300574706 +428,14,14.553356605289801,265.17236151218736 +428,15,14.492108044681666,265.15412544570063 +428,16,14.596074858798858,264.97529335869234 +432,-1,11.405138523224721,255.9063139726011 +432,14,11.432646738364928,255.9998031617535 +432,15,11.359629812364503,255.87994425158848 +432,16,11.428316189927665,255.62788878984793 +434,-1,14.847994247457473,266.03754936053457 +434,14,14.867051825011933,266.0750597483496 +434,15,14.803205373552865,266.0495249567533 +434,16,14.904956901061958,265.85983566825684 +439,-1,10.495895774701248,254.98205128279662 +439,14,10.507221501050529,255.02823825740842 +439,15,10.459881581465607,254.98103497807034 +439,16,10.55601803777103,254.80935246928834 +445,-1,12.591433814305,260.644800636277 +445,14,12.61051925167218,260.6914824781264 +445,15,12.548142151960638,260.6460797071974 +445,16,12.643505702832249,260.46350295417517 +457,-1,14.880979235884821,266.0321454147847 +457,14,14.898259314982242,266.06060032490154 +457,15,14.84040627158832,266.0551889855668 +457,16,14.932495903313692,265.85648767519433 +515,-1,18.616299447313228,274.62756596122557 +515,14,18.59349307023486,274.5640268106597 +515,15,18.59822766251503,274.7177701583755 +515,16,18.75426974616796,274.60528596797053 +546,-1,18.75597450437593,274.9546460225111 +546,14,18.730767371985028,274.88663340579564 +546,15,18.74014763057299,275.04957567686847 +546,16,18.89637130991383,274.9355534667884 +552,-1,18.642194868105243,274.8093847761069 +552,14,18.618144678084686,274.74309736927034 +552,15,18.624932474580024,274.9023149942151 +552,16,18.782436482700035,274.7895810345104 +559,-1,13.067877296512076,263.337966804856 +559,14,13.055116170921012,263.29421882164144 +559,15,13.037193285290632,263.3876881337999 +559,16,13.204782665906883,263.35887021387873 +566,-1,18.76259695242809,275.09332568997456 +566,14,18.738058912536506,275.0276461394868 +566,15,18.745749679838248,275.1858111594942 +566,16,18.903497838384443,275.07251136518175 +571,-1,18.84143266493505,275.2338212878694 +571,14,18.815991717575205,275.16638130608817 +571,15,18.825470772739934,275.32843146073253 +571,16,18.98312202542528,275.2134855624938 +573,-1,18.856943026223398,275.2346418821147 +573,14,18.83102226708298,275.16707335014434 +573,15,18.84171312497865,275.32986134712405 +573,16,18.99830646889034,275.2130122127985 +646,-1,18.04592973889992,273.949408003435 +646,14,18.030180062483897,273.8989773580635 +646,15,18.024660765201908,274.03577970666134 +646,16,18.166628431510084,273.8884841503539 +662,-1,19.185774288732137,275.99279054363757 +662,14,19.15636478378864,275.92164475459964 +662,15,19.173138731994413,276.0915764072293 +662,16,19.332692477901812,275.9743311609092 +668,-1,19.20198788733615,276.01891380471204 +668,14,19.17275620232605,275.9486847002138 +668,15,19.18924202926627,276.11705330909894 +668,16,19.348578370753795,275.9988591520856 +756,-1,19.34556970172583,276.49093563382763 +756,14,19.315010279373908,276.4191241660325 +756,15,19.33488469062897,276.59048431995075 +756,16,19.491156046477627,276.4727760273699 +867,-1,19.421788916570417,276.7479601946552 +867,14,19.391811361899883,276.676702071589 +867,15,19.41226301733536,276.8480674662396 +867,16,19.561819444549077,276.7260614250309 +906,-1,16.422217700030203,271.7480802068924 +906,14,16.402760624572746,271.6879650474609 +906,15,16.403696065428843,271.8364124058574 +906,16,16.54871814487146,271.7182547898225 +921,-1,19.426181996325298,276.7900859768641 +921,14,19.395694919391612,276.71700067044156 +921,15,19.4179459913375,276.89123667890834 +921,16,19.56435088908476,276.7720836758886 +1027,-1,19.46579528406605,277.1604254829402 +1027,14,19.43468806378315,277.0862659578881 +1027,15,19.45860072801157,277.2631518593977 +1027,16,19.603241798242813,277.14189830997066 +1046,-1,18.282295039306184,275.55010866320254 +1046,14,18.258355649331165,275.48333455200776 +1046,15,18.265128488213495,275.6403443748135 +1046,16,18.421847194256124,275.54004066315525 +1121,-1,17.261161478035785,273.8952547436953 +1121,14,17.236975373061266,273.82386399149675 +1121,15,17.242554862181034,273.9836573660174 +1121,16,17.40567242653493,273.90811010199513 +1133,-1,19.29533650578031,277.1044125820725 +1133,14,19.261237875116827,277.0189986403982 +1133,15,19.290030525112403,277.21278992009127 +1133,16,19.438299977471182,277.11215609546423 +1191,-1,18.852822365193983,276.54135819067676 +1191,14,18.832900515573378,276.482544009772 +1191,15,18.838369726217405,276.6295916735632 +1191,16,18.969484601453956,276.5068733459806 +1194,-1,19.451907241562893,277.75645682341855 +1194,14,19.423803951190358,277.68411094990626 +1194,15,19.444049896149764,277.8576828657158 +1194,16,19.57999329843948,277.73542162335525 +1271,-1,19.48941888805434,278.0101384860552 +1271,14,19.46026513041647,277.93263328721787 +1271,15,19.48328945794513,278.11463836225863 +1271,16,19.61633357203983,277.9991470781219 +1479,-1,14.003271796193586,259.1039398028353 +1479,14,14.023870037608342,259.04525476995155 +1479,15,13.97171954432275,259.1785104148273 +1479,16,14.016036288059379,259.10894348266174 +1509,-1,15.591353871907671,258.5569114049743 +1509,14,15.609310298685017,258.48531238328394 +1509,15,15.572719683093155,258.6640189061237 +1509,16,15.575747634521157,258.5158254583758 +1513,-1,15.079763374153343,255.89685919331828 +1513,14,15.105898417236446,255.84534871770634 +1513,15,15.047391776801499,255.9622248694587 +1513,16,15.074144451656034,255.90150947412917 +1521,-1,15.064716730305417,255.33448328473793 +1521,14,15.088352039526685,255.27792553752496 +1521,15,15.039099556571495,255.41986509120565 +1521,16,15.048484960880778,255.2997605290852 +1536,-1,14.28939737897196,255.2968575988283 +1536,14,14.312329654356407,255.25806893946623 +1536,15,14.251110199797772,255.34173535418165 +1536,16,14.31355273236883,255.31307123240396 +1574,-1,13.614524404385973,249.39249961335182 +1574,14,13.643401174873075,249.34998003817978 +1574,15,13.577509299119162,249.47812747185034 +1574,16,13.611409250975568,249.30366252315122 +1579,-1,12.236378526082376,244.623206508289 +1579,14,12.265659367159126,244.59630407161652 +1579,15,12.194427428269599,244.70370711914782 +1579,16,12.24580437744268,244.48997259893395 +1585,-1,12.998776357787941,245.46666446126147 +1585,14,13.022781756884422,245.43871875394353 +1585,15,12.958316421064357,245.54687680010332 +1585,16,13.024097991428917,245.33824217359629 +1587,-1,13.089022743202776,246.57049258361712 +1587,14,13.114640635222658,246.54915508132876 +1587,15,13.043621738584203,246.6348466438326 +1587,16,13.123097847699738,246.46333976679844 +1626,-1,13.630933361197691,251.45510143924912 +1626,14,13.6557836910461,251.4072996663385 +1626,15,13.597509044810439,251.53567865152095 +1626,16,13.632891260370846,251.40113424504813 +1639,-1,14.345958863144594,252.0692843484591 +1639,14,14.377405832995205,252.0425410256133 +1639,15,14.306399635354172,252.12084787022783 +1639,16,14.341365204114293,252.0201197881668 +1643,-1,13.029604765305658,252.98801279431243 +1643,14,13.060427059398492,252.9575962315062 +1643,15,12.993281750365608,253.0792245794171 +1643,16,13.016418886607449,252.83680229860192 +1652,-1,14.526103063991611,255.2184106133653 +1652,14,14.549156435838132,255.16899169122942 +1652,15,14.499962790792573,255.31311525888339 +1652,16,14.51326942584134,255.12925514057477 +1658,-1,14.3179641996147,253.99964876675853 +1658,14,14.348125828665946,253.953650560755 +1658,15,14.283742248570645,254.07517693892515 +1658,16,14.302196072174167,253.95359598235908 +1671,-1,13.921213680433716,255.9039444813858 +1671,14,13.948791498258847,255.86331132628683 +1671,15,13.88467284744436,255.97222424447693 +1671,16,13.922212394348211,255.85869566646258 +1786,-1,15.722962789233087,271.4685861841164 +1786,14,15.712200034853957,271.42250837908585 +1786,15,15.693448973638358,271.5346392573149 +1786,16,15.848988222275315,271.4505618070914 +1805,-1,19.216621445503318,276.93693949617574 +1805,14,19.18914934421494,276.86388605914993 +1805,15,19.2103957857847,277.03952568409505 +1805,16,19.337529168829906,276.9146155188494 +1884,-1,19.31675177300453,277.19297637469265 +1884,14,19.290838616179162,277.1304312664147 +1884,15,19.309086978303597,277.28728209756633 +1884,16,19.436205402182516,277.15491381151975 +1991,-1,19.347680239749884,277.3406286548743 +1991,14,19.322204512627234,277.27761583963525 +1991,15,19.340853260551267,277.43894816653284 +1991,16,19.46296299211293,277.2926000339633 +2019,-1,14.235255729618855,268.0216077158337 +2019,14,14.21888058640275,267.97801394344276 +2019,15,14.215647536273929,268.09358605893146 +2019,16,14.35349248502533,267.97679700715736 +2094,-1,19.33295184956739,277.3530364023607 +2094,14,19.301167783751467,277.28632382478645 +2094,15,19.32941472516561,277.45610208129824 +2094,16,19.462359161042546,277.30519182674897 +2119,-1,10.711796666027817,257.5818102968717 +2119,14,10.69214136329505,257.5475919456452 +2119,15,10.69445482871388,257.62997845884485 +2119,16,10.835941744355006,257.57101245663057 +2213,-1,12.817802029389666,264.4890486234572 +2213,14,12.800252833872966,264.46930115928285 +2213,15,12.795835352964355,264.5297428487416 +2213,16,12.947747097842765,264.44508053111105 +2239,-1,19.37025674073832,277.61169319938904 +2239,14,19.33679863406121,277.55992181886194 +2239,15,19.369038350674273,277.70040250870807 +2239,16,19.499579989800335,277.54902863927293 +2271,-1,12.364759081132737,263.61131345421563 +2271,14,12.348895285933164,263.59187163120913 +2271,15,12.342593337758021,263.65381216366546 +2271,16,12.488863680448933,263.5609026435921 +2321,-1,10.649745680688648,257.3869318986451 +2321,14,10.633256182606967,257.3627158336515 +2321,15,10.631329898024656,257.4235501162515 +2321,16,10.765329469803822,257.37188722924014 +2398,-1,10.489027352009948,256.2541897939565 +2398,14,10.469945446216366,256.223054352885 +2398,15,10.476576474018023,256.3010877344666 +2398,16,10.596897332362706,256.2353829405692 +2701,-1,8.25973293977833,244.6944761946365 +2701,14,8.244452068297635,244.67174994217345 +2701,15,8.251703698077069,244.71672008724994 +2701,16,8.340653188579871,244.71582657503615 +2889,-1,12.324990249941138,263.2867496129323 +2889,14,12.306602516379373,263.2613272918366 +2889,15,12.305311857966384,263.3301014613091 +2889,16,12.45140439890726,263.25658927417913 +2958,-1,12.05427991740345,262.2901246398327 +2958,14,12.035137463334316,262.265139183448 +2958,15,12.036010904352661,262.33263720520006 +2958,16,12.179444979071329,262.2607585571961 +2993,-1,10.190504907617465,253.85828053116762 +2993,14,10.170282996944392,253.82852160573967 +2993,15,10.181879161274857,253.89021946864526 +2993,16,10.291739023428397,253.87801051482728 +3002,-1,7.760179029064197,240.63143191734082 +3002,14,7.7416097397684025,240.59383525136502 +3002,15,7.758928370181207,240.6557305887662 +3002,16,7.833247728625199,240.7033294739927 +3049,-1,11.212765053659972,258.75077355793337 +3049,14,11.192915564559042,258.7273404861564 +3049,15,11.198220747726959,258.78579015191326 +3049,16,11.32985514644393,258.73743737123226 +3105,-1,9.117239439627202,247.6530050811262 +3105,14,9.097663419689905,247.61961326971704 +3105,15,9.113938490772538,247.68209635883213 +3105,16,9.200430530934419,247.69488549484674 +3110,-1,10.001364873914747,252.09191209401047 +3110,14,9.980762315709631,252.06165585991624 +3110,15,9.99535884000425,252.1227725293938 +3110,16,10.096419965548487,252.1166894660688 +5381,-1,13.204452482935643,268.144309824923 +5381,14,13.195324529035533,268.11913805670696 +5381,15,13.179048368501745,268.1877530662393 +5381,16,13.312643110063604,268.1129290751175 +5399,-1,14.489583214759742,271.1797220556748 +5399,14,14.477904547949347,271.14395712877456 +5399,15,14.466168725748568,271.2349350012913 +5399,16,14.601514188998735,271.1541924724893 +5480,-1,15.398775357115134,272.9882842439366 +5480,14,15.387155916364732,272.9522836539787 +5480,15,15.37632947640946,273.0454741803754 +5480,16,15.507699913426912,272.9578660172472 diff --git a/physicsnemo/experimental/datapipes/healda/configs/normalizations/mhs_normalizations.csv b/physicsnemo/experimental/datapipes/healda/configs/normalizations/mhs_normalizations.csv new file mode 100644 index 0000000000..49df36963d --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/configs/normalizations/mhs_normalizations.csv @@ -0,0 +1,31 @@ +Raw_Channel_ID,Platform_ID,obs_std,obs_mean +1,-1,27.314051741583636,237.9566862096014 +1,14,27.237688807568933,237.82756663196426 +1,15,27.25539242564268,238.15684577291427 +1,16,27.104623367838013,238.3287148131825 +1,23,27.342755373663014,237.84289789226148 +1,24,27.452495102189413,237.89947810831046 +2,-1,26.618813103955723,261.915233839472 +2,14,26.714805283241724,261.70902595462064 +2,15,26.566889256684032,262.0244276665933 +2,16,26.501021796581963,262.0089071790408 +2,23,26.635057018276363,261.8615537641429 +2,24,26.598255721079855,262.0179211819286 +3,-1,9.442343271781501,247.38245044616173 +3,14,9.360823103781955,247.3196715240357 +3,15,9.358463680886302,247.33785121070034 +3,16,9.332064063561445,247.28928810997937 +3,23,9.357006869204472,247.3811298765863 +3,24,10.021122147008533,247.67779653526242 +4,-1,12.917529068672312,258.2913764089477 +4,14,12.894516788434766,258.1149392180758 +4,15,12.878421890127624,258.1926237313513 +4,16,12.800390334130613,258.1741921825457 +4,23,12.968961486043167,258.30936712616074 +4,24,12.952925642222597,258.52580165711 +5,-1,17.997020725294178,265.0566627761782 +5,14,18.121249551151013,265.00266964773476 +5,15,18.000085396832166,265.10250063093014 +5,16,17.85497341358148,265.1404196358589 +5,23,18.051736849065584,264.9905259829178 +5,24,17.88897715825186,265.09578591765796 diff --git a/physicsnemo/experimental/datapipes/healda/configs/sensors.py b/physicsnemo/experimental/datapipes/healda/configs/sensors.py new file mode 100644 index 0000000000..e1322c9dea --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/configs/sensors.py @@ -0,0 +1,287 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Sensor metadata and configuration for observation loading. + +Defines sensor configurations, platform mappings, and channel offsets used +by ``UFSUnifiedLoader`` and the observation transform pipeline. +""" + +import pathlib +from dataclasses import dataclass, field + +import numpy as np +import pandas as pd + + +@dataclass +class SensorConfig: + """Sensor metadata for data loading and normalization.""" + + name: str + platforms: list[str] + channels: int + nc_file_template: str + means: np.ndarray = field(init=False) + stds: np.ndarray = field(init=False) + min_valid: float = 0.0 + max_valid: float = 400.0 + sensor_type: str = "microwave" + raw_to_local: np.ndarray = field(init=False) + + def __post_init__(self): + base = pathlib.Path(__file__).parent / "normalizations" + norm_file = base / f"{self.name}_normalizations.csv" + + if norm_file.exists(): + df = pd.read_csv(norm_file) + channel_col = "Raw_Channel_ID" + df = df[df["Platform_ID"] == -1].sort_values(channel_col) + + self.means = df["obs_mean"].to_numpy() + self.stds = df["obs_std"].to_numpy() + + raw_ids = df[channel_col].to_numpy() + max_raw = raw_ids.max() + lookup_table = np.full(max_raw + 1, 0, dtype=int) + for local_idx, raw in enumerate(raw_ids, start=1): + lookup_table[raw] = local_idx + self.raw_to_local = lookup_table + else: + self.means = np.zeros(self.channels, dtype=float) + self.stds = np.ones(self.channels, dtype=float) + self.raw_to_local = None + + +def _build_identity_lut(raw_ids: np.ndarray) -> np.ndarray: + """Build a 1-indexed identity LUT from observed raw channel IDs.""" + raw_ids = np.asarray(raw_ids).ravel() + unique = np.unique(raw_ids) + lut = np.zeros(int(unique.max()) + 1, dtype=int) + for local_idx, raw in enumerate(unique, start=1): + lut[raw] = local_idx + return lut + + +def get_global_channel_id(sensor, raw_channel_ids): + """Map per-sensor raw channel IDs to unified global IDs.""" + cfg = SENSOR_CONFIGS[sensor] + if cfg.raw_to_local is None: + cfg.raw_to_local = _build_identity_lut(raw_channel_ids) + raw_to_local = cfg.raw_to_local + channel_offset = SENSOR_OFFSET[sensor] + raw_channel_ids = np.asarray(raw_channel_ids) + safe_ids = np.minimum(raw_channel_ids, len(raw_to_local) - 1) + local_channels = raw_to_local[safe_ids] - 1 + return (local_channels + channel_offset).astype(np.uint16) + + +# --------------------------------------------------------------------------- +# Sensor registry +# --------------------------------------------------------------------------- + +SENSOR_CONFIGS = { + "atms": SensorConfig( + name="atms", + platforms=["npp", "n20"], + channels=22, + nc_file_template="diag_atms_{platform}_ges.{date}_control.nc4", + min_valid=0.0, + max_valid=400.0, + sensor_type="microwave", + ), + "mhs": SensorConfig( + name="mhs", + platforms=["metop-a", "metop-b", "metop-c", "n18", "n19"], + channels=5, + nc_file_template="diag_mhs_{platform}_ges.{date}_control.nc4", + min_valid=0.0, + max_valid=400.0, + sensor_type="microwave", + ), + "amsua": SensorConfig( + name="amsua", + platforms=[ + "metop-a", "metop-b", "metop-c", "n15", "n16", "n17", "n18", "n19", + ], + channels=15, + nc_file_template="diag_amsua_{platform}_ges.{date}_control.nc4", + min_valid=0.0, + max_valid=400.0, + sensor_type="microwave", + ), + "amsub": SensorConfig( + name="amsub", + platforms=["n15", "n16", "n17"], + channels=5, + nc_file_template="diag_amsub_{platform}_ges.{date}_control.nc4", + min_valid=0.0, + max_valid=400.0, + sensor_type="microwave", + ), + "iasi": SensorConfig( + name="iasi", + platforms=["metop-a", "metop-b", "metop-c"], + channels=175, + nc_file_template="diag_iasi_{platform}_ges.{date}_control.nc4", + min_valid=150.0, + max_valid=350.0, + sensor_type="infrared", + ), + "cris-fsr": SensorConfig( + name="cris-fsr", + platforms=["npp", "n20"], + channels=100, + nc_file_template="diag_cris_fsr_{platform}_ges.{date}_control.nc4", + min_valid=150.0, + max_valid=350.0, + sensor_type="infrared", + ), + "conv": SensorConfig( + name="conv", + platforms=[], + channels=8, + nc_file_template="conv_{platform}_ges.{date}_control.nc4", + sensor_type="conv", + ), + "iasi-pca": SensorConfig( + name="iasi-pca", + platforms=["metop-a", "metop-b", "metop-c"], + channels=32, + nc_file_template="", + min_valid=float("-inf"), + max_valid=float("inf"), + sensor_type="infrared", + ), + "cris-fsr-pca": SensorConfig( + name="cris-fsr-pca", + platforms=["npp", "n20"], + channels=32, + nc_file_template="", + min_valid=float("-inf"), + max_valid=float("inf"), + sensor_type="infrared", + ), + "airs": SensorConfig( + name="airs", + platforms=["aqua"], + channels=117, + nc_file_template="diag_airs_{platform}_ges.{date}_control.nc4", + min_valid=150.0, + max_valid=350.0, + sensor_type="infrared", + ), + "airs-pca": SensorConfig( + name="airs-pca", + platforms=["aqua"], + channels=32, + nc_file_template="", + min_valid=float("-inf"), + max_valid=float("inf"), + sensor_type="infrared", + ), +} + + +# --------------------------------------------------------------------------- +# QC filtering limits for conventional observations +# --------------------------------------------------------------------------- + + +class QCLimits: + HEIGHT_MIN = 0 + HEIGHT_MAX = 60000 + PRESSURE_MIN_GPS = 0.5 + PRESSURE_MIN_DEFAULT = 200 + PRESSURE_MAX = 1100 + + +# --------------------------------------------------------------------------- +# Conventional channel definitions +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class ConvChannel: + name: str + platform: str + nc_column: str + min_valid: float + max_valid: float + + +CONV_CHANNELS = [ + ConvChannel("gps_angle", "gps", "Observation", float("-inf"), float("inf")), + ConvChannel("gps_t", "gps", "Temperature_at_Obs_Location", 150, 350), + ConvChannel("gps_q", "gps", "Specific_Humidity_at_Obs_Location", 0.0, 1.0), + ConvChannel("ps", "ps", "Observation", float("-inf"), float("inf")), + ConvChannel("q", "q", "Observation", 0, 1), + ConvChannel("t", "t", "Observation", 150, 350), + ConvChannel("u", "uv", "u_Observation", -100, 100), + ConvChannel("v", "uv", "v_Observation", -100, 100), +] + +CONV_CHANNEL_NAMES = [c.name for c in CONV_CHANNELS] +CONV_PLATFORMS = list(dict.fromkeys(c.platform for c in CONV_CHANNELS)) +CONV_GPS_CHANNELS = [i for i, c in enumerate(CONV_CHANNELS) if c.platform == "gps"] +CONV_GPS_LEVEL2_CHANNELS = [ + i for i, c in enumerate(CONV_CHANNELS) if c.name in ("gps_t", "gps_q") +] +CONV_UV_CHANNELS = [i for i, c in enumerate(CONV_CHANNELS) if c.platform == "uv"] +CONV_UV_IN_SITU_TYPES = [220, 221, 229, 230, 231, 232, 233, 234, 235, 280, 282] + + +def _build_conv_channel_map() -> dict[str, int]: + channel_map = {} + for i, channel in enumerate(CONV_CHANNELS, start=1): + if channel.platform not in channel_map: + channel_map[channel.platform] = i + return channel_map + + +CONV_CHANNEL_MAP = _build_conv_channel_map() + + +def _next_power_of_two(n: int) -> int: + return 1 << (n - 1).bit_length() + + +# --------------------------------------------------------------------------- +# Platform and sensor ID mappings +# --------------------------------------------------------------------------- + +PLATFORM_NAME_TO_ID = { + "aqua": 0, "aura": 1, "f10": 2, "f11": 3, "f13": 4, "f14": 5, + "f15": 6, "g08": 7, "g10": 8, "g11": 9, "g12": 10, "m08": 11, + "m09": 12, "m10": 13, "metop-a": 14, "metop-b": 15, "metop-c": 16, + "n11": 17, "n12": 18, "n14": 19, "n15": 20, "n16": 21, "n17": 22, + "n18": 23, "n19": 24, "n20": 25, "npp": 26, "gps": 27, "ps": 28, + "q": 29, "t": 30, "uv": 31, +} + +PLATFORM_ID_TO_NAME = {v: k for k, v in PLATFORM_NAME_TO_ID.items()} +NPLATFORMS = _next_power_of_two(max(len(PLATFORM_NAME_TO_ID), 64)) # 64 + +# Global channel offsets (contiguous across sensors) +SENSOR_OFFSET = {} +offset = 0 +for name, cfg in SENSOR_CONFIGS.items(): + SENSOR_OFFSET[name] = offset + offset += cfg.channels +NCHANNEL = _next_power_of_two(max(offset, 1024)) # 1024 + +CONV_GPS_GLOBAL_IDS = [SENSOR_OFFSET["conv"] + i for i in CONV_GPS_CHANNELS] + +SENSOR_NAME_TO_ID = {name: idx for idx, name in enumerate(SENSOR_CONFIGS.keys())} +SENSOR_ID_TO_NAME = {idx: name for name, idx in SENSOR_NAME_TO_ID.items()} diff --git a/physicsnemo/experimental/datapipes/healda/configs/static_data.py b/physicsnemo/experimental/datapipes/healda/configs/static_data.py new file mode 100644 index 0000000000..5d4a2d5ff4 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/configs/static_data.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Static geospatial data loaders (orography, land fraction). + +These functions load time-invariant fields used as conditioning inputs. +Data paths are read from environment variables set in ``.env``. +""" + +import functools +import os + +import earth2grid +import numpy as np +import torch +import zarr + + +@functools.cache +def load_lfrac(hpx_level: int) -> torch.Tensor: + """Load land fraction data regridded to HEALPix NEST ordering. + + Reads from the zarr path specified by ``UFS_LAND_DATA_ZARR``. + + Args: + hpx_level: HEALPix resolution level. + """ + src_grid = earth2grid.latlon.equiangular_lat_lon_grid(nlat=768, nlon=1536) + hpx_grid = earth2grid.healpix.Grid( + level=hpx_level, pixel_order=earth2grid.healpix.NEST + ) + regridder = earth2grid.get_regridder(src_grid, hpx_grid) + + land_data_path = os.environ["UFS_LAND_DATA_ZARR"] + land_data = zarr.open_group(land_data_path) + land_fraction = land_data["lfrac"][:] + land_fraction = regridder(torch.from_numpy(land_fraction).to(torch.float64)) + return land_fraction + + +@functools.cache +def load_orography() -> np.ndarray: + """Load orography (surface elevation) on HEALPix level-6 NEST grid. + + Reads from the zarr path specified by ``UFS_HPX6_ZARR``. + """ + ufs_zarr_path = os.environ["UFS_HPX6_ZARR"] + group = zarr.open_group(ufs_zarr_path) + return group["orog"][:] diff --git a/physicsnemo/experimental/datapipes/healda/configs/variable_configs.py b/physicsnemo/experimental/datapipes/healda/configs/variable_configs.py new file mode 100644 index 0000000000..25d11e816e --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/configs/variable_configs.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Variable configurations for supported datasets.""" + +from physicsnemo.experimental.datapipes.healda.types import VariableConfig + +VARIABLE_CONFIGS = {} + +VARIABLE_CONFIGS["default"] = VariableConfig( + name="ufs", + levels=[1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50], + variables_3d=["Q", "U", "V", "T", "Z"], + variables_2d=[ + "tas", + "uas", + "vas", + "rlut", + "rsut", + "pressfc", + "pr", + "rsds", + "sst", + "sic", + "hfls", + "huss", + ], + variables_static=["orog", "lfrac"], +) + +VARIABLE_CONFIGS["era5"] = VariableConfig( + name="era5", + levels=[1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50], + variables_3d=["U", "V", "T", "Z", "Q"], + variables_2d=[ + "tcwv", + "tas", + "uas", + "vas", + "100u", + "100v", + "pres_msl", + "sst", + "sic", + ], + variables_static=["orog", "lfrac"], +) + +VARIABLE_CONFIGS["gfs"] = VariableConfig( + name="gfs", + levels=[1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50], + variables_3d=["U", "V", "T", "Z", "Q"], + variables_2d=[ + "tcwv", + "tas", + "uas", + "vas", + "100u", + "100v", + "pres_msl", + "sp", + ], + variables_static=["orog", "lfrac"], +) diff --git a/physicsnemo/experimental/datapipes/healda/dataset.py b/physicsnemo/experimental/datapipes/healda/dataset.py new file mode 100644 index 0000000000..3f20a7a3a4 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/dataset.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Map-style dataset combining ERA5 state with observations. + +``ObsERA5Dataset`` is the primary dataset class for HealDA training. It loads +ERA5 analysis state from an xarray DataArray, observations from an +``ObsLoader`` (e.g. ``UFSUnifiedLoader``), and applies a ``Transform`` +(e.g. ``ERA5ObsTransform``) to produce training batches. + +Temporal windowing, model-parallel rank slicing, and train/test splitting are +handled internally via ``FrameIndexGenerator`` and ``MultiCoordIndex``. + +Example usage:: + + from physicsnemo.experimental.datapipes.healda.dataset import ObsERA5Dataset + from physicsnemo.experimental.datapipes.healda.loaders.ufs_obs import UFSUnifiedLoader + from physicsnemo.experimental.datapipes.healda.transforms.era5_obs import ERA5ObsTransform + from physicsnemo.experimental.datapipes.healda.configs.variable_configs import VARIABLE_CONFIGS + + obs_loader = UFSUnifiedLoader( + data_path="/path/to/obs", + sensors=["atms", "mhs", "conv"], + normalization="zscore", + obs_context_hours=(-21, 3), + ) + transform = ERA5ObsTransform(variable_config=VARIABLE_CONFIGS["era5"]) + + dataset = ObsERA5Dataset( + era5_data=era5_xarray["data"], + obs_loader=obs_loader, + transform=transform, + variable_config=VARIABLE_CONFIGS["era5"], + split="train", + ) +""" + +from __future__ import annotations + +import asyncio +from typing import Union + +import numpy as np +import pandas as pd +import torch + +from physicsnemo.experimental.datapipes.healda.indexing import get_flat_indexer +from physicsnemo.experimental.datapipes.healda.loaders.era5 import get_batch_info +from physicsnemo.experimental.datapipes.healda.protocols import ObsLoader, Transform +from physicsnemo.experimental.datapipes.healda.time_utils import as_cftime +from physicsnemo.experimental.datapipes.healda.types import VariableConfig + +# HEALPix level-6 pixel count: 12 * 4^6 +NPIX_HPX6 = 12 * 4**6 + +# Default year held out for evaluation +DEFAULT_TEST_YEAR = 2022 + +# ERA5 time range available in the zarr store +ERA5_TIME_START = "2000-01-01 00:00:00" +ERA5_TIME_END = "2023-10-31 23:00:00" + + +class ObsERA5Dataset(torch.utils.data.Dataset): + """Map-style dataset loading ERA5 state + observations. + + Args: + era5_data: xarray DataArray with dimensions ``(time, variable, pixel)`` + containing the ERA5 state. Must have a ``"time"`` coordinate. + obs_loader: Any object implementing the ``ObsLoader`` protocol + (``async def sel_time(times) -> dict``). + transform: Any object implementing the ``Transform`` protocol + (``def transform(times, frames) -> dict``). + variable_config: ``VariableConfig`` describing the variables and levels. + split: ``"train"`` (year != ``DEFAULT_TEST_YEAR``), + ``"test"`` (year == ``DEFAULT_TEST_YEAR``), ``""`` (all), + or a list of years to include. + time_length: Number of frames per training window. + frame_step: Step size between frames (default 1). + model_rank: Model-parallel rank for time slicing. + model_world_size: Total model-parallel world size. + """ + + def __init__( + self, + era5_data, + obs_loader: ObsLoader, + transform: Transform, + variable_config: VariableConfig, + *, + split: Union[str, list[int]] = "", + time_length: int = 1, + frame_step: int = 1, + model_rank: int = 0, + model_world_size: int = 1, + ): + self.variable_config = variable_config + self.batch_info = get_batch_info(variable_config, time_step=6) + + # Accept either xr.DataArray or xr.Dataset["data"] + era5 = era5_data + era5 = era5.sel(time=slice(ERA5_TIME_START, ERA5_TIME_END)) + time = pd.to_datetime(era5["time"].values) + + mask = self._create_time_mask(time, split) + self._era5 = era5.isel(time=mask) + self._obs_loader = obs_loader + self.npix = NPIX_HPX6 + + self.time_length = time_length + self._indexer = get_flat_indexer( + self._era5, + [], + "time", + time_length=time_length, + frame_step=frame_step, + model_rank=model_rank, + model_world_size=model_world_size, + ) + self.transform = transform + + @staticmethod + def _create_time_mask( + time: pd.DatetimeIndex, split: Union[str, list[int]] + ) -> np.ndarray: + """Create a boolean mask for filtering times based on split.""" + if isinstance(split, str): + mask = { + "train": time.year != DEFAULT_TEST_YEAR, + "test": time.year == DEFAULT_TEST_YEAR, + "": np.ones_like(time, dtype=np.bool_), + }[split] + else: + mask = time.year.isin(split) + return mask + + def __len__(self): + return len(self._indexer) + + @property + def times(self): + """All available times in the dataset.""" + return pd.to_datetime(self._era5["time"].values) + + def _get_state(self, i): + coords = self._indexer[i] + state = self._era5.sel(variable=self.batch_info.channels).isel(coords) + return state.values + + def _get_times(self, i): + coords = self._indexer[i] + state = self._era5.isel(coords) + return pd.to_datetime(state.time) + + def _get_obs(self, i): + time = self._get_times(i) + return asyncio.run(self._obs_loader.sel_time(time))["obs"] + + def get(self, i): + """Load state + obs for a single sample index. + + Returns: + ``(times, frames)`` where ``times`` is a list of ``cftime`` objects + and ``frames`` is a list of dicts with ``"state"`` and ``"obs"``. + """ + times = [as_cftime(t) for t in self._get_times(i)] + state = self._get_state(i) + obs = self._get_obs(i) + time_per_rank = state.shape[0] + objs = [{"state": state[t], "obs": obs[t]} for t in range(time_per_rank)] + return times, objs + + def __getitems__(self, indexes): + """Batched access — called by DataLoader with batched sampler. + + Loads all samples, then applies the transform to produce the batch dict. + """ + times, objs = zip(*[self.get(i) for i in indexes]) + return self.transform.transform(times, objs) diff --git a/physicsnemo/experimental/datapipes/healda/indexing.py b/physicsnemo/experimental/datapipes/healda/indexing.py new file mode 100644 index 0000000000..d86de4fd51 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/indexing.py @@ -0,0 +1,227 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Temporal indexing infrastructure for HealDA datasets. + +Provides frame-level indexing with temporal windowing, striding, and +model-parallel rank slicing. Used by ``ObsERA5Dataset`` (and potentially +other map-style datasets) to convert a flat sample index into the set of +physical frame indices needed for a single training window. +""" + +import numpy as np +import torch + + +# --------------------------------------------------------------------------- +# Segment splitting +# --------------------------------------------------------------------------- + + +def split_array_contiguous(x): + """Split *x* into sub-arrays at points where the step size changes. + + This detects gaps in a time array (e.g. year boundaries, missing data) + and returns a list of contiguous segments. + """ + if x.size == 0: + return [] + + d = x[1] - x[0] + segments = [] + start = 0 + for i in range(1, x.size): + if (x[i] - x[i - 1]) != d: + segments.append(x[start:i]) + start = i + + if start < x.size: + segments.append(x[start:]) + + return segments + + +# --------------------------------------------------------------------------- +# Frame index generator +# --------------------------------------------------------------------------- + + +class FrameIndexGenerator: + """Generate frame indices with striding, permutation, and model-rank slicing. + + Given a 1-D time array (possibly with gaps), this class: + + 1. Splits the array into contiguous segments. + 2. Computes the number of valid sliding windows per segment. + 3. Maps a logical sample index to the corresponding physical frame indices, + applying temporal striding and model-rank slicing. + + Args: + times: 1-D array of timestamps (used only for contiguity detection). + time_length: Number of frames per window. + frame_step: Step size between consecutive frames in a window. + model_rank: This rank's index for model-parallel time slicing. + model_world_size: Total number of model-parallel ranks. + """ + + def __init__( + self, + times, + time_length: int, + frame_step: int, + model_rank: int, + model_world_size: int, + ): + self.time_length = time_length + self.frame_step = frame_step + self.model_rank = model_rank + self.model_world_size = model_world_size + + self.segments = split_array_contiguous(times) + self.sizes = [len(segment) for segment in self.segments] + self.total_samples = sum(self.sizes) + + frames_per_window = (time_length - 1) * frame_step + 1 + self.segment_valid_lengths = [] + for segment in self.segments: + valid = len(segment) - frames_per_window + 1 + self.segment_valid_lengths.append(max(valid, 0)) + + self.cumulative_valid_sizes = [0] + list( + np.cumsum(self.segment_valid_lengths) + ) + self.cumulative_sizes = [0] + list(np.cumsum(self.sizes)) + self.valid_length = sum(self.segment_valid_lengths) + + def generate_frame_indices(self, sample_indices: torch.Tensor) -> list[list[int]]: + """Generate frame indices from sample indices. + + For each logical sample index, returns the physical frame indices after + applying striding and model-rank slicing. + + Args: + sample_indices: Tensor of logical sample indices. + + Returns: + List of frame index lists, one per sample. + """ + frame_idxs = [] + for sample_idx in sample_indices: + physical_idx = self._map_logical_to_physical(sample_idx) + frames = list( + range( + physical_idx, + physical_idx + self.time_length * self.frame_step, + self.frame_step, + ) + ) + # Model-parallel rank slicing + n = self.time_length // self.model_world_size + frames = frames[self.model_rank * n : (self.model_rank + 1) * n] + frame_idxs.append(frames) + return frame_idxs + + def _map_logical_to_physical(self, logical_idx: int) -> int: + """Map a logical sample index to a physical frame index across segments.""" + if logical_idx >= self.total_samples: + raise IndexError( + f"Sample index {logical_idx} out of bounds " + f"for {self.total_samples} samples" + ) + + segment_idx = 0 + for i, cum_size in enumerate(self.cumulative_valid_sizes[1:], 1): + if logical_idx < cum_size: + segment_idx = i - 1 + break + + segment_start = self.cumulative_sizes[segment_idx] + offset_within_segment = logical_idx - self.cumulative_valid_sizes[segment_idx] + return segment_start + offset_within_segment + + def get_valid_length(self) -> int: + """Total number of valid sample windows across all segments.""" + return self.valid_length + + +# --------------------------------------------------------------------------- +# Multi-coordinate index +# --------------------------------------------------------------------------- + + +class MultiCoordIndex: + """Map a flat integer index to multi-dimensional xarray coordinates. + + Combines arbitrary sample dimensions (e.g. ensemble members) with a + temporal frame dimension managed by a ``FrameIndexGenerator``. + + Args: + sample_dims: Names of the non-temporal sample dimensions. + sample_sizes: Sizes of each sample dimension. + frame_dim: Name of the temporal dimension. + frame_indexer: A ``FrameIndexGenerator`` instance. + """ + + def __init__(self, sample_dims, sample_sizes, frame_dim, frame_indexer): + self.sample_dims = sample_dims + self.sample_sizes = sample_sizes + self.sequence_dim = frame_dim + self._frame_indexer = frame_indexer + + def __len__(self): + n = 1 + for s in self.sample_sizes: + n *= s + n *= self._frame_indexer.get_valid_length() + return n + + def __getitem__(self, i): + shape = [*self.sample_sizes, self._frame_indexer.get_valid_length()] + *index, seq_index = np.unravel_index(i, shape) + frames = self._frame_indexer.generate_frame_indices([seq_index])[0] + coords = dict(zip(self.sample_dims, index)) + coords[self.sequence_dim] = frames + return coords + + +def get_flat_indexer( + ds, + sample_dims, + frame_dim, + time_length, + frame_step, + model_rank, + model_world_size, +): + """Create a ``MultiCoordIndex`` for an xarray Dataset. + + Args: + ds: xarray Dataset with the relevant dimensions. + sample_dims: List of non-temporal dimension names. + frame_dim: Name of the temporal dimension. + time_length: Frames per window. + frame_step: Step between frames. + model_rank: Model-parallel rank. + model_world_size: Model-parallel world size. + """ + times = ds[frame_dim].values + frame_indexer = FrameIndexGenerator( + times=times, + time_length=time_length, + frame_step=frame_step, + model_rank=model_rank, + model_world_size=model_world_size, + ) + sample_sizes = [ds.sizes[dim] for dim in sample_dims] + return MultiCoordIndex(sample_dims, sample_sizes, frame_dim, frame_indexer) diff --git a/physicsnemo/experimental/datapipes/healda/loaders/__init__.py b/physicsnemo/experimental/datapipes/healda/loaders/__init__.py new file mode 100644 index 0000000000..3159bfe656 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/loaders/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/physicsnemo/experimental/datapipes/healda/loaders/era5.py b/physicsnemo/experimental/datapipes/healda/loaders/era5.py new file mode 100644 index 0000000000..09470cb094 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/loaders/era5.py @@ -0,0 +1,207 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ERA5 analysis loader and normalization statistics. + +Provides ``get_batch_info`` for normalization constants (mean/std per channel) +and ERA5-specific variable name mapping from ECMWF conventions to the standard +names used internally. +""" + +import os +import pathlib +from typing import Optional + +import numpy as np +import pandas as pd +import xarray + +from physicsnemo.experimental.datapipes.healda.loaders.zarr_loader import NO_LEVEL, ZarrLoader +from physicsnemo.experimental.datapipes.healda.types import BatchInfo, TimeUnit, VariableConfig + +__all__ = ["ERA5Loader", "get_batch_info"] + +SST_LAND_FILL_VALUE = 290 +HPX_LEVEL = 6 + + +class ERA5Loader: + """Load ERA5 reanalysis state via async zarr I/O. + + Wraps ``ZarrLoader`` with ERA5-specific variable naming conventions + and returns ``{"state": ndarray, "label": list[int]}``. + + Args: + variable_config: Describes which 2D/3D variables and levels to load. + era5_zarr_path: Path to ERA5 zarr store. If *None*, reads from + ``ERA5_74VAR`` environment variable. + """ + + def __init__(self, variable_config: VariableConfig, era5_zarr_path: str | None = None): + self.variable_config = variable_config + variables_2d = [ + "sstk", "ci", "msl", "10u", "10v", "2t", "tcwv", "100u", "100v", + ] + path = era5_zarr_path or os.environ.get("ERA5_74VAR", os.environ.get("V6_ERA5_ZARR", "")) + self._loader = ZarrLoader( + path=path, + variables_3d=["u", "v", "t", "z", "q"], + variables_2d=variables_2d, + level_coord_name="levels", + levels=variable_config.levels, + ) + + async def sel_time(self, times): + data = await self._loader.sel_time(times) + self._convert_to_standard(data) + shape = (len(times), 4**HPX_LEVEL * 12) + state = _collect_fields( + _get_index(self.variable_config), data, shape=shape + ) + state = np.moveaxis(state, 0, 1) # c t x -> t c x + return { + "state": state, + "label": [1] * len(times), # 1 = era5 label index + } + + def _convert_to_standard(self, data): + if ("sstk", NO_LEVEL) in data: + sstk = data[("sstk", NO_LEVEL)] + if not np.ma.isMaskedArray(sstk): + sstk = np.ma.masked_invalid(sstk) + data[("sstk", NO_LEVEL)] = sstk.filled(SST_LAND_FILL_VALUE) + + if ("ci", NO_LEVEL) in data: + ci = data[("ci", NO_LEVEL)] + if not np.ma.isMaskedArray(ci): + ci = np.ma.masked_invalid(ci) + data[("ci", NO_LEVEL)] = ci.filled(0) + + if ("tp", NO_LEVEL) in data: + water_density = 1000 + seconds_per_hour = 3600 + data[("tp", NO_LEVEL)] = ( + data[("tp", NO_LEVEL)] * water_density / seconds_per_hour + ) + + fields_out_map = { + "tclw": "cllvi", "tciw": "clivi", "2t": "tas", "10u": "uas", + "10v": "vas", "100u": "100u", "100v": "100v", "msl": "pres_msl", + "tp": "pr", "sstk": "sst", "ci": "sic", "tcwv": "prw", + "u": "U", "v": "V", "t": "T", "z": "Z", "q": "Q", + } + for key, value in list(data.items()): + match key: + case (name, level): + if name in fields_out_map: + data[(fields_out_map[name], level)] = value + + +# --------------------------------------------------------------------------- +# Normalization statistics +# --------------------------------------------------------------------------- + + +def get_batch_info( + config: VariableConfig, + time_step: int = 1, + time_unit: TimeUnit = TimeUnit.HOUR, +) -> BatchInfo: + return BatchInfo( + channels=[_encode_channel(tup) for tup in _get_index(config).tolist()], + scales=_get_std(config), + center=_get_mean(config), + time_step=time_step, + time_unit=time_unit, + ) + + +def _get_index(config: VariableConfig): + return pd.MultiIndex.from_tuples( + [(v, level) for v in config.variables_3d for level in config.levels] + + [(v, NO_LEVEL) for v in config.variables_2d], + names=["variable", "level"], + ) + + +def _collect_fields( + index, + data: dict[tuple[str, int | None], np.ndarray], + shape, + prefix: Optional[str] = None, +) -> np.ndarray: + out = np.full( + shape=(index.size,) + shape, + dtype=np.float32, + fill_value=np.nan, + ) + for i, (var, lev) in enumerate(index): + key = (prefix, var, lev) if prefix is not None else (var, lev) + if key in data: + out[i] = data[key] + return out + + +def _get_mean(config: VariableConfig) -> np.ndarray: + return _get_nearest_stats(config)["mean"].values + + +def _get_std(config: VariableConfig) -> np.ndarray: + return _get_nearest_stats(config)["std"].values + + +def _encode_channel(channel) -> str: + name, level = channel + if level != NO_LEVEL: + return f"{name}{level}" + else: + return name + + +def _load_raw_stats(config: VariableConfig) -> pd.DataFrame: + if config.name == "ufs": + file_name = "ufs_v0_stats.csv" + elif config.name == "era5": + file_name = "era5_13_levels_stats.csv" + else: + raise ValueError(f"Unknown dataset: {config.name}") + path = pathlib.Path(__file__).parent.parent / "configs" / file_name + return pd.read_csv(path).set_index(["variable", "level"]) + + +def _get_nearest_stats(config: VariableConfig): + raw = _load_raw_stats(config) + idx = _get_index(config) + + mapped_idx = [] + for var, level in idx: + if level != NO_LEVEL: + available = raw.loc[var].index.values + nearest = available[np.abs(available - level).argmin()] + mapped_idx.append((var, nearest)) + else: + mapped_idx.append((var, level)) + + mapped_idx = pd.MultiIndex.from_tuples(mapped_idx, names=["variable", "level"]) + return raw.loc[mapped_idx] + + +def open_era5_xarray(path: str | None = None, **kwargs) -> xarray.Dataset: + """Open the ERA5 74-variable zarr dataset as xarray. + + Args: + path: Zarr store path. If *None*, reads ``ERA5_74VAR`` env var. + """ + path = path or os.environ["ERA5_74VAR"] + return xarray.open_zarr(path, **kwargs) diff --git a/physicsnemo/experimental/datapipes/healda/loaders/ufs_obs.py b/physicsnemo/experimental/datapipes/healda/loaders/ufs_obs.py new file mode 100644 index 0000000000..351db52777 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/loaders/ufs_obs.py @@ -0,0 +1,324 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""UFS Unified observation loader for combined satellite and conventional data. + +``UFSUnifiedLoader`` implements the ``ObsLoader`` protocol, loading +parquet-based observations produced by the ETL pipeline. It provides +quality-control filtering, normalization, and DA-window alignment. + +Example:: + + loader = UFSUnifiedLoader( + data_path="/path/to/processed_obs", + sensors=["atms", "mhs", "conv"], + obs_context_hours=(-3, 3), + normalization="zscore", + ) + result = await loader.sel_time(pd.DatetimeIndex([...])) + tables = result["obs"] # list[pa.Table], one per timestamp +""" + +import functools +import io +import os +from datetime import datetime +from typing import List, Literal + +import fsspec +import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.compute as pc +import pyarrow.parquet as pq + +from physicsnemo.experimental.datapipes.healda.configs.combined_schema import ( + GLOBAL_CHANNEL_ID, + SENSOR_ID, + get_combined_observation_schema, +) +from physicsnemo.experimental.datapipes.healda.configs.sensors import SENSOR_CONFIGS +from physicsnemo.experimental.datapipes.healda.transforms.obs_filtering import filter_observations + +LOCAL_CHANNEL_ID = pa.field("local_channel_id", pa.uint16()) + + +def get_channel_table(data_path: str, filesystem=None): + """Load the channel metadata table used for normalization. + + Args: + data_path: Root path to the processed observation data. + filesystem: Optional fsspec filesystem for remote access. + """ + return UFSUnifiedLoader( + data_path, + sensors=[], + obs_context_hours=(-3, 3), + normalization="zscore", + filesystem=filesystem, + ).channel_table + + +class UFSUnifiedLoader: + """Unified loader for UFS observation data in combined parquet format. + + Handles both satellite and conventional observations, providing an async + interface compatible with ``ObsERA5Dataset``. + + Args: + data_path: Path to the processed observation data directory. + sensors: List of sensor names to load (e.g. ``["atms", "mhs", "conv"]``). + filesystem: Optional fsspec filesystem for remote access (e.g. S3). + normalization: Normalization method (``"fixed_range"`` or ``"zscore"``). + innovation_type: Innovation type (``"none"``, ``"adjusted"``, ``"unadjusted"``). + qc_filter: Whether to apply quality-control filtering. + filter_innovation: Whether to filter based on innovation values. + check_corrected: Whether to validate corrected observation values. + obs_context_hours: ``(start, end)`` hours relative to target time. + data_spacing: Hours between data points (default 3). + drop_obs_channel_ids: Global channel IDs to drop. + conv_uv_in_situ_only: Exclude satellite UV (keep in-situ only). + conv_gps_level1_only: Exclude GPS T/Q (keep bending angle). + """ + + def __init__( + self, + data_path: str, + sensors: List[str], + filesystem: fsspec.AbstractFileSystem | None = None, + normalization: Literal["fixed_range", "zscore"] = "fixed_range", + innovation_type: Literal["none", "adjusted", "unadjusted"] = "none", + qc_filter: bool = False, + filter_innovation: bool = False, + check_corrected: bool = True, + obs_context_hours: tuple[int, int] = (-24, 0), + data_spacing: int = 3, + drop_obs_channel_ids: list[int] | None = None, + conv_uv_in_situ_only: bool = False, + conv_gps_level1_only: bool = False, + ): + self.data_path = data_path + self.sensors = sensors + self.fs = filesystem + self.normalization = normalization + self.innovation_type = innovation_type + self.qc_filter = qc_filter + self.filter_innovation = filter_innovation + self.check_corrected = check_corrected + self.obs_context_hours = obs_context_hours + self.data_spacing = data_spacing + self.drop_obs_channel_ids = ( + list(drop_obs_channel_ids) if drop_obs_channel_ids is not None else [] + ) + self.conv_uv_in_situ_only = conv_uv_in_situ_only + self.conv_gps_level1_only = conv_gps_level1_only + + for sensor in self.sensors: + if sensor not in SENSOR_CONFIGS: + raise ValueError( + f"Unconfigured sensor: {sensor}. " + f"Available: {list(SENSOR_CONFIGS.keys())}" + ) + + self._channel_table = None + + @functools.cached_property + def _base_schema(self) -> pa.Schema: + return get_combined_observation_schema() + + @functools.cached_property + def _read_columns(self) -> list[str]: + return self._base_schema.names + + @property + def output_schema(self) -> pa.Schema: + return self._base_schema.append(LOCAL_CHANNEL_ID).append(SENSOR_ID) + + @functools.cached_property + def channel_table(self) -> pa.Table: + """Load the channel table for normalization.""" + channel_table_path = os.path.join(self.data_path, "channel_table.parquet") + if self.fs is not None: + file = io.BytesIO(self.fs.cat_file(channel_table_path)) + else: + file = channel_table_path + + table = pq.read_table(file) + sensor_id = np.asarray(table["sensor_id"]) + local_channel_ids = [] + offset = 0 + for i in range(len(sensor_id)): + if sensor_id[i] != sensor_id[i - 1]: + offset = i + local_channel_ids.append(i - offset) + array = pa.array(local_channel_ids).cast(LOCAL_CHANNEL_ID.type) + return table.append_column(LOCAL_CHANNEL_ID, array) + + def _get_interval_times(self, dt: datetime) -> pd.DatetimeIndex: + start, end = self.obs_context_hours + start += self.data_spacing + return pd.date_range( + dt + pd.Timedelta(hours=start), + dt + pd.Timedelta(hours=end), + freq=f"{self.data_spacing}h", + ) + + def _get_parquet_files_to_read(self, interval_times: pd.DatetimeIndex): + required_dates = {t.strftime("%Y%m%d") for t in interval_times} + for sensor in self.sensors: + for date in required_dates: + file_path = os.path.join( + self.data_path, sensor, f"{date}", "0.parquet" + ) + yield (sensor, file_path) + + def _iterate_parquet_da_windows(self, parquet_path, target_windows): + try: + if self.fs is not None: + file = io.BytesIO(self.fs.cat_file(parquet_path)) + else: + file = parquet_path + + parquet = pq.ParquetFile(file) + schema = parquet.schema_arrow + da_idx = schema.get_field_index("DA_window") + + for row_group_idx in range(parquet.num_row_groups): + stats = ( + parquet.metadata.row_group(row_group_idx) + .column(da_idx) + .statistics + ) + row_group_lo, row_group_hi = stats.min, stats.max + + this_window = None + for w in target_windows: + if row_group_lo <= w <= row_group_hi: + this_window = w + + if this_window is None: + continue + + table = parquet.read_row_group( + row_group_idx, columns=self._read_columns + ) + + if row_group_lo != row_group_hi: + mask = pc.is_in( + table["DA_window"], pa.array(list(target_windows)) + ) + table = table.filter(mask) + + if table.num_rows == 0: + continue + + yield this_window, table + except (FileNotFoundError, OSError): + return + + def _filter_observations(self, table: pa.Table) -> pa.Table: + return filter_observations( + table, + self.qc_filter, + conv_uv_in_situ_only=self.conv_uv_in_situ_only, + conv_gps_level1_only=self.conv_gps_level1_only, + ) + + def _normalize_observations(self, table: pa.Table) -> pa.Table: + if self.normalization == "fixed_range": + normalized = pc.divide(pc.subtract(table["Observation"], 0), 400 - 0) + elif self.normalization == "zscore": + normalized = pc.divide( + pc.subtract(table["Observation"], table["mean"]), + table["stddev"], + ) + else: + raise ValueError(f"Unknown normalization type: {self.normalization}") + return table.set_column( + table.schema.get_field_index("Observation"), + "Observation", + normalized, + ) + + _extra_channel_fields = ["min_valid", "max_valid", "is_conv", "mean", "stddev"] + + def _add_channel_metadata(self, table): + return table.join( + self.channel_table.select( + [ + GLOBAL_CHANNEL_ID.name, + LOCAL_CHANNEL_ID.name, + SENSOR_ID.name, + *self._extra_channel_fields, + ] + ), + GLOBAL_CHANNEL_ID.name, + ) + + async def sel_time(self, times: pd.DatetimeIndex) -> dict: + """Load observation data for specified times. + + Args: + times: Target times to load data for. + + Returns: + ``{"obs": [pa.Table, ...]}``, one table per timestamp. + """ + all_times = set() + for t in times: + interval_times = self._get_interval_times(t) + all_times.update(interval_times) + + interval_times = pd.DatetimeIndex(sorted(all_times)) + files_to_read = self._get_parquet_files_to_read(interval_times) + + tables = {} + for sensor, file_path in files_to_read: + for interval_time, table in self._iterate_parquet_da_windows( + file_path, interval_times + ): + table = self._add_channel_metadata(table) + table = self._filter_observations(table) + if self.drop_obs_channel_ids: + mask = pc.is_in( + table[GLOBAL_CHANNEL_ID.name], + pa.array(self.drop_obs_channel_ids).cast( + table[GLOBAL_CHANNEL_ID.name].type + ), + ) + table = table.filter(pc.invert(mask)) + table = self._normalize_observations(table) + table = table.drop(self._extra_channel_fields) + tables.setdefault(interval_time, []).append(table) + + def process(t): + all_tables = [] + for interval_time in self._get_interval_times(t): + for table in tables.get(interval_time, []): + all_tables.append(table) + + if not all_tables: + return empty + + table = pa.concat_tables(all_tables) + return table.cast(self.output_schema) + + empty = self._get_empty_table() + return {"obs": [process(t) for t in times]} + + def _get_empty_table(self): + empty_arrays = [] + for field in self.output_schema: + empty_arrays.append(pa.array([], type=field.type)) + return pa.table(empty_arrays, schema=self.output_schema) diff --git a/physicsnemo/experimental/datapipes/healda/loaders/zarr_loader.py b/physicsnemo/experimental/datapipes/healda/loaders/zarr_loader.py new file mode 100644 index 0000000000..e9c07aa810 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/loaders/zarr_loader.py @@ -0,0 +1,183 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Async zarr loader for 2D and 3D atmospheric variables. + +``ZarrLoader`` provides concurrent I/O for multiple variables and levels +using ``asyncio.gather``. It is used by ``ERA5Loader`` to read ERA5 data +from zarr stores. +""" + +import asyncio +import urllib.parse + +import cftime +import numpy as np +import pandas as pd +import xarray as xr +import zarr +import zarr.storage +from zarr.core.sync import sync + +NO_LEVEL = -1 # sentinel for 2D (surface) variables that lack a pressure level + + +def _is_local(path): + url = urllib.parse.urlparse(path) + return url.scheme == "" + + +async def _getitem(array, index): + return await array.get_orthogonal_selection(index) + + +async def _getitem_static(array, num_times: int): + field = await array.getitem((slice(None),) * array.ndim) + field = field[None, ...] + return np.broadcast_to(field, (num_times, *field.shape[1:])) + + +class ZarrLoader: + """Load 2D and 3D variables from a zarr dataset with async I/O. + + Args: + path: Zarr store path (local or remote). + variables_3d: List of 3D variable names. + variables_2d: List of 2D variable names. + levels: Pressure levels to extract. + level_coord_name: Name of the vertical coordinate in the zarr store. + storage_options: fsspec storage options for remote stores. + time_sel_method: Passed to ``pd.Index.get_indexer(method=)``. + variables_static: List of static (time-invariant) variable names. + """ + + def __init__( + self, + *, + path: zarr.storage.StoreLike, + variables_3d, + variables_2d, + levels, + level_coord_name: str = "", + storage_options=None, + time_sel_method: str | None = None, + variables_static: list[str] = [], + ): + self.time_sel_method = time_sel_method + self.variables_2d = variables_2d + self.variables_3d = variables_3d + self.levels = levels + self.variables_static = variables_static + + if isinstance(path, str) and _is_local(path): + storage_options = None + + self.group = sync( + zarr.api.asynchronous.open_group( + path, + storage_options=storage_options, + use_consolidated=True, + mode="r", + ) + ) + + if self.variables_3d: + self.inds = sync(self._get_vertical_indices(level_coord_name, levels)) + + self._arrays = {} + self._has_time = bool(self.variables_3d or self.variables_2d) + if self._has_time: + time_num, self.units, self.calendar = sync(self._get_time()) + if np.issubdtype(time_num.dtype, np.datetime64): + self.times = pd.DatetimeIndex(time_num) + else: + self.times = xr.CFTimeIndex( + cftime.num2date(time_num, units=self.units, calendar=self.calendar) + ) + + async def sel_time(self, times) -> dict[tuple[str, int], np.ndarray]: + """Load data for the given times. + + Returns: + Dict with keys ``(variable_name, level)`` where ``level == -1`` + for 2D variables. + """ + if self._has_time: + index_in_loader = self.times.get_indexer( + times, method=self.time_sel_method + ) + if (index_in_loader == -1).any(): + raise KeyError("Index not found.") + else: + index_in_loader = np.arange(len(times)) + return await self._get(index_in_loader) + + async def _get_time(self): + time = await self.group.get("time") + time_data = await time.getitem(slice(None)) + return time_data, time.attrs.get("units"), time.attrs.get("calendar") + + async def _get_vertical_indices(self, coord_name, levels): + levels_var = await self.group.get(coord_name) + levels_arr = await levels_var.getitem(slice(None)) + return pd.Index(levels_arr).get_indexer(levels, method="nearest") + + async def _get_array(self, name): + if name not in self._arrays: + self._arrays[name] = await self.group.get(name) + return self._arrays[name] + + async def _get(self, t) -> dict[tuple[str, int | None], np.ndarray]: + tasks = [] + keys = [] + + for name in self.variables_3d: + arr = await self._get_array(name) + if arr is None: + raise KeyError(name) + for level, k in zip(self.levels, self.inds): + key = (name, level) + k_indexer = [k] + value = _getitem(arr, (t, k_indexer)) + tasks.append(value) + keys.append(key) + + for name in self.variables_2d: + arr = await self._get_array(name) + if arr is None: + raise KeyError(name) + key = (name, NO_LEVEL) + value = _getitem(arr, (t,)) + tasks.append(value) + keys.append(key) + + for name in self.variables_static: + arr = await self._get_array(name) + if arr is None: + raise KeyError(name) + key = (name, NO_LEVEL) + value = _getitem_static(arr, len(t)) + tasks.append(value) + keys.append(key) + + arrays = await asyncio.gather(*tasks) + out = {} + for key, array in zip(keys, arrays): + name, _ = key + if name in self.variables_3d: + out[key] = np.squeeze(array, 1) + else: + out[key] = array + + return out diff --git a/physicsnemo/experimental/datapipes/healda/prefetch.py b/physicsnemo/experimental/datapipes/healda/prefetch.py new file mode 100644 index 0000000000..d11c5712e5 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/prefetch.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Background-thread prefetch with a dedicated CUDA stream. + +``prefetch_map`` wraps any iterable (typically a DataLoader) and applies a +transform function in a background thread using a separate CUDA stream. +This hides the CPU-to-GPU transfer and GPU featurization latency behind the +training forward/backward pass. +""" + +import dataclasses +import queue +import threading +from typing import Any, Callable, Iterable, Optional + +import torch +from torch.utils.data import DataLoader + + +class _Done: + pass + + +class _PrefetchIterator: + """Process batches asynchronously in a background thread. + + The background thread uses a separate CUDA stream so that data movement + and featurization do not block the main training stream. + + Args: + dataloader: Any iterable of batches. + transform: ``batch -> batch`` function to run on the background stream. + queue_size: Bounded queue depth (default 2). + cuda_stream: CUDA stream for background work (created if *None*). + """ + + def __init__( + self, + dataloader: Iterable, + transform: Callable[[Any], Any], + queue_size: int = 2, + cuda_stream: Optional[torch.cuda.Stream] = None, + ): + self.dataloader = dataloader + self.transform = transform + self.queue_size = queue_size + self.cuda_stream = cuda_stream or torch.cuda.Stream() + + self.queue: queue.Queue = queue.Queue(maxsize=queue_size) + self.thread: Optional[threading.Thread] = None + self.stop_event = threading.Event() + + self.dataloader_iter = None + self._started = False + + # -- background worker -------------------------------------------------- + + def _worker(self): + try: + while not self.stop_event.is_set(): + try: + batch = next(self.dataloader_iter) + except StopIteration: + self.queue.put((_Done, None)) + break + + with torch.cuda.stream(self.cuda_stream): + processed_batch = self.transform(batch) + + self.cuda_stream.synchronize() + self.queue.put((processed_batch, None)) + except Exception as e: + self.queue.put((None, e)) + + # -- lifecycle ---------------------------------------------------------- + + def _start(self): + if self._started: + return + self.dataloader_iter = iter(self.dataloader) + self.stop_event.clear() + self.thread = threading.Thread(target=self._worker, daemon=True) + self.thread.start() + self._started = True + + def _stop(self): + if not self._started: + return + self.stop_event.set() + if self.thread and self.thread.is_alive(): + self.thread.join(timeout=1.0) + self._started = False + + # -- iterator protocol -------------------------------------------------- + + def __len__(self): + return len(self.dataloader) + + def __iter__(self): + self._start() + return self + + @staticmethod + def _record_stream(x, stream): + """Tell the CUDA caching allocator that *stream* also owns *x*.""" + if isinstance(x, torch.Tensor): + x.record_stream(stream) + elif isinstance(x, list): + for item in x: + _PrefetchIterator._record_stream(item, stream) + elif isinstance(x, dict): + for item in x.values(): + _PrefetchIterator._record_stream(item, stream) + elif dataclasses.is_dataclass(x): + for field in dataclasses.fields(x): + _PrefetchIterator._record_stream(getattr(x, field.name), stream) + + def __next__(self): + if not self._started: + raise RuntimeError("Iterator not started. Call __iter__ first.") + + batch, error = self.queue.get() + + if error is not None: + raise error + + if batch is _Done: + self._stop() + raise StopIteration + + # Inform the allocator that the consumer stream also uses these tensors + self._record_stream(batch, torch.cuda.current_stream()) + return batch + + def __del__(self): + self._stop() + + +def prefetch_map( + dataloader: DataLoader, + transform: Callable[[Any], Any], + queue_size: int = 2, + cuda_stream: Optional[torch.cuda.Stream] = None, +) -> _PrefetchIterator: + """Wrap a DataLoader with background prefetching and GPU transforms. + + Args: + dataloader: Source of batches (typically a PyTorch DataLoader). + transform: Function applied to each batch on a background CUDA stream. + queue_size: Maximum number of pre-processed batches to buffer. + cuda_stream: CUDA stream for background work (created if *None*). + + Returns: + An iterable that yields pre-processed, GPU-resident batches. + + Example:: + + loader = prefetch_map( + dataloader, + lambda batch: transform.device_transform(batch, device), + queue_size=1, + ) + for batch in loader: + # batch is already on GPU + loss = model(**batch) + """ + return _PrefetchIterator(dataloader, transform, queue_size, cuda_stream) diff --git a/physicsnemo/experimental/datapipes/healda/protocols.py b/physicsnemo/experimental/datapipes/healda/protocols.py new file mode 100644 index 0000000000..cf13295407 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/protocols.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Protocols defining the extension points for HealDA data loading. + +Custom data sources and transforms can be plugged into the HealDA pipeline by +implementing these protocols. No inheritance is required -- any class with the +right method signatures will satisfy the protocol. + +See ``UFSUnifiedLoader`` for a reference ``ObsLoader`` implementation, and +``ERA5ObsTransform`` for a reference ``Transform`` / ``DeviceTransform`` +implementation. +""" + +from typing import Any, Protocol, runtime_checkable + +import cftime +import pandas as pd +import torch + + +@runtime_checkable +class ObsLoader(Protocol): + """Load observations for a set of timestamps. + + Implementations fetch observation data (satellite, conventional, or other) + for the requested times and return it as a dictionary. The method is async + to allow concurrent I/O when composed with other loaders. + + The canonical return key is ``"obs"`` mapping to a ``list[pa.Table]`` + with one table per requested timestamp. Custom loaders may use different + keys as long as the downstream ``Transform`` expects them. + + Example:: + + class MyObsLoader: + async def sel_time(self, times): + tables = [load_obs_for(t) for t in times] + return {"obs": tables} + """ + + async def sel_time(self, times: pd.DatetimeIndex) -> dict[str, list[Any]]: + """Load observation data for the given timestamps. + + Args: + times: Timestamps to load observations for. + + Returns: + Dictionary mapping field names to per-timestep data lists. + """ + ... + + +class Transform(Protocol): + """CPU-side batch transform, called inside DataLoader worker processes. + + Converts raw loaded frames (state arrays + observation tables) into an + intermediate batch dictionary suitable for ``pin_memory`` and collation. + Must NOT use CUDA. + + Args: + times: ``list[list[cftime.datetime]]`` shaped ``(batch, time_per_sample)``. + frames: ``list[list[dict]]`` shaped ``(batch, time_per_sample)``. + Each inner dict has keys from the loaders (e.g. ``"state"``, ``"obs"``). + + Returns: + Batch dict with tensors (``target``, ``condition``, time encodings, etc.). + """ + + def transform( + self, + times: list[list[cftime.datetime]], + frames: list[list[dict[str, Any]]], + ) -> dict[str, Any]: ... + + +class DeviceTransform(Protocol): + """GPU-side transform, called in the ``prefetch_map`` background thread. + + Moves the CPU batch to the target device and performs GPU-accelerated + featurization (e.g. observation metadata computation, HEALPix pixel lookup). + + Args: + batch: Output of ``Transform.transform()``. + device: Target ``torch.device``. + + Returns: + Batch dict with all tensors on the target device. + """ + + def device_transform( + self, batch: dict[str, Any], device: torch.device + ) -> dict[str, Any]: ... diff --git a/physicsnemo/experimental/datapipes/healda/samplers.py b/physicsnemo/experimental/datapipes/healda/samplers.py new file mode 100644 index 0000000000..5254e91c16 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/samplers.py @@ -0,0 +1,178 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Samplers and multi-loader utilities for cache-friendly distributed training. + +``ChunkedDistributedSampler`` yields indices in contiguous chunks so that +data backed by chunked storage (e.g. zarr) benefits from sequential I/O. + +``RoundRobinLoader`` interleaves multiple DataLoaders — typically one per +worker, each with its own ``ChunkedDistributedSampler`` — to provide an +iterable-style interface over map-style datasets with per-worker chunk +affinity. +""" + +import random + +import torch +import torch.distributed +import torch.utils.data + + +# --------------------------------------------------------------------------- +# Chunked distributed sampler +# --------------------------------------------------------------------------- + + +class ChunkedDistributedSampler(torch.utils.data.Sampler): + """A distributed sampler that yields indices in contiguous chunks. + + Within each chunk, indices are sequential (optionally shuffled within the + chunk). Chunks themselves can be shuffled across epochs. This pattern is + critical when the underlying dataset caches data at chunk granularity, as + it ensures sequential access within each cache window. + + The sampler is infinite: after exhausting all chunks it advances the epoch + and re-shuffles. + + Args: + dataset: Map-style dataset. + chunk_size: Number of contiguous indices per chunk. + rank: This worker's global rank. + num_replicas: Total number of workers (data-parallel * per-GPU workers). + shuffle: Whether to shuffle the order of chunks. + shuffle_within_chunk: Whether to shuffle indices within each chunk. + drop_last: Whether to drop incomplete trailing chunks. + seed: Random seed (broadcast from rank 0 when distributed). + sampler_fn: Optional custom sampler over chunk indices. + """ + + def __init__( + self, + dataset: torch.utils.data.Dataset, + chunk_size: int = 1, + rank=0, + num_replicas=1, + shuffle=False, + shuffle_within_chunk=False, + drop_last=True, + seed=42, + sampler_fn=None, + ): + super().__init__() + self.n = len(dataset) + nchunks = self.n // chunk_size + chunks = list(range(nchunks)) + + if torch.distributed.is_initialized(): + seed = torch.tensor(seed).cuda() + torch.distributed.broadcast(seed, src=0) + seed = seed.item() + + self._chunk_sampler = ( + sampler_fn(chunks) + if sampler_fn is not None + else torch.utils.data.DistributedSampler( + chunks, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + ) + ) + self.chunk_size = chunk_size + self.shuffle_within_chunk = shuffle_within_chunk + self.seed = seed + self.rank = rank + self.epoch = 0 + self.index_within_chunk = 0 + self._chunk_iter = iter(self._chunk_sampler) + self._current_chunk_indices = None + + if self.shuffle_within_chunk: + self.rng = random.Random(seed + rank) + + def set_epoch(self, epoch): + try: + self._chunk_sampler.set_epoch(epoch) + except AttributeError: + pass + self.epoch = epoch + + def __len__(self): + return self.n + + def __iter__(self): + return self + + def __next__(self): + if self.index_within_chunk == 0: + try: + self.active_chunk = next(self._chunk_iter) + except StopIteration: + self.set_epoch(self.epoch + 1) + self._chunk_iter = iter(self._chunk_sampler) + raise StopIteration() + + chunk_start = self.active_chunk * self.chunk_size + self._current_chunk_indices = list( + range(chunk_start, chunk_start + self.chunk_size) + ) + + if self.shuffle_within_chunk: + self.rng.shuffle(self._current_chunk_indices) + + i = self._current_chunk_indices[self.index_within_chunk] + self.index_within_chunk = (self.index_within_chunk + 1) % self.chunk_size + return i + + +# --------------------------------------------------------------------------- +# Round-robin loader +# --------------------------------------------------------------------------- + + +class RoundRobinLoader(torch.utils.data.IterableDataset): + """Round-robin interleaving of multiple map-style DataLoaders. + + This converts map-style datasets to iterable-style by cycling through + the given DataLoaders in round-robin order, removing exhausted ones + until all are done. + + Typical usage: create one ``DataLoader`` per worker, each backed by a + ``ChunkedDistributedSampler`` with a unique rank, then wrap them with + ``RoundRobinLoader``. + + Args: + dataloaders: List of DataLoader instances to interleave. + """ + + def __init__(self, dataloaders: list[torch.utils.data.DataLoader]): + super().__init__() + self.dataloaders = dataloaders + + def __len__(self): + return sum(len(dl) for dl in self.dataloaders) + + def __iter__(self): + iterators = [iter(dl) for dl in self.dataloaders] + active_indices = list(range(len(self.dataloaders))) + + while active_indices: + for idx in list(active_indices): + try: + yield next(iterators[idx]) + except StopIteration: + active_indices.remove(idx) diff --git a/physicsnemo/experimental/datapipes/healda/time_utils.py b/physicsnemo/experimental/datapipes/healda/time_utils.py new file mode 100644 index 0000000000..c5ffa48195 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/time_utils.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Datetime conversion utilities for HealDA data loading.""" + +import datetime + +import cftime +import numpy as np +import pandas as pd + + +def as_pydatetime(time) -> datetime.datetime: + """Convert a cftime or stdlib datetime to a timezone-aware Python datetime.""" + if isinstance(time, cftime.datetime): + return datetime.datetime(*cftime.to_tuple(time), tzinfo=datetime.timezone.utc) + elif isinstance(time, datetime.datetime): + return time + else: + raise NotImplementedError(type(time)) + + +def as_numpy(time) -> np.ndarray: + """Standardize time input to ``np.ndarray`` of ``np.datetime64``.""" + if hasattr(time, "values"): # Handle pandas Index + time = time.values + elif isinstance(time, (pd.Timestamp, datetime.datetime)): + time = np.array([np.datetime64(time)]) + elif isinstance(time, cftime.datetime): + return as_numpy(as_pydatetime(time)) + elif isinstance(time, np.datetime64): + time = np.array([time]) + else: + time = np.array([np.datetime64(t) for t in time]) + return time + + +def as_timestamp(time) -> np.ndarray: + """Return *time* as an integer Unix timestamp (seconds since epoch).""" + return as_numpy(time).astype("datetime64[s]").astype(int) + + +def second_of_day(time): + """Return seconds elapsed since the start of the day for *time*.""" + begin_of_day = time.replace(hour=0, second=0, minute=0) + return (time - begin_of_day).total_seconds() + + +def as_cftime(timestamp) -> cftime.DatetimeGregorian: + """Convert a pandas Timestamp (or similar) to ``cftime.DatetimeGregorian``.""" + return cftime.DatetimeGregorian( + timestamp.year, + timestamp.month, + timestamp.day, + timestamp.hour, + timestamp.minute, + timestamp.second, + ) diff --git a/physicsnemo/experimental/datapipes/healda/transforms/__init__.py b/physicsnemo/experimental/datapipes/healda/transforms/__init__.py new file mode 100644 index 0000000000..3159bfe656 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/transforms/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/physicsnemo/experimental/datapipes/healda/transforms/era5_obs.py b/physicsnemo/experimental/datapipes/healda/transforms/era5_obs.py new file mode 100644 index 0000000000..77adb2a87e --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/transforms/era5_obs.py @@ -0,0 +1,412 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Two-stage transform for ERA5 state + observation data. + +``ERA5ObsTransform`` implements both the ``Transform`` and ``DeviceTransform`` +protocols: + +- **Stage 1** (``transform``): CPU-side preprocessing in DataLoader workers. + Normalizes state, encodes observations to tensors, computes time encodings. + +- **Stage 2** (``device_transform``): GPU-side transfer and featurization in + the ``prefetch_map`` background thread. Moves tensors to device, computes + observation metadata features, creates ``UnifiedObservation``. +""" + +import dataclasses +import datetime +import functools +import warnings + +import cftime +import earth2grid +import numpy as np +import pyarrow as pa +import pyarrow.compute as pc +import torch + +from physicsnemo.experimental.datapipes.healda.configs.static_data import load_lfrac, load_orography +from physicsnemo.experimental.datapipes.healda.configs.variable_configs import VARIABLE_CONFIGS +from physicsnemo.experimental.datapipes.healda.loaders.era5 import get_batch_info +from physicsnemo.experimental.datapipes.healda.transforms import obs_features, obs_features_ext +from physicsnemo.experimental.datapipes.healda.types import UnifiedObservation, VariableConfig + +warnings.filterwarnings( + "ignore", + message="The given NumPy array is not writable, and PyTorch does not support non-writable tensors", +) + + +def _cftime_to_timestamp(time: cftime.DatetimeGregorian) -> float: + return datetime.datetime( + *cftime.to_tuple(time), tzinfo=datetime.timezone.utc + ).timestamp() + + +def _reorder_nest_to_hpxpad(x): + x = torch.as_tensor(x) + src_order = earth2grid.healpix.NEST + dst_order = earth2grid.healpix.HEALPIX_PAD_XY + return earth2grid.healpix.reorder(x, src_order, dst_order) + + +def _compute_second_of_day(time: cftime.datetime): + day_start = time.replace(hour=0, minute=0, second=0) + return (time - day_start) / datetime.timedelta(seconds=1) + + +def _compute_day_of_year(time: cftime.datetime): + day_start = time.replace(hour=0, minute=0, second=0) + year_start = day_start.replace(month=1, day=1) + return (time - year_start) / datetime.timedelta(seconds=86400) + + +def _compute_timestamp(time: cftime.datetime): + return int(_cftime_to_timestamp(time)) + + +def _get_static_condition(HPX_LEVEL, variable_config) -> torch.Tensor: + lfrac = load_lfrac(HPX_LEVEL) + orography = load_orography() + # Global mean and std computed over the UFS HEALPix level-6 training data. + orog_scale, orog_mean = 627.3885284872, 232.56013904090733 + lfrac_scale, lfrac_mean = 0.4695501683565522, 0.3410480857539571 + data = { + "orog": (orography - orog_mean) / orog_scale, + "lfrac": (lfrac - lfrac_mean) / lfrac_scale, + } + arrays = [torch.as_tensor(data[name]) for name in variable_config.variables_static] + array = torch.stack(arrays).float() # c x + return array.unsqueeze(1) + + +@dataclasses.dataclass +class ERA5ObsTransform: + """Two-stage batch transform for ERA5 state + observation data. + + Implements both ``Transform`` and ``DeviceTransform`` protocols. + + Args: + variable_config: Which variables and levels to normalize. + hpx_level: HEALPix level for observation pixel lookup. + hpx_level_condition: HEALPix level for static conditioning data. + extended_features: Whether to use extended (30-feature) observation + encoding instead of the standard 28-feature encoding. + """ + + variable_config: VariableConfig = VARIABLE_CONFIGS["era5"] + hpx_level: int = 10 + hpx_level_condition: int = 6 + extended_features: bool = False + + def __post_init__(self): + batch_info = get_batch_info(self.variable_config) + self.mean = np.array(batch_info.center)[:, None] + self.std = np.array(batch_info.scales)[:, None] + + @functools.cached_property + def _grid(self): + return earth2grid.healpix.Grid( + self.hpx_level, pixel_order=earth2grid.healpix.NEST + ) + + # ------------------------------------------------------------------ + # Obs processing helpers + # ------------------------------------------------------------------ + + @staticmethod + def _sort_by_record_batch(table: pa.Table, column_name: str) -> pa.Table: + record_batches_order = [] + for batch in table.to_batches(): + if batch.num_rows == 0: + continue + group_value = batch[column_name][0] + record_batches_order.append((group_value, batch)) + + if not record_batches_order: + return table + + record_batches_order.sort(key=lambda x: x[0].as_py()) + return pa.Table.from_batches([batch for _, batch in record_batches_order]) + + @staticmethod + def _append_batch_time_info_chunked( + table: pa.Table, b: int, t: int, timestamp: int + ) -> pa.Table: + b_idx_type = pa.int16() + t_idx_type = pa.int16() + time_type = pa.int64() + ref_col = table.column(0) + + b_idx_chunks, t_idx_chunks, time_chunks = [], [], [] + for chunk in ref_col.chunks: + L = len(chunk) + if L == 0: + b_idx_chunks.append(pa.array([], type=b_idx_type)) + t_idx_chunks.append(pa.array([], type=t_idx_type)) + time_chunks.append(pa.array([], type=time_type)) + continue + + b_idx_chunks.append(pa.array(np.full(L, b, dtype=np.int16), type=b_idx_type)) + t_idx_chunks.append(pa.array(np.full(L, t, dtype=np.int16), type=t_idx_type)) + time_chunks.append(pa.array(np.full(L, timestamp, dtype=np.int64), type=time_type)) + + out = table.append_column("batch_idx", pa.chunked_array(b_idx_chunks, type=b_idx_type)) + out = out.append_column("time_idx", pa.chunked_array(t_idx_chunks, type=t_idx_type)) + out = out.append_column("target_time", pa.chunked_array(time_chunks, type=time_type)) + return out + + @staticmethod + def _build_observation_lengths_3d(obs_table: pa.Table, frame_times): + B, T = len(frame_times), len(frame_times[0]) + sensor_ids = set() + counts_map = {} + + for batch in obs_table.to_batches(): + if batch.num_rows == 0: + continue + s_id = int(batch["sensor_id"][0].as_py()) + b_id = int(batch["batch_idx"][0].as_py()) + t_id = int(batch["time_idx"][0].as_py()) + n = batch.num_rows + sensor_ids.add(s_id) + if s_id not in counts_map: + counts_map[s_id] = torch.zeros((B, T), dtype=torch.int32) + counts_map[s_id][b_id, t_id] += n + + active_sensor_ids = sorted(sensor_ids) + S = len(active_sensor_ids) + + if not sensor_ids: + lengths_3d = torch.zeros((0, B, T), dtype=torch.int32) + sensor_id_to_local = torch.zeros((0,), dtype=torch.int32) + return lengths_3d, sensor_id_to_local + + max_sensor_id = max(sensor_ids) + lengths_3d = torch.zeros((S, B, T), dtype=torch.int32) + for s_local, s_id in enumerate(active_sensor_ids): + lengths_3d[s_local] = counts_map[s_id] + + sensor_id_to_local = torch.full((max_sensor_id + 1,), -1, dtype=torch.int32) + for local_idx, sensor_id in enumerate(active_sensor_ids): + sensor_id_to_local[sensor_id] = local_idx + + return lengths_3d, sensor_id_to_local + + def _process_obs(self, target_times, frames): + all_obs_with_indices = [] + for b_idx, sample_frames in enumerate(frames): + for t_idx, frame_dict in enumerate(sample_frames): + table = frame_dict["obs"] + table_with_indices = self._append_batch_time_info_chunked( + table, b_idx, t_idx, + _compute_timestamp(target_times[b_idx][t_idx]), + ) + all_obs_with_indices.append(table_with_indices) + + obs = pa.concat_tables(all_obs_with_indices) + obs = self._sort_by_record_batch(obs, "sensor_id") + + lengths_3d, sensor_id_to_local = self._build_observation_lengths_3d( + obs, target_times + ) + + obs_tensors = {} + required_columns = { + "latitude": "Latitude", + "longitude": "Longitude", + "observation": "Observation", + "global_channel_id": "Global_Channel_ID", + "sat_zenith_angle": "Sat_Zenith_Angle", + "sol_zenith_angle": "Sol_Zenith_Angle", + "local_channel_id": "local_channel_id", + "height": "Height", + "pressure": "Pressure", + "scan_angle": "Scan_Angle", + } + + for tensor_key, column_name in required_columns.items(): + obs_tensors[tensor_key] = torch.from_numpy(obs[column_name].to_numpy()) + + arr = obs["Absolute_Obs_Time"].to_numpy().astype("datetime64[ns]", copy=False) + obs_tensors["absolute_obs_time"] = torch.from_numpy(arr.view(np.int64)) + obs_tensors["target_time_sec"] = torch.from_numpy(obs["target_time"].to_numpy()) + + platform_id = pc.fill_null(obs["Platform_ID"], 0) + obs_tensors["platform_id"] = torch.from_numpy(platform_id.to_numpy()) + + obs_type = pc.fill_null(obs["Observation_Type"], 0) + obs_tensors["observation_type"] = torch.from_numpy(obs_type.to_numpy()) + + return obs_tensors, lengths_3d, sensor_id_to_local + + def _get_target(self, frames) -> torch.Tensor: + all_state = [f["state"] for sample in frames for f in sample] + batch_size = len(frames) + state = np.stack(all_state) + state = state.reshape((batch_size, -1) + state.shape[1:]) + state = (state - self.mean) / self.std + target = torch.from_numpy(state) + b, t, c, x = range(4) + out = target.permute(b, c, t, x) + return _reorder_nest_to_hpxpad(out) + + @functools.cached_property + def _static_condition(self): + condition = _get_static_condition( + self.hpx_level_condition, self.variable_config + ) + condition = condition.unsqueeze(0) + return _reorder_nest_to_hpxpad(condition) + + # ------------------------------------------------------------------ + # Stage 1: CPU transform (DataLoader workers) + # ------------------------------------------------------------------ + + def transform(self, times, frames): + """CPU-side batch transform. + + Args: + times: ``list[list[cftime]]`` shaped ``(batch, time_per_sample)``. + frames: ``list[list[dict]]`` shaped ``(batch, time_per_sample)``. + + Returns: + Batch dict with ``target``, ``unified_obs``, ``condition``, and + time encodings. + """ + out = {} + + def _apply_time_func(func): + return torch.from_numpy(np.vectorize(func)(times)) + + if "obs" in frames[0][0].keys(): + out["unified_obs"] = self._process_obs(times, frames) + out["target"] = self._get_target(frames).float() + out["second_of_day"] = _apply_time_func(_compute_second_of_day).float() + out["day_of_year"] = _apply_time_func(_compute_day_of_year).float() + out["timestamp"] = _apply_time_func(_compute_timestamp) + b, _, t, _ = out["target"].shape + condition = self._static_condition.float() + if condition.shape[0] not in (1, b): + raise ValueError( + f"condition batch dim {condition.shape[0]} must be 1 or target batch {b}" + ) + if condition.shape[2] not in (1, t): + raise ValueError( + f"condition time dim {condition.shape[2]} must be 1 or target time {t}" + ) + if condition.shape[0] == 1 and b > 1: + condition = condition.expand(b, -1, -1, -1) + if condition.shape[2] == 1 and t > 1: + condition = condition.expand(-1, -1, t, -1) + out["condition"] = condition.clone() + out["labels"] = torch.empty([len(frames), 0]) + return out + + # ------------------------------------------------------------------ + # Stage 2: GPU transform (prefetch_map background thread) + # ------------------------------------------------------------------ + + def device_transform(self, batch, device): + """GPU-side transform: move to device and compute observation features. + + Args: + batch: Output of ``transform()``. + device: Target ``torch.device``. + + Returns: + Batch dict with all tensors on device and ``unified_obs`` as + ``UnifiedObservation``. + """ + batch = batch.copy() + out = {} + for key in batch: + if key == "unified_obs": + obs_tensors, lengths, sensor_id_to_local = batch["unified_obs"] + out[key] = self._device_transform_unified_obs( + obs_tensors, lengths, sensor_id_to_local, device + ) + else: + out[key] = batch[key].to(device, non_blocking=True) + return out + + def _device_transform_unified_obs( + self, obs_tensors, lengths, sensor_id_to_local, device + ): + def _to_device(tensor, non_blocking=True): + if isinstance(tensor, torch.Tensor): + return tensor.to(device, non_blocking=non_blocking) + else: + return torch.from_numpy(tensor).to(device, non_blocking=non_blocking) + + obs_tensors = {key: _to_device(val) for key, val in obs_tensors.items()} + + obs_time_ns = obs_tensors["absolute_obs_time"] + lat_tensor = obs_tensors["latitude"] + lon_tensor = obs_tensors["longitude"] + height_tensor = obs_tensors["height"] + pressure_tensor = obs_tensors["pressure"] + scan_angle_tensor = obs_tensors["scan_angle"] + sat_zenith_tensor = obs_tensors["sat_zenith_angle"] + sol_zenith_tensor = obs_tensors["sol_zenith_angle"] + platform_id_tensor = obs_tensors["platform_id"].int() + obs_type_tensor = obs_tensors["observation_type"].int() + pix = self._grid.ang2pix(lon_tensor, lat_tensor).int() + local_channel_id_tensor = obs_tensors["local_channel_id"].int() + global_channel_id_tensor = obs_tensors["global_channel_id"].int() + observation_tensor = obs_tensors["observation"] + + if self.extended_features: + meta = obs_features_ext.compute_unified_metadata( + obs_tensors["target_time_sec"], + time=obs_time_ns, + lon=lon_tensor, + lat=lat_tensor, + height=height_tensor, + pressure=pressure_tensor, + scan_angle=scan_angle_tensor, + sat_zenith_angle=sat_zenith_tensor, + sol_zenith_angle=sol_zenith_tensor, + ) + else: + meta = obs_features.compute_unified_metadata( + obs_tensors["target_time_sec"], + time=obs_time_ns, + lon=lon_tensor, + height=height_tensor, + pressure=pressure_tensor, + scan_angle=scan_angle_tensor, + sat_zenith_angle=sat_zenith_tensor, + sol_zenith_angle=sol_zenith_tensor, + ) + + return UnifiedObservation( + obs=observation_tensor, + time=obs_time_ns, + float_metadata=meta, + pix=pix, + local_channel=local_channel_id_tensor, + platform=platform_id_tensor, + obs_type=obs_type_tensor, + global_channel=global_channel_id_tensor, + hpx_level=self.hpx_level, + lengths=_to_device(lengths), + sensor_id_to_local=_to_device(sensor_id_to_local), + ) + + +def identity_collate(obj): + """Identity collate function — batch is already assembled by __getitems__.""" + return obj diff --git a/physicsnemo/experimental/datapipes/healda/transforms/obs_features.py b/physicsnemo/experimental/datapipes/healda/transforms/obs_features.py new file mode 100644 index 0000000000..5345d4d827 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/transforms/obs_features.py @@ -0,0 +1,361 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Literal +import math +import torch +import triton +import triton.language as tl +from triton.language.extra.libdevice import fast_sinf as fsin, fast_cosf as fcos, isnan + +N_FEATURES = 28 + + +def _compute_unified_metadata_reference( + target_time_sec: torch.Tensor, # int64 seconds + lon: torch.Tensor, + time: torch.Tensor, # int64 nanoseconds + # Raw metadata fields + height: torch.Tensor | None = None, + pressure: torch.Tensor | None = None, + scan_angle: torch.Tensor | None = None, + sat_zenith_angle: torch.Tensor | None = None, + sol_zenith_angle: torch.Tensor | None = None, +) -> torch.Tensor: + """ + Compute unified metadata from raw fields. + + Features are concatenated in the following order: + - Local solar time (4 features): Fourier encoding with 2 frequencies + - Relative time features (2 features): normalized time difference and its square + - Height features (8 features, NaN for satellite): Fourier encoding with 4 frequencies + - Pressure features (8 features, NaN for satellite): Fourier encoding with 4 frequencies + - Scan angle features (2 features, NaN for conventional): normalized scan angle and its square + - Satellite zenith features (2 features, NaN for conventional): cos(θ_sat) and cos(θ_sat)² + - Solar zenith features (2 features, NaN for conventional): cos(θ_sun) and sin(θ_sun) + + Note: time inputs use int64 to preserve precision. Float conversion happens only + after magnitude reduction to avoid precision loss with large Unix timestamps. + """ + device = lon.device + n_obs = lon.shape[0] + + lst = local_solar_time(lon, time) + + # Build metadata features as a list + metadata_features = [] + + # Local solar time features (4 features) + local_solar_time_features = fourier_features( + lst / 24.0, 2 + ) # 2 frequencies = 4 features + metadata_features.append(local_solar_time_features) + + # Relative time features (2 features) + target_time_ns = target_time_sec * 1_000_000_000 + dt_sec = (time - target_time_ns).float() * 1e-9 + relative_time_hours = dt_sec / 3600.0 + dt_norm = relative_time_hours / 24.0 # Normalize + time_norm_features = torch.stack([dt_norm, dt_norm**2], dim=-1) + metadata_features.append(time_norm_features) + + # Height features (8 features, NaN for satellite) + if height is not None: + height_norm = normalize( + height, + "linear", + 100.0, # height_min + 60000.0, # height_max + 0.5, # height_power + ) + height_features = fourier_features(height_norm, 4) # 4 frequencies = 8 features + metadata_features.append(height_features) + else: + # Add NaN tensor for height features + metadata_features.append( + torch.full((n_obs, 8), float("nan"), device=device, dtype=torch.float32) + ) + + # Pressure features (8 features, NaN for satellite) + if pressure is not None: + pressure_norm = normalize( + pressure, + "linear", + 10.0, # pressure_min + 1100.0, # pressure_max + 3.0, # pressure_power + ) + pressure_features = fourier_features( + pressure_norm, 4 + ) # 4 frequencies = 8 features + metadata_features.append(pressure_features) + else: + # Add NaN tensor for pressure features + metadata_features.append( + torch.full((n_obs, 8), float("nan"), device=device, dtype=torch.float32) + ) + + # Scan angle features (2 features, NaN for conventional) + if scan_angle is not None: + xi_norm = scan_angle / 50.0 # ~[-1,1] as in existing code + scan_angle_features = torch.stack([xi_norm, xi_norm**2], dim=-1) + metadata_features.append(scan_angle_features) + else: + # Add NaN tensor for scan angle features + metadata_features.append( + torch.full((n_obs, 2), float("nan"), device=device, dtype=torch.float32) + ) + + # Satellite zenith features (2 features, NaN for conventional) + if sat_zenith_angle is not None: + cos_theta_sat = torch.cos(torch.deg2rad(sat_zenith_angle)) + sat_zenith_features = torch.stack([cos_theta_sat, cos_theta_sat**2], dim=-1) + metadata_features.append(sat_zenith_features) + else: + metadata_features.append( + torch.full((n_obs, 2), float("nan"), device=device, dtype=torch.float32) + ) + + # Solar zenith features (2 features, NaN for conventional) + if sol_zenith_angle is not None: + cos_theta_sun = torch.cos(torch.deg2rad(sol_zenith_angle)) + sin_theta_sun = torch.sin(torch.deg2rad(sol_zenith_angle)) + sol_zenith_features = torch.stack([cos_theta_sun, sin_theta_sun], dim=-1) + metadata_features.append(sol_zenith_features) + else: + # Add NaN tensor for solar zenith features + metadata_features.append( + torch.full((n_obs, 2), float("nan"), device=device, dtype=torch.float32) + ) + + # Concatenate all features + metadata = torch.cat(metadata_features, dim=-1) + metadata = metadata.nan_to_num(0.0) + + return metadata + + +def normalize( + x: torch.Tensor, + scale: Literal["linear", "log", "power"], + x_min: float, + x_max: float, + power: float, +) -> torch.Tensor: + # map x onto [0,1] using chosen scale + if scale == "linear": + return torch.clamp(x / x_max, 0.0, 1.0) + elif scale == "log": + # ensure positive + return (torch.log(x + x_min) - math.log(x_min)) / ( + math.log(x_max + x_min) - math.log(x_min) + ) + elif scale == "power": + x_lin = torch.clamp(x / x_max, 0.0, 1.0) + return x_lin.pow(power) + else: + raise ValueError(f"Unknown scale '{scale}'") + + +def fourier_features(x_norm: torch.Tensor, num_freqs: int) -> torch.Tensor: + # x_norm: (N,) in [0,1] + # produce (N, 2*num_freqs) of sin/cos features + device = x_norm.device + freqs = torch.arange(1, num_freqs + 1, device=device, dtype=x_norm.dtype) * ( + 2 * math.pi + ) + x_expanded = x_norm.unsqueeze(-1) * freqs # (N, num_freqs) + sin_features = torch.sin(x_expanded) + cos_features = torch.cos(x_expanded) + return torch.cat([sin_features, cos_features], dim=-1) + + +def local_solar_time( + lon_deg: torch.Tensor, + abs_time_ns: torch.Tensor, +) -> torch.Tensor: + # Approximate without equation of time correction + sec_of_day = (abs_time_ns // 1_000_000_000) % 86400 + utc_hours = sec_of_day.float() / 3600.0 + lst = (utc_hours + lon_deg / 15.0) % 24.0 + return lst + + +######################################################### +# Triton implementations +######################################################### + + +@triton.jit +def _fourier_store(out_ptr, base, offset, angle, valid, m, NUM_FREQS: tl.constexpr): + """Store fourier features [sin(k*angle), cos(k*angle)] for k=1..NUM_FREQS, zeroed when !valid.""" + for k in tl.static_range(1, NUM_FREQS + 1): + tl.store( + out_ptr + base + offset + k - 1, + tl.where(valid, fsin(angle * k), 0.0), + mask=m, + ) + tl.store( + out_ptr + base + offset + NUM_FREQS + k - 1, + tl.where(valid, fcos(angle * k), 0.0), + mask=m, + ) + + +@triton.jit +def _metadata_kernel( + lon_ptr, + time_ptr, + target_ptr, + height_ptr, + press_ptr, + scan_ptr, + sat_ptr, + sol_ptr, + out_ptr, + N, + BLOCK: tl.constexpr, +): + """Compute unified metadata using a single Triton kernel. + Using torch.compile on compute_unified_metadata runs into issues with torch dynamo when using DistributedDataParallel. + Seems compiling a function that isn't in main network/main thread does not work. + """ + + pid = tl.program_id(0) + off = pid * BLOCK + tl.arange(0, BLOCK) + m = off < N + + lon = tl.load(lon_ptr + off, mask=m, other=0.0).to(tl.float32) + time_ns = tl.load(time_ptr + off, mask=m, other=0) + target_s = tl.load(target_ptr + off, mask=m, other=0) + height = tl.load(height_ptr + off, mask=m, other=0.0).to(tl.float32) + pressure = tl.load(press_ptr + off, mask=m, other=0.0).to(tl.float32) + scan = tl.load(scan_ptr + off, mask=m, other=0.0).to(tl.float32) + sat_zen = tl.load(sat_ptr + off, mask=m, other=0.0).to(tl.float32) + sol_zen = tl.load(sol_ptr + off, mask=m, other=0.0).to(tl.float32) + + # Fields are NaN when the observation type doesn't carry that metadata + # (e.g. satellite obs lack height/pressure, conventional obs lack zenith angles). + height_valid = ~isnan(height) + pressure_valid = ~isnan(pressure) + scan_valid = ~isnan(scan) + sat_zen_valid = ~isnan(sat_zen) + sol_zen_valid = ~isnan(sol_zen) + + TWO_PI: tl.constexpr = 6.283185307179586 + DEG2RAD: tl.constexpr = 0.017453292519943295 + base = off * 28 + idx = 0 + + # ======== Local Solar Time — fourier(2) -> 4 features ======== + sod = (time_ns // 1000000000) % 86400 + utc_hr = sod.to(tl.float32) / 3600.0 + lst = (utc_hr + lon / 15.0) % 24.0 + lst_angle = lst / 24.0 * TWO_PI + _fourier_store(out_ptr, base, idx, lst_angle, True, m, 2) + idx += 4 # 2 freqs * 2 (sin+cos) + + # ======== Relative Time -> 2 features ======== + dt_days = (time_ns - target_s * 1000000000).to(tl.float32) * 1e-9 / 86400.0 + tl.store(out_ptr + base + idx, dt_days, mask=m) + tl.store(out_ptr + base + idx + 1, dt_days * dt_days, mask=m) + idx += 2 + + # ======== Height — fourier(4) -> 8 features ======== + height_angle = tl.where( + height_valid, tl.minimum(tl.maximum(height / 60000.0, 0.0), 1.0) * TWO_PI, 0.0 + ) + _fourier_store(out_ptr, base, idx, height_angle, height_valid, m, 4) + idx += 8 # 4 freqs * 2 + + # ======== Pressure — fourier(4) -> 8 features ======== + press_angle = tl.where( + pressure_valid, + tl.minimum(tl.maximum(pressure / 1100.0, 0.0), 1.0) * TWO_PI, + 0.0, + ) + _fourier_store(out_ptr, base, idx, press_angle, pressure_valid, m, 4) + idx += 8 # 4 freqs * 2 + + # ======== Scan Angle -> 2 features ======== + scan_norm = tl.where(scan_valid, scan / 50.0, 0.0) + tl.store(out_ptr + base + idx, tl.where(scan_valid, scan_norm, 0.0), mask=m) + tl.store( + out_ptr + base + idx + 1, + tl.where(scan_valid, scan_norm * scan_norm, 0.0), + mask=m, + ) + idx += 2 + + # ======== Satellite Zenith -> 2 features ======== + cos_sat = fcos(tl.where(sat_zen_valid, sat_zen * DEG2RAD, 0.0)) + tl.store(out_ptr + base + idx, tl.where(sat_zen_valid, cos_sat, 0.0), mask=m) + tl.store( + out_ptr + base + idx + 1, + tl.where(sat_zen_valid, cos_sat * cos_sat, 0.0), + mask=m, + ) + idx += 2 + + # ======== Solar Zenith -> 2 features ======== + sol_rad = tl.where(sol_zen_valid, sol_zen * DEG2RAD, 0.0) + tl.store(out_ptr + base + idx, tl.where(sol_zen_valid, fcos(sol_rad), 0.0), mask=m) + tl.store( + out_ptr + base + idx + 1, tl.where(sol_zen_valid, fsin(sol_rad), 0.0), mask=m + ) + idx += 2 + + +def compute_unified_metadata( + target_time_sec: torch.Tensor, + time: torch.Tensor, + lon: torch.Tensor, + height: torch.Tensor, + pressure: torch.Tensor, + scan_angle: torch.Tensor, + sat_zenith_angle: torch.Tensor, + sol_zenith_angle: torch.Tensor, +) -> torch.Tensor: + if not lon.is_cuda: + return _compute_unified_metadata_reference( + target_time_sec, + lon=lon, + time=time, + height=height, + pressure=pressure, + scan_angle=scan_angle, + sat_zenith_angle=sat_zenith_angle, + sol_zenith_angle=sol_zenith_angle, + ) + N = lon.shape[0] + out = torch.empty(N, N_FEATURES, dtype=torch.float32, device=lon.device) + if N == 0: + return out + BLOCK = 256 + grid = ((N + BLOCK - 1) // BLOCK,) + _metadata_kernel[grid]( + lon, + time, + target_time_sec, + height, + pressure, + scan_angle, + sat_zenith_angle, + sol_zenith_angle, + out, + N, + BLOCK=BLOCK, + num_warps=4, + ) + return out diff --git a/physicsnemo/experimental/datapipes/healda/transforms/obs_features_ext.py b/physicsnemo/experimental/datapipes/healda/transforms/obs_features_ext.py new file mode 100644 index 0000000000..3fb8a5fe65 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/transforms/obs_features_ext.py @@ -0,0 +1,318 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import torch +import triton +import triton.language as tl +from triton.language.extra.libdevice import fast_sinf as fsin, fast_cosf as fcos, isnan + +N_FEATURES = 30 +# Layout: +# [0:4) LST fourier(2) — shared +# [4:6) relative-time [dt_days, dt_days²] — shared (polynomial) +# [6:8) relative-time fourier(1) — shared (cyclic) +# [8:10) latitude [sin(lat), cos(lat)] — shared +# [10:30) BRANCH (conv OR sat): +# conv: height fourier(5) [10:20) + pressure fourier(5) [20:30) +# sat: scan fourier(3) [10:16) + sat_zen fourier(4) [16:24) + sol_zen fourier(3) [24:30) +# +# Normalization conventions (normalize to [0,1] by physical max): +# height: h / 60000 -> [0, 1] +# pressure: p / 1100 -> [0, 1] +# scan: ξ / 50 -> ~[-1, 1] (data range: ~[-50, +50] deg) +# sat_zen: θ / 90 -> ~[-0.67, +0.67] (data range: ~[-60, +60] deg, signed) +# sol_zen: θ / 180 -> ~[0.06, 0.94] (data range: ~[10, 170] deg) + + +def _compute_unified_metadata_reference( + target_time_sec: torch.Tensor, + lon: torch.Tensor, + lat: torch.Tensor, + time: torch.Tensor, + height: torch.Tensor, + pressure: torch.Tensor, + scan_angle: torch.Tensor, + sat_zenith_angle: torch.Tensor, + sol_zenith_angle: torch.Tensor, +) -> torch.Tensor: + """Reference (CPU-friendly) implementation of unified metadata v2. + + Conv/sat specialization: height validity determines which branch fills + slots 10-29. Every slot carries signal — no zero padding. + """ + device = lon.device + n_obs = lon.shape[0] + out = torch.zeros(n_obs, N_FEATURES, dtype=torch.float32, device=device) + + if n_obs == 0: + return out + + is_conv = ~torch.isnan(height) + + TWO_PI = 2 * math.pi + + # --- shared: LST fourier(2) -> 4 --- + lst = local_solar_time(lon, time) + out[:, 0:4] = fourier_features(lst / 24.0 * TWO_PI, 2) + + # --- shared: relative time polynomial -> 2 --- + target_time_ns = target_time_sec * 1_000_000_000 + dt_days = (time - target_time_ns).float() * 1e-9 / 86400.0 + out[:, 4] = dt_days + out[:, 5] = dt_days**2 + + # --- shared: relative time fourier(1) -> 2 --- + out[:, 6:8] = fourier_features(dt_days, 1) + + # --- shared: latitude -> 2 --- + lat_rad = torch.deg2rad(lat) + out[:, 8] = torch.sin(lat_rad) + out[:, 9] = torch.cos(lat_rad) + + # --- conv branch: height fourier(5) + pressure fourier(5) --- + if is_conv.any(): + c = is_conv + h_norm = torch.clamp(height[c] / 60000.0, 0.0, 1.0) + out[c, 10:20] = fourier_features(h_norm * TWO_PI, 5) + p_norm = torch.clamp(pressure[c] / 1100.0, 0.0, 1.0) + out[c, 20:30] = fourier_features(p_norm * TWO_PI, 5) + + # --- sat branch: scan fourier(3) + sat_zen fourier(4) + sol_zen fourier(3) --- + is_sat = ~is_conv + if is_sat.any(): + s = is_sat + out[s, 10:16] = fourier_features(scan_angle[s] / 50.0 * TWO_PI, 3) + out[s, 16:24] = fourier_features(sat_zenith_angle[s] / 90.0 * TWO_PI, 4) + out[s, 24:30] = fourier_features(sol_zenith_angle[s] / 180.0 * TWO_PI, 3) + + return out + + +def fourier_features(x_norm: torch.Tensor, num_freqs: int) -> torch.Tensor: + device = x_norm.device + freqs = torch.arange(1, num_freqs + 1, device=device, dtype=x_norm.dtype) + x_expanded = x_norm.unsqueeze(-1) * freqs + sin_features = torch.sin(x_expanded) + cos_features = torch.cos(x_expanded) + return torch.cat([sin_features, cos_features], dim=-1) + + +def local_solar_time( + lon_deg: torch.Tensor, + abs_time_ns: torch.Tensor, +) -> torch.Tensor: + sec_of_day = (abs_time_ns // 1_000_000_000) % 86400 + utc_hours = sec_of_day.float() / 3600.0 + lst = (utc_hours + lon_deg / 15.0) % 24.0 + return lst + + +######################################################### +# Triton implementations +######################################################### + + +@triton.jit +def _fourier_store(out_ptr, base, offset, x_norm, valid, m, NUM_FREQS: tl.constexpr): + """Store sin(kx), cos(kx) for k=1..NUM_FREQS. Matches fourier_features().""" + for k in tl.static_range(1, NUM_FREQS + 1): + angle = x_norm * k + tl.store( + out_ptr + base + offset + k - 1, + tl.where(valid, fsin(angle), 0.0), + mask=m, + ) + tl.store( + out_ptr + base + offset + NUM_FREQS + k - 1, + tl.where(valid, fcos(angle), 0.0), + mask=m, + ) + + +@triton.jit +def _metadata_kernel( + lat_ptr, + lon_ptr, + time_ptr, + target_ptr, + height_ptr, + press_ptr, + scan_ptr, + sat_ptr, + sol_ptr, + out_ptr, + N, + BLOCK: tl.constexpr, + N_FEAT: tl.constexpr, +): + """Extended observation metadata — 30 features per observation. + + Compared to the standard 28-feature encoding (obs_features.py): + - Adds latitude encoding (sin/cos) as shared features [8:10). + - Uses Fourier encoding for relative time instead of raw polynomial only. + - Replaces NaN-padded conv/sat columns with mutually exclusive branches: + slots 10-29 are written by exactly one branch per row (conv or sat, + selected by NaN in height), so every feature carries signal. + + Triton implementation because torch.compile on the equivalent + `_compute_unified_metadata_reference()` hits dynamo errors under + multi-gpu DDP training (compiling a function in a non-main thread). + """ + pid = tl.program_id(0) + off = pid * BLOCK + tl.arange(0, BLOCK) + m = off < N + + lat = tl.load(lat_ptr + off, mask=m, other=0.0).to(tl.float32) + lon = tl.load(lon_ptr + off, mask=m, other=0.0).to(tl.float32) + time_ns = tl.load(time_ptr + off, mask=m, other=0) + target_s = tl.load(target_ptr + off, mask=m, other=0) + height = tl.load(height_ptr + off, mask=m, other=0.0).to(tl.float32) + pressure = tl.load(press_ptr + off, mask=m, other=0.0).to(tl.float32) + scan = tl.load(scan_ptr + off, mask=m, other=0.0).to(tl.float32) + sat_zen = tl.load(sat_ptr + off, mask=m, other=0.0).to(tl.float32) + sol_zen = tl.load(sol_ptr + off, mask=m, other=0.0).to(tl.float32) + + is_conv = ~isnan(height) + m_conv = m & is_conv + m_sat = m & ~is_conv + + DEG2RAD: tl.constexpr = 0.017453292519943295 + base = off * N_FEAT + idx = 0 + TWO_PI: tl.constexpr = 6.283185307179586 + # ======== Shared: LST fourier(2) -> 4 features ======== + sod = (time_ns // 1000000000) % 86400 + utc_hr = sod.to(tl.float32) / 3600.0 + lst = (utc_hr + lon / 15.0) % 24.0 + lst_norm = lst / 24.0 + _fourier_store(out_ptr, base, idx, lst_norm * TWO_PI, True, m, 2) + idx += 4 + + # ======== Shared: Relative time polynomial -> 2 features ======== + dt_days = (time_ns - target_s * 1000000000).to(tl.float32) * 1e-9 / 86400.0 + tl.store(out_ptr + base + idx, dt_days, mask=m) + tl.store(out_ptr + base + idx + 1, dt_days * dt_days, mask=m) + idx += 2 + + # ======== Shared: Relative time fourier(1) -> 2 features ======== + _fourier_store(out_ptr, base, idx, dt_days, True, m, 1) + idx += 2 + + # ======== Shared: Latitude -> 2 features ======== + lat_rad = lat * DEG2RAD + tl.store(out_ptr + base + idx, fsin(lat_rad), mask=m) + tl.store(out_ptr + base + idx + 1, fcos(lat_rad), mask=m) + idx += 2 + + # ======== Conv branch [idx:idx+20) — guarded by m_conv ======== + branch = idx + h_norm = tl.minimum(tl.maximum(height / 60000.0, 0.0), 1.0) * TWO_PI + _fourier_store(out_ptr, base, branch, h_norm, True, m_conv, 5) + branch += 10 + + p_norm = tl.minimum(tl.maximum(pressure / 1100.0, 0.0), 1.0) * TWO_PI + _fourier_store(out_ptr, base, branch, p_norm, True, m_conv, 5) + + # ======== Sat branch [idx:idx+20) — guarded by m_sat ======== + branch = idx + scan_norm = scan / 50.0 * TWO_PI + _fourier_store(out_ptr, base, branch, scan_norm, True, m_sat, 3) + branch += 6 + + sat_norm = sat_zen / 90.0 * TWO_PI + _fourier_store(out_ptr, base, branch, sat_norm, True, m_sat, 4) + branch += 8 + + sol_norm = sol_zen / 180.0 * TWO_PI + _fourier_store(out_ptr, base, branch, sol_norm, True, m_sat, 3) + + +def compute_unified_metadata( + target_time_sec: torch.Tensor, + time: torch.Tensor, + lon: torch.Tensor, + lat: torch.Tensor, + height: torch.Tensor, + pressure: torch.Tensor, + scan_angle: torch.Tensor, + sat_zenith_angle: torch.Tensor, + sol_zenith_angle: torch.Tensor, +) -> torch.Tensor: + """Compute unified metadata features (v2) for observations. + + Args: + target_time_sec: Target time in seconds since epoch, shape (N,) + time: Observation time in nanoseconds, shape (N,) + lon: Longitude in degrees, shape (N,) + lat: Latitude in degrees, shape (N,) + height: Height in meters (NaN for satellite obs), shape (N,) + pressure: Pressure in hPa (NaN for satellite obs), shape (N,) + scan_angle: Scan angle in degrees (NaN for conv obs), shape (N,) + sat_zenith_angle: Satellite zenith angle in degrees (NaN for conv obs), shape (N,) + sol_zenith_angle: Solar zenith angle in degrees (NaN for conv obs), shape (N,) + + Returns: + Tensor of shape (N, 30) with unified metadata features. + """ + # Validate input shapes + N = lon.shape[0] + for name, tensor in [ + ("target_time_sec", target_time_sec), + ("time", time), + ("lat", lat), + ("height", height), + ("pressure", pressure), + ("scan_angle", scan_angle), + ("sat_zenith_angle", sat_zenith_angle), + ("sol_zenith_angle", sol_zenith_angle), + ]: + if tensor.shape[0] != N: + raise ValueError(f"{name} has length {tensor.shape[0]}, expected {N}") + + if not lon.is_cuda: + return _compute_unified_metadata_reference( + target_time_sec, + lon=lon, + lat=lat, + time=time, + height=height, + pressure=pressure, + scan_angle=scan_angle, + sat_zenith_angle=sat_zenith_angle, + sol_zenith_angle=sol_zenith_angle, + ) + + out = torch.empty(N, N_FEATURES, dtype=torch.float32, device=lon.device) + if N == 0: + return out + BLOCK = 256 + grid = ((N + BLOCK - 1) // BLOCK,) + _metadata_kernel[grid]( + lat, + lon, + time, + target_time_sec, + height, + pressure, + scan_angle, + sat_zenith_angle, + sol_zenith_angle, + out, + N, + BLOCK=BLOCK, + N_FEAT=N_FEATURES, + num_warps=4, + ) + return out diff --git a/physicsnemo/experimental/datapipes/healda/transforms/obs_filtering.py b/physicsnemo/experimental/datapipes/healda/transforms/obs_filtering.py new file mode 100644 index 0000000000..12debfe99f --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/transforms/obs_filtering.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Quality-control filtering for observation data. + +Applies range checks, height/pressure limits, and optional QC flag filtering +to PyArrow observation tables. Used by ``UFSUnifiedLoader`` after joining +channel metadata. +""" + +import pyarrow as pa +import pyarrow.compute as pc + +from physicsnemo.experimental.datapipes.healda.configs.sensors import ( + CONV_GPS_LEVEL2_CHANNELS, + CONV_UV_CHANNELS, + CONV_UV_IN_SITU_TYPES, + QCLimits, +) + +# Column references for filtering expressions +height = pc.field("Height") +pressure = pc.field("Pressure") +obs = pc.field("Observation") +analysis_use = pc.field("Analysis_Use_Flag") +qc_flag = pc.field("QC_Flag") +min_valid = pc.field("min_valid") +max_valid = pc.field("max_valid") +local_id = pc.field("local_channel_id") +is_conv = pc.field("is_conv") +obs_type = pc.field("Observation_Type") + + +def _get_conv_filter_expr( + table: pa.Table, + qc_filter: bool = False, + uv_in_situ_only: bool = False, + gps_level1_only: bool = False, +): + """Build filter expression for conventional observations.""" + is_gps = local_id <= 2 + + height_ok = pc.is_finite(height) & ( + (height >= QCLimits.HEIGHT_MIN) & (height <= QCLimits.HEIGHT_MAX) + ) + + min_pressure = pc.if_else( + is_gps, + pa.scalar(QCLimits.PRESSURE_MIN_GPS), + pa.scalar(QCLimits.PRESSURE_MIN_DEFAULT), + ) + pressure_ok = pc.is_finite(pressure) + pressure_ok &= (pressure >= min_pressure) & (pressure <= QCLimits.PRESSURE_MAX) + + ok = pressure_ok & height_ok + + if qc_filter: + ok &= analysis_use == pa.scalar(1) + + if uv_in_situ_only: + is_uv_channel = pc.is_in(local_id, pa.array(CONV_UV_CHANNELS)) + is_in_situ = pc.is_in( + obs_type, + pa.array(CONV_UV_IN_SITU_TYPES, type=table["Observation_Type"].type), + ) + ok &= ~is_uv_channel | is_in_situ + + if gps_level1_only: + ok &= ~pc.is_in(local_id, pa.array(CONV_GPS_LEVEL2_CHANNELS)) + + return ok + + +def filter_observations( + table: pa.Table, + qc_filter: bool = False, + conv_uv_in_situ_only: bool = False, + conv_gps_level1_only: bool = False, +) -> pa.Table: + """Filter observations by range, QC flags, and conventional-specific criteria. + + Args: + table: PyArrow table with observation data (must include channel metadata + columns ``min_valid``, ``max_valid``, ``is_conv``, ``local_channel_id``). + qc_filter: Whether to apply QC flag / analysis-use filtering. + conv_uv_in_situ_only: Exclude satellite UV winds (keep in-situ only). + conv_gps_level1_only: Exclude GPS T/Q retrievals (keep bending angle). + + Returns: + Filtered PyArrow table. + """ + ok = pc.is_finite(obs) + ok &= obs >= min_valid + ok &= obs <= max_valid + + sat_ok = ok + if qc_filter: + sat_ok &= qc_flag == 0 + + conv_filter = _get_conv_filter_expr( + table, qc_filter, conv_uv_in_situ_only, conv_gps_level1_only + ) + ok &= pc.if_else(is_conv, conv_filter, sat_ok) + + return table.filter(ok) diff --git a/physicsnemo/experimental/datapipes/healda/types.py b/physicsnemo/experimental/datapipes/healda/types.py new file mode 100644 index 0000000000..ade0931806 --- /dev/null +++ b/physicsnemo/experimental/datapipes/healda/types.py @@ -0,0 +1,355 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Core data types for HealDA data loading.""" + +from __future__ import annotations + +import dataclasses +import json +from datetime import timedelta +from enum import Enum +from typing import Any, Optional, TypedDict + +import numpy as np +import torch + + +# --------------------------------------------------------------------------- +# Time unit enum +# --------------------------------------------------------------------------- + + +class TimeUnit(Enum): + """Time units supported by the dataset. + + Values are the pandas frequency strings (offset aliases). + """ + + HOUR = "h" + DAY = "D" + MINUTE = "min" + SECOND = "s" + + def to_timedelta(self, steps: float) -> timedelta: + return { + TimeUnit.HOUR: timedelta(hours=steps), + TimeUnit.DAY: timedelta(days=steps), + TimeUnit.MINUTE: timedelta(minutes=steps), + TimeUnit.SECOND: timedelta(seconds=steps), + }[self] + + +# --------------------------------------------------------------------------- +# Variable configuration +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass(frozen=True) +class VariableConfig: + """Describes the variables and pressure levels for a dataset.""" + + name: str + variables_2d: list[str] + variables_3d: list[str] + levels: list[int] + variables_static: list[str] = dataclasses.field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Batch info (normalization metadata) +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class BatchInfo: + """Metadata about batch channels and normalization constants.""" + + channels: list[str] + time_step: int = 1 # Time (in units ``time_unit``) between consecutive frames + time_unit: TimeUnit = TimeUnit.HOUR + scales: Any | None = None + center: Any | None = None + + def __post_init__(self): + if isinstance(self.time_unit, str): + raise ValueError("Time unit is a str. Should be a TimeUnit.") + + @staticmethod + def loads(s): + kw = json.loads(s) + if "time_unit" in kw: + kw["time_unit"] = TimeUnit(kw["time_unit"]) + kw.pop("residual_normalization", None) + return BatchInfo(**kw) + + def asdict(self): + out = {} + out["channels"] = self.channels + out["time_step"] = self.time_step + out["time_unit"] = self.time_unit.value + if self.scales is not None: + out["scales"] = np.asarray(self.scales).tolist() + else: + out["scales"] = None + if self.center is not None: + out["center"] = np.asarray(self.center).tolist() + else: + out["center"] = None + return out + + def sel_channels(self, channels: list[str]): + channels = list(channels) + index = np.array([self.channels.index(ch) for ch in channels]) + scales = None + if self.scales is not None: + scales = np.asarray(self.scales)[index] + center = None + if self.center is not None: + center = np.asarray(self.center)[index] + return BatchInfo( + time_step=self.time_step, + time_unit=self.time_unit, + channels=channels, + scales=scales, + center=center, + ) + + def denormalize(self, x): + scales = torch.as_tensor(self.scales).to(x) + scales = scales.view(-1, 1, 1) + center = torch.as_tensor(self.center).to(x) + center = center.view(-1, 1, 1) + return x * scales + center + + def get_time_delta(self, t: int) -> timedelta: + """Get time offset of the *t*-th frame in a frame sequence.""" + total_steps = t * self.time_step + return self.time_unit.to_timedelta(total_steps) + + +# --------------------------------------------------------------------------- +# Unified observation structure +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class UnifiedObservation: + """Unified observation structure for both satellite and conventional observations.""" + + obs: torch.Tensor # (n_obs,) observation values + time: torch.Tensor # (n_obs,) timestamps in ns since epoch + float_metadata: torch.Tensor # (n_obs, n_features) + + # Integer metadata (each shape (n_obs,)) + pix: torch.Tensor # HEALPix pixel index (NEST) + local_channel: torch.Tensor + platform: torch.Tensor + obs_type: torch.Tensor + global_channel: torch.Tensor + + hpx_level: int # HEALPix level that ``pix`` is defined at + + lengths: torch.Tensor | None = ( + None # 3D: (n_active_sensors, batch, time) per-window obs counts + ) + sensor_id_to_local: torch.Tensor | None = ( + None # (max_sensor_id+1,) map: sensor_id -> local_idx (-1 if inactive) + ) + + @classmethod + def empty( + cls, + device: str = "cpu", + hpx_level: int = 8, + batch_dims: tuple[int, int] = (1, 1), + ) -> UnifiedObservation: + B, T = batch_dims + return cls( + obs=torch.empty(0, device=device), + time=torch.empty(0, dtype=torch.long, device=device), + float_metadata=torch.empty((0, 28), device=device), + pix=torch.empty(0, dtype=torch.long, device=device), + local_channel=torch.empty(0, dtype=torch.long, device=device), + platform=torch.empty(0, dtype=torch.long, device=device), + obs_type=torch.empty(0, dtype=torch.long, device=device), + global_channel=torch.empty(0, dtype=torch.long, device=device), + lengths=torch.zeros(1, B, T, dtype=torch.long, device=device), + sensor_id_to_local=torch.zeros(1, dtype=torch.long, device=device), + hpx_level=hpx_level, + ) + + @property + def batch_dims(self): + """Return ``(batch, time)`` shape from 3D offsets ``(S, B, T)``.""" + if self.lengths is not None: + return self.lengths.shape[-2:] + else: + return () + + def __repr__(self): + nobs = self.obs.shape[0] + return f"UnifiedObservation({nobs=}, batch_dims={self.batch_dims})" + + def to(self, device=None, dtype=None, non_blocking=True): + """Move all tensors to *device* and/or convert *dtype*.""" + + def _move(x): + if x is None: + return None + return x.to(device=device, dtype=dtype, non_blocking=non_blocking) + + return UnifiedObservation( + obs=_move(self.obs), + time=_move(self.time), + float_metadata=_move(self.float_metadata), + pix=_move(self.pix), + local_channel=_move(self.local_channel), + platform=_move(self.platform), + obs_type=_move(self.obs_type), + global_channel=_move(self.global_channel), + hpx_level=self.hpx_level, + lengths=_move(self.lengths), + sensor_id_to_local=_move(self.sensor_id_to_local), + ) + + +# --------------------------------------------------------------------------- +# Batch TypedDict +# --------------------------------------------------------------------------- + + +class Batch(TypedDict): + """A batch of model inputs produced by the data pipeline.""" + + target: torch.Tensor # (b, c, t, x) + condition: torch.Tensor # (b, c_cond, t, x) + second_of_day: torch.Tensor # (b, t) + day_of_year: torch.Tensor # (b, t) + labels: torch.Tensor # (b, num_classes) + timestamp: torch.Tensor # (b, t) + unified_obs: Optional[UnifiedObservation] + + +def empty_batch( + *, + batch_gpu: int, + out_channels: int, + condition_channels: int, + time_length: int, + x_size: int, + device: torch.device | str, +) -> Batch: + """Create an empty batch with the given dimensions.""" + if x_size <= 0: + raise ValueError(f"x_size must be positive, got {x_size}") + + return { + "target": torch.empty( + [batch_gpu, out_channels, time_length, x_size], device=device + ), + "condition": torch.empty( + [batch_gpu, condition_channels, time_length, x_size], device=device + ), + "second_of_day": torch.empty([batch_gpu, time_length], device=device), + "day_of_year": torch.empty([batch_gpu, time_length], device=device), + "labels": torch.empty([batch_gpu, 0], device=device), + "timestamp": torch.empty( + [batch_gpu, time_length], dtype=torch.long, device=device + ), + "unified_obs": UnifiedObservation.empty( + device=device, batch_dims=(batch_gpu, time_length) + ), + } + + +# --------------------------------------------------------------------------- +# Sensor-level splitting +# --------------------------------------------------------------------------- + + +@torch.compiler.disable +def split_by_sensor( + obs: UnifiedObservation, target_sensor_ids: list[int] +) -> dict[int, UnifiedObservation]: + """Slice a ``UnifiedObservation`` into per-sensor sub-objects. + + Uses precomputed ``lengths`` and ``sensor_id_to_local`` for efficient + splitting without per-element sensor ID checks. + """ + if obs.lengths is None or obs.sensor_id_to_local is None: + raise ValueError("lengths is required for split_by_sensor") + + lengths = obs.lengths # [S, B, T] + sensor_id_to_local = obs.sensor_id_to_local + + device = obs.obs.device + B, T = obs.batch_dims + + sizes = lengths.sum(dim=(1, 2)).tolist() + obs_fields = [ + obs.obs, + obs.time, + obs.float_metadata, + obs.pix, + obs.local_channel, + obs.platform, + obs.obs_type, + obs.global_channel, + ] + splits = [torch.split(f, sizes) for f in obs_fields] + + out = {} + for sensor_id in target_sensor_ids: + if sensor_id < len(sensor_id_to_local): + s_local = int(sensor_id_to_local[sensor_id].item()) + else: + s_local = -1 + + single_sensor_map = torch.full( + (sensor_id + 1,), -1, dtype=torch.int32, device=device + ) + single_sensor_map[sensor_id] = 0 + + if s_local < 0: + sensor_lengths = torch.zeros((1, B, T), dtype=lengths.dtype, device=device) + out[sensor_id] = UnifiedObservation( + obs=obs.obs[:0], + time=obs.time[:0], + float_metadata=obs.float_metadata[:0], + pix=obs.pix[:0], + local_channel=obs.local_channel[:0], + platform=obs.platform[:0], + obs_type=obs.obs_type[:0], + global_channel=obs.global_channel[:0], + hpx_level=obs.hpx_level, + lengths=sensor_lengths, + sensor_id_to_local=single_sensor_map, + ) + else: + out[sensor_id] = UnifiedObservation( + obs=splits[0][s_local], + time=splits[1][s_local], + float_metadata=splits[2][s_local], + pix=splits[3][s_local], + local_channel=splits[4][s_local], + platform=splits[5][s_local], + obs_type=splits[6][s_local], + global_channel=splits[7][s_local], + hpx_level=obs.hpx_level, + lengths=lengths[s_local : s_local + 1], + sensor_id_to_local=single_sensor_map, + ) + + return out From 7a9aa880aac66fd285d6c837c1b8a50d79e64340 Mon Sep 17 00:00:00 2001 From: Peter Harrington Date: Wed, 8 Apr 2026 18:37:40 -0700 Subject: [PATCH 02/10] Cleanup and address imports --- examples/weather/healda/README.md | 81 ++++++---- examples/weather/healda/test/conftest.py | 5 +- .../healda/test/test_combined_schema.py | 19 ++- examples/weather/healda/test/test_features.py | 13 +- examples/weather/healda/test/test_indexing.py | 20 ++- .../weather/healda/test/test_obs_filtering.py | 17 ++- examples/weather/healda/test/test_prefetch.py | 5 +- examples/weather/healda/test/test_samplers.py | 10 +- .../weather/healda/test/test_time_utils.py | 12 +- examples/weather/healda/test/test_types.py | 18 ++- .../datapipes/healda/configs/__init__.py | 5 +- .../healda/configs/combined_schema.py | 9 +- .../datapipes/healda/configs/sensors.py | 5 +- .../datapipes/healda/configs/static_data.py | 12 +- .../healda/configs/variable_configs.py | 5 +- .../experimental/datapipes/healda/dataset.py | 5 +- .../experimental/datapipes/healda/indexing.py | 5 +- .../datapipes/healda/loaders/__init__.py | 5 +- .../datapipes/healda/loaders/era5.py | 10 +- .../datapipes/healda/loaders/ufs_obs.py | 13 +- .../datapipes/healda/loaders/zarr_loader.py | 23 +-- .../experimental/datapipes/healda/prefetch.py | 5 +- .../datapipes/healda/protocols.py | 5 +- .../experimental/datapipes/healda/samplers.py | 5 +- .../datapipes/healda/time_utils.py | 5 +- .../datapipes/healda/transforms/__init__.py | 5 +- .../datapipes/healda/transforms/era5_obs.py | 14 +- .../healda/transforms/obs_features.py | 5 +- .../healda/transforms/obs_features_ext.py | 5 +- .../healda/transforms/obs_filtering.py | 11 +- .../experimental/datapipes/healda/types.py | 5 +- pyproject.toml | 9 ++ uv.lock | 144 +++++++++++++----- 33 files changed, 350 insertions(+), 165 deletions(-) diff --git a/examples/weather/healda/README.md b/examples/weather/healda/README.md index fcda6ef1a9..06cc298c36 100644 --- a/examples/weather/healda/README.md +++ b/examples/weather/healda/README.md @@ -1,22 +1,37 @@ # HealDA — AI-based Data Assimilation on the HEALPix Grid -> 🏗️🏗️ **This recipe is under active construction.** Structure and functionality are subject to changes 🏗️🏗️ +> **This recipe is under active construction.** +> Structure and functionality are subject to changes. -HealDA is a stateless assimilation model that produces a single global weather analysis from conventional and satellite observations. It operates on a HEALPix level-6 padded XY grid and outputs ERA5-compatible atmospheric variables. +HealDA is a stateless assimilation model that produces a single +global weather analysis from conventional and satellite +observations. It operates on a HEALPix level-6 padded XY grid +and outputs ERA5-compatible atmospheric variables. -This example provides a recipe to train HealDA, with support for extension to custom data. +This example provides a recipe to train HealDA, with support +for extension to custom data. ## Setup -Start by installing PhysicsNeMo (if not already installed) with the `datapipes-extras` optional dependency group, along with the packages in `requirements.txt`. Then, copy this folder (`examples/weather/healda`) to a system with a GPU available. Also, prepare a dataset that can serve training data according to the protocols outlined in the [Generalized Data Loading](#generalized-data-loading) section below. +Start by installing PhysicsNeMo (if not already installed) with +the `healda` optional dependency group, along with the packages +in `requirements.txt`. Then, copy this folder +(`examples/weather/healda`) to a system with a GPU available. +Also, prepare a dataset that can serve training data according +to the protocols outlined in the +[Generalized Data Loading](#generalized-data-loading) section +below. ## Generalized Data Loading -The ``physicsnemo.experimental.datapipes.healda`` package provides a composable data loading pipeline with clear extension points. The architecture separates components into loaders, transforms, datasets, and sampling infrastructure. +The `physicsnemo.experimental.datapipes.healda` package provides +a composable data loading pipeline with clear extension points. +The architecture separates components into loaders, transforms, +datasets, and sampling infrastructure. ### Architecture -``` +```text ObsERA5Dataset(era5_data, obs_loader, transform) | Temporal windowing via FrameIndexGenerator | __getitems__ -> get() per index -> transform.transform() @@ -34,25 +49,29 @@ Training loop (GPU-ready batch) ### Key Protocols -Custom data sources and transforms plug in via these protocols (see `physicsnemo.experimental.datapipes.healda.protocols`): +Custom data sources and transforms plug in via these protocols +(see `physicsnemo.experimental.datapipes.healda.protocols`): **`ObsLoader`** — the observation loading interface: + ```python class MyObsLoader: - async def sel_time(self, times: pd.DatetimeIndex) -> dict[str, list[Any]]: - """Return {"obs": [pa.Table_per_time, ...]}""" + async def sel_time(self, times): + """Return {"obs": [pa.Table, ...]}""" ... ``` -**`Transform`** / **`DeviceTransform`** — two-stage batch processing: +**`Transform`** / **`DeviceTransform`** — two-stage batch +processing: + ```python class MyTransform: - def transform(self, times, frames) -> dict[str, Any]: - """CPU-side: normalize, encode observations, time features.""" + def transform(self, times, frames): + """CPU-side: normalize, encode obs, time features.""" ... - def device_transform(self, batch, device) -> dict[str, Any]: - """GPU-side: move to device, compute observation features.""" + def device_transform(self, batch, device): + """GPU-side: move to device, compute obs features.""" ... ``` @@ -60,19 +79,21 @@ class MyTransform: | Component | Module | Description | |---|---|---| -| `ObsERA5Dataset` | `healda.dataset` | Map-style dataset combining ERA5 state + observations | -| `UFSUnifiedLoader` | `healda.loaders.ufs_obs` | Parquet-based observation loader (satellite + conventional) | -| `ERA5Loader` | `healda.loaders.era5` | Async ERA5 zarr loader (not used by ObsERA5Dataset directly) | -| `ERA5ObsTransform` | `healda.transforms.era5_obs` | Two-stage transform with Triton feature kernels | -| `ChunkedDistributedSampler` | `healda.samplers` | Cache-friendly distributed sampler | -| `RoundRobinLoader` | `healda.samplers` | Multi-loader interleaving | -| `prefetch_map` | `healda.prefetch` | Background CUDA stream prefetching | +| `ObsERA5Dataset` | `dataset` | ERA5 state + observations | +| `UFSUnifiedLoader` | `loaders.ufs_obs` | Parquet obs loader | +| `ERA5Loader` | `loaders.era5` | Async ERA5 zarr loader | +| `ERA5ObsTransform` | `transforms.era5_obs` | Two-stage transform | +| `ChunkedDistributedSampler` | `samplers` | Distributed sampler | +| `RoundRobinLoader` | `samplers` | Multi-loader interleave | +| `prefetch_map` | `prefetch` | CUDA stream prefetching | -All modules above are under `physicsnemo.experimental.datapipes` (abbreviated as `healda` in the table). +All modules above are under +`physicsnemo.experimental.datapipes.healda`. ### Writing a Custom Observation Loader -Implement `async def sel_time(times)` returning a dict with observation data per timestamp: +Implement `async def sel_time(times)` returning a dict with +observation data per timestamp: ```python class GOESRadianceLoader: @@ -89,10 +110,17 @@ class GOESRadianceLoader: ``` Then pass it to the dataset: + ```python -from physicsnemo.experimental.datapipes.healda import ObsERA5Dataset -from physicsnemo.experimental.datapipes.healda.transforms.era5_obs import ERA5ObsTransform -from physicsnemo.experimental.datapipes.healda.configs.variable_configs import VARIABLE_CONFIGS +from physicsnemo.experimental.datapipes.healda import ( + ObsERA5Dataset, +) +from physicsnemo.experimental.datapipes.healda.transforms.era5_obs import ( + ERA5ObsTransform, +) +from physicsnemo.experimental.datapipes.healda.configs.variable_configs import ( + VARIABLE_CONFIGS, +) dataset = ObsERA5Dataset( era5_data=era5_xr["data"], @@ -101,4 +129,3 @@ dataset = ObsERA5Dataset( variable_config=VARIABLE_CONFIGS["era5"], ) ``` - diff --git a/examples/weather/healda/test/conftest.py b/examples/weather/healda/test/conftest.py index 73553c39ed..3a4ff72671 100644 --- a/examples/weather/healda/test/conftest.py +++ b/examples/weather/healda/test/conftest.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/examples/weather/healda/test/test_combined_schema.py b/examples/weather/healda/test/test_combined_schema.py index 48c9a43ebe..d03b0efd08 100644 --- a/examples/weather/healda/test/test_combined_schema.py +++ b/examples/weather/healda/test/test_combined_schema.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -20,14 +21,22 @@ get_channel_table_schema, get_combined_observation_schema, ) -from physicsnemo.experimental.datapipes.healda.configs.sensors import SENSOR_CONFIGS, SENSOR_NAME_TO_ID +from physicsnemo.experimental.datapipes.healda.configs.sensors import ( + SENSOR_CONFIGS, + SENSOR_NAME_TO_ID, +) def test_combined_schema_has_required_fields(): schema = get_combined_observation_schema() required = [ - "Latitude", "Longitude", "Absolute_Obs_Time", "DA_window", - "Platform_ID", "Observation", "Global_Channel_ID", + "Latitude", + "Longitude", + "Absolute_Obs_Time", + "DA_window", + "Platform_ID", + "Observation", + "Global_Channel_ID", ] for name in required: assert name in schema.names, f"Missing required field: {name}" diff --git a/examples/weather/healda/test/test_features.py b/examples/weather/healda/test/test_features.py index 080a8e210c..120b870084 100644 --- a/examples/weather/healda/test/test_features.py +++ b/examples/weather/healda/test/test_features.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -21,8 +22,12 @@ import pytest import torch -from physicsnemo.experimental.datapipes.healda.transforms import obs_features as standard -from physicsnemo.experimental.datapipes.healda.transforms import obs_features_ext as extended +from physicsnemo.experimental.datapipes.healda.transforms import ( + obs_features as standard, +) +from physicsnemo.experimental.datapipes.healda.transforms import ( + obs_features_ext as extended, +) def _make_obs_data(n, device, include_lat=False): diff --git a/examples/weather/healda/test/test_indexing.py b/examples/weather/healda/test/test_indexing.py index b722c689ff..b9b6032bec 100644 --- a/examples/weather/healda/test/test_indexing.py +++ b/examples/weather/healda/test/test_indexing.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -17,7 +18,10 @@ import numpy as np import torch -from physicsnemo.experimental.datapipes.healda.indexing import FrameIndexGenerator, split_array_contiguous +from physicsnemo.experimental.datapipes.healda.indexing import ( + FrameIndexGenerator, + split_array_contiguous, +) def test_split_array_contiguous_single_segment(): @@ -63,10 +67,12 @@ def test_frame_index_generator_model_rank_slicing(): def test_frame_index_generator_multiple_segments(): """Test frame index generation across non-contiguous segments.""" - times = np.concatenate([ - np.arange(0, 10), # [0, 1, ..., 9] - np.arange(20, 35), # [20, 21, ..., 34] - ]) + times = np.concatenate( + [ + np.arange(0, 10), # [0, 1, ..., 9] + np.arange(20, 35), # [20, 21, ..., 34] + ] + ) generator = FrameIndexGenerator( times=times, time_length=3, frame_step=1, model_rank=0, model_world_size=1 diff --git a/examples/weather/healda/test/test_obs_filtering.py b/examples/weather/healda/test/test_obs_filtering.py index fa54ed205a..5fd364da3f 100644 --- a/examples/weather/healda/test/test_obs_filtering.py +++ b/examples/weather/healda/test/test_obs_filtering.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -18,7 +19,9 @@ import pyarrow as pa from physicsnemo.experimental.datapipes.healda.configs.sensors import SENSOR_OFFSET -from physicsnemo.experimental.datapipes.healda.transforms.obs_filtering import filter_observations +from physicsnemo.experimental.datapipes.healda.transforms.obs_filtering import ( + filter_observations, +) def _make_filter_test_table(): @@ -37,10 +40,14 @@ def _make_filter_test_table(): return pa.table( { - "Observation": np.array([100.0, 200.0, 500.0, 50.0, 60.0], dtype=np.float32), + "Observation": np.array( + [100.0, 200.0, 500.0, 50.0, 60.0], dtype=np.float32 + ), "Global_Channel_ID": np.array(channels, dtype=np.uint16), "Pressure": np.array([500.0, 800.0, 600.0, 400.0, 300.0], dtype=np.float32), - "Height": np.array([1000.0, 5000.0, 100.0, 2000.0, 3000.0], dtype=np.float32), + "Height": np.array( + [1000.0, 5000.0, 100.0, 2000.0, 3000.0], dtype=np.float32 + ), "Observation_Type": np.array([200, 210, 220, 230, 280], dtype=np.uint16), "QC_Flag": np.array([0, 0, 0, 0, 0], dtype=np.int32), "Analysis_Use_Flag": np.array([1, 1, 0, 1, 1], dtype=np.int8), diff --git a/examples/weather/healda/test/test_prefetch.py b/examples/weather/healda/test/test_prefetch.py index f2c1b000fd..a0080bb09f 100644 --- a/examples/weather/healda/test/test_prefetch.py +++ b/examples/weather/healda/test/test_prefetch.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/examples/weather/healda/test/test_samplers.py b/examples/weather/healda/test/test_samplers.py index 948a7a0c0f..32922f535d 100644 --- a/examples/weather/healda/test/test_samplers.py +++ b/examples/weather/healda/test/test_samplers.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -19,7 +20,10 @@ import torch import torch.utils.data -from physicsnemo.experimental.datapipes.healda.samplers import ChunkedDistributedSampler, RoundRobinLoader +from physicsnemo.experimental.datapipes.healda.samplers import ( + ChunkedDistributedSampler, + RoundRobinLoader, +) def test_chunked_sampler_sequential(): diff --git a/examples/weather/healda/test/test_time_utils.py b/examples/weather/healda/test/test_time_utils.py index aa4cd27664..942d492285 100644 --- a/examples/weather/healda/test/test_time_utils.py +++ b/examples/weather/healda/test/test_time_utils.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -20,7 +21,12 @@ import numpy as np import pandas as pd -from physicsnemo.experimental.datapipes.healda.time_utils import as_cftime, as_numpy, as_pydatetime, as_timestamp +from physicsnemo.experimental.datapipes.healda.time_utils import ( + as_cftime, + as_numpy, + as_pydatetime, + as_timestamp, +) def test_as_numpy_from_pandas_index(): diff --git a/examples/weather/healda/test/test_types.py b/examples/weather/healda/test/test_types.py index d83fa3b919..4fff22621e 100644 --- a/examples/weather/healda/test/test_types.py +++ b/examples/weather/healda/test/test_types.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -17,7 +18,10 @@ import pytest import torch -from physicsnemo.experimental.datapipes.healda.types import UnifiedObservation, split_by_sensor +from physicsnemo.experimental.datapipes.healda.types import ( + UnifiedObservation, + split_by_sensor, +) def make_realistic_obs( @@ -42,7 +46,9 @@ def make_realistic_obs( for b in range(B): for t in range(T): lengths_3d[s_local, b, t] = sum( - 1 for obs in all_obs if obs[0] == s_id and obs[1] == b and obs[2] == t + 1 + for obs in all_obs + if obs[0] == s_id and obs[1] == b and obs[2] == t ) sensor_id_to_local = torch.full((max(sensors) + 1,), -1, dtype=torch.int32) @@ -151,7 +157,9 @@ def test_split_handles_sparse_windows(): obs = UnifiedObservation( obs=torch.arange(nobs, dtype=torch.float32).unsqueeze(1).expand(nobs, 3), time=torch.zeros(nobs, dtype=torch.long), - float_metadata=torch.arange(nobs, dtype=torch.float32).unsqueeze(1).expand(nobs, 5), + float_metadata=torch.arange(nobs, dtype=torch.float32) + .unsqueeze(1) + .expand(nobs, 5), pix=torch.arange(nobs, dtype=torch.long), local_channel=torch.zeros(nobs, dtype=torch.long), platform=torch.zeros(nobs, dtype=torch.long), diff --git a/physicsnemo/experimental/datapipes/healda/configs/__init__.py b/physicsnemo/experimental/datapipes/healda/configs/__init__.py index 3159bfe656..af85283aa4 100644 --- a/physicsnemo/experimental/datapipes/healda/configs/__init__.py +++ b/physicsnemo/experimental/datapipes/healda/configs/__init__.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/physicsnemo/experimental/datapipes/healda/configs/combined_schema.py b/physicsnemo/experimental/datapipes/healda/configs/combined_schema.py index fb2138c099..fa9cac925a 100644 --- a/physicsnemo/experimental/datapipes/healda/configs/combined_schema.py +++ b/physicsnemo/experimental/datapipes/healda/configs/combined_schema.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -20,7 +21,9 @@ location for multi-component observations. """ -import pyarrow as pa +from physicsnemo.core.version_check import OptionalImport + +pa = OptionalImport("pyarrow") GLOBAL_CHANNEL_ID = pa.field("Global_Channel_ID", pa.uint16(), nullable=False) SENSOR_ID = pa.field("sensor_id", pa.uint16()) diff --git a/physicsnemo/experimental/datapipes/healda/configs/sensors.py b/physicsnemo/experimental/datapipes/healda/configs/sensors.py index e1322c9dea..92b25d7a8c 100644 --- a/physicsnemo/experimental/datapipes/healda/configs/sensors.py +++ b/physicsnemo/experimental/datapipes/healda/configs/sensors.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/physicsnemo/experimental/datapipes/healda/configs/static_data.py b/physicsnemo/experimental/datapipes/healda/configs/static_data.py index 5d4a2d5ff4..f4b94f992d 100644 --- a/physicsnemo/experimental/datapipes/healda/configs/static_data.py +++ b/physicsnemo/experimental/datapipes/healda/configs/static_data.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -21,10 +22,13 @@ import functools import os -import earth2grid import numpy as np import torch -import zarr + +from physicsnemo.core.version_check import OptionalImport + +earth2grid = OptionalImport("earth2grid") +zarr = OptionalImport("zarr") @functools.cache diff --git a/physicsnemo/experimental/datapipes/healda/configs/variable_configs.py b/physicsnemo/experimental/datapipes/healda/configs/variable_configs.py index 25d11e816e..0e2698acb2 100644 --- a/physicsnemo/experimental/datapipes/healda/configs/variable_configs.py +++ b/physicsnemo/experimental/datapipes/healda/configs/variable_configs.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/physicsnemo/experimental/datapipes/healda/dataset.py b/physicsnemo/experimental/datapipes/healda/dataset.py index 3f20a7a3a4..1ad9103766 100644 --- a/physicsnemo/experimental/datapipes/healda/dataset.py +++ b/physicsnemo/experimental/datapipes/healda/dataset.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/physicsnemo/experimental/datapipes/healda/indexing.py b/physicsnemo/experimental/datapipes/healda/indexing.py index d86de4fd51..2c0daa3e1e 100644 --- a/physicsnemo/experimental/datapipes/healda/indexing.py +++ b/physicsnemo/experimental/datapipes/healda/indexing.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/physicsnemo/experimental/datapipes/healda/loaders/__init__.py b/physicsnemo/experimental/datapipes/healda/loaders/__init__.py index 3159bfe656..af85283aa4 100644 --- a/physicsnemo/experimental/datapipes/healda/loaders/__init__.py +++ b/physicsnemo/experimental/datapipes/healda/loaders/__init__.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/physicsnemo/experimental/datapipes/healda/loaders/era5.py b/physicsnemo/experimental/datapipes/healda/loaders/era5.py index 09470cb094..6ee1c15b0b 100644 --- a/physicsnemo/experimental/datapipes/healda/loaders/era5.py +++ b/physicsnemo/experimental/datapipes/healda/loaders/era5.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -25,7 +26,10 @@ import numpy as np import pandas as pd -import xarray + +from physicsnemo.core.version_check import OptionalImport + +xarray = OptionalImport("xarray") from physicsnemo.experimental.datapipes.healda.loaders.zarr_loader import NO_LEVEL, ZarrLoader from physicsnemo.experimental.datapipes.healda.types import BatchInfo, TimeUnit, VariableConfig diff --git a/physicsnemo/experimental/datapipes/healda/loaders/ufs_obs.py b/physicsnemo/experimental/datapipes/healda/loaders/ufs_obs.py index 351db52777..27fbe7fba2 100644 --- a/physicsnemo/experimental/datapipes/healda/loaders/ufs_obs.py +++ b/physicsnemo/experimental/datapipes/healda/loaders/ufs_obs.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -39,9 +40,11 @@ import fsspec import numpy as np import pandas as pd -import pyarrow as pa -import pyarrow.compute as pc -import pyarrow.parquet as pq +from physicsnemo.core.version_check import OptionalImport + +pa = OptionalImport("pyarrow") +pc = OptionalImport("pyarrow.compute") +pq = OptionalImport("pyarrow.parquet") from physicsnemo.experimental.datapipes.healda.configs.combined_schema import ( GLOBAL_CHANNEL_ID, diff --git a/physicsnemo/experimental/datapipes/healda/loaders/zarr_loader.py b/physicsnemo/experimental/datapipes/healda/loaders/zarr_loader.py index e9c07aa810..220d19c5e5 100644 --- a/physicsnemo/experimental/datapipes/healda/loaders/zarr_loader.py +++ b/physicsnemo/experimental/datapipes/healda/loaders/zarr_loader.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -19,16 +20,20 @@ from zarr stores. """ +from __future__ import annotations + import asyncio import urllib.parse import cftime import numpy as np import pandas as pd -import xarray as xr -import zarr -import zarr.storage -from zarr.core.sync import sync + +from physicsnemo.core.version_check import OptionalImport + +xr = OptionalImport("xarray") +zarr = OptionalImport("zarr") +_zarr_sync = OptionalImport("zarr.core.sync") NO_LEVEL = -1 # sentinel for 2D (surface) variables that lack a pressure level @@ -83,7 +88,7 @@ def __init__( if isinstance(path, str) and _is_local(path): storage_options = None - self.group = sync( + self.group = _zarr_sync.sync( zarr.api.asynchronous.open_group( path, storage_options=storage_options, @@ -93,12 +98,12 @@ def __init__( ) if self.variables_3d: - self.inds = sync(self._get_vertical_indices(level_coord_name, levels)) + self.inds = _zarr_sync.sync(self._get_vertical_indices(level_coord_name, levels)) self._arrays = {} self._has_time = bool(self.variables_3d or self.variables_2d) if self._has_time: - time_num, self.units, self.calendar = sync(self._get_time()) + time_num, self.units, self.calendar = _zarr_sync.sync(self._get_time()) if np.issubdtype(time_num.dtype, np.datetime64): self.times = pd.DatetimeIndex(time_num) else: diff --git a/physicsnemo/experimental/datapipes/healda/prefetch.py b/physicsnemo/experimental/datapipes/healda/prefetch.py index d11c5712e5..273251feac 100644 --- a/physicsnemo/experimental/datapipes/healda/prefetch.py +++ b/physicsnemo/experimental/datapipes/healda/prefetch.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/physicsnemo/experimental/datapipes/healda/protocols.py b/physicsnemo/experimental/datapipes/healda/protocols.py index cf13295407..c12eb428d0 100644 --- a/physicsnemo/experimental/datapipes/healda/protocols.py +++ b/physicsnemo/experimental/datapipes/healda/protocols.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/physicsnemo/experimental/datapipes/healda/samplers.py b/physicsnemo/experimental/datapipes/healda/samplers.py index 5254e91c16..6ba75b6e13 100644 --- a/physicsnemo/experimental/datapipes/healda/samplers.py +++ b/physicsnemo/experimental/datapipes/healda/samplers.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/physicsnemo/experimental/datapipes/healda/time_utils.py b/physicsnemo/experimental/datapipes/healda/time_utils.py index c5ffa48195..9629816975 100644 --- a/physicsnemo/experimental/datapipes/healda/time_utils.py +++ b/physicsnemo/experimental/datapipes/healda/time_utils.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/physicsnemo/experimental/datapipes/healda/transforms/__init__.py b/physicsnemo/experimental/datapipes/healda/transforms/__init__.py index 3159bfe656..af85283aa4 100644 --- a/physicsnemo/experimental/datapipes/healda/transforms/__init__.py +++ b/physicsnemo/experimental/datapipes/healda/transforms/__init__.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/physicsnemo/experimental/datapipes/healda/transforms/era5_obs.py b/physicsnemo/experimental/datapipes/healda/transforms/era5_obs.py index 77adb2a87e..99bff19a4c 100644 --- a/physicsnemo/experimental/datapipes/healda/transforms/era5_obs.py +++ b/physicsnemo/experimental/datapipes/healda/transforms/era5_obs.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -31,12 +32,15 @@ import warnings import cftime -import earth2grid import numpy as np -import pyarrow as pa -import pyarrow.compute as pc import torch +from physicsnemo.core.version_check import OptionalImport + +earth2grid = OptionalImport("earth2grid") +pa = OptionalImport("pyarrow") +pc = OptionalImport("pyarrow.compute") + from physicsnemo.experimental.datapipes.healda.configs.static_data import load_lfrac, load_orography from physicsnemo.experimental.datapipes.healda.configs.variable_configs import VARIABLE_CONFIGS from physicsnemo.experimental.datapipes.healda.loaders.era5 import get_batch_info diff --git a/physicsnemo/experimental/datapipes/healda/transforms/obs_features.py b/physicsnemo/experimental/datapipes/healda/transforms/obs_features.py index 5345d4d827..c9b8457bc9 100644 --- a/physicsnemo/experimental/datapipes/healda/transforms/obs_features.py +++ b/physicsnemo/experimental/datapipes/healda/transforms/obs_features.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/physicsnemo/experimental/datapipes/healda/transforms/obs_features_ext.py b/physicsnemo/experimental/datapipes/healda/transforms/obs_features_ext.py index 3fb8a5fe65..a586cdbb51 100644 --- a/physicsnemo/experimental/datapipes/healda/transforms/obs_features_ext.py +++ b/physicsnemo/experimental/datapipes/healda/transforms/obs_features_ext.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/physicsnemo/experimental/datapipes/healda/transforms/obs_filtering.py b/physicsnemo/experimental/datapipes/healda/transforms/obs_filtering.py index 12debfe99f..1d611810c8 100644 --- a/physicsnemo/experimental/datapipes/healda/transforms/obs_filtering.py +++ b/physicsnemo/experimental/datapipes/healda/transforms/obs_filtering.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -19,8 +20,10 @@ channel metadata. """ -import pyarrow as pa -import pyarrow.compute as pc +from physicsnemo.core.version_check import OptionalImport + +pa = OptionalImport("pyarrow") +pc = OptionalImport("pyarrow.compute") from physicsnemo.experimental.datapipes.healda.configs.sensors import ( CONV_GPS_LEVEL2_CHANNELS, diff --git a/physicsnemo/experimental/datapipes/healda/types.py b/physicsnemo/experimental/datapipes/healda/types.py index ade0931806..6cb6b0c02a 100644 --- a/physicsnemo/experimental/datapipes/healda/types.py +++ b/physicsnemo/experimental/datapipes/healda/types.py @@ -1,11 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/pyproject.toml b/pyproject.toml index 5c78e4f283..649fc8eac4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ conflicts = [ torch-sparse = ["torch"] torch-cluster = ["torch"] torch-scatter = ["torch"] +earth2grid = ["setuptools", "torch"] [[tool.uv.index]] name = "nvidia" @@ -145,6 +146,8 @@ natten = [ { index = "natten-cu130-whl", extra = "natten-cu13" }, ] +earth2grid = { url = "https://github.com/NVlabs/earth2grid/archive/8fdff5a78d324f8d25afe224915301b3169bffe2.tar.gz" } + ##################################################################### # Flags Controlling the local build of physicsnemo ##################################################################### @@ -280,6 +283,12 @@ perf = [ "transformer_engine[pytorch]", ] +healda = [ + "nvidia-physicsnemo[datapipes-extras]", + "pyarrow>=14.0.0", + "earth2grid", +] + ##################################################################### # Linting configuration diff --git a/uv.lock b/uv.lock index 1123a38262..eb9c622c0c 100644 --- a/uv.lock +++ b/uv.lock @@ -294,8 +294,8 @@ name = "astunparse" version = "1.6.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "six", marker = "extra == 'extra-18-nvidia-physicsnemo-cu12' or extra == 'extra-18-nvidia-physicsnemo-cu13' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, - { name = "wheel", marker = "extra == 'extra-18-nvidia-physicsnemo-cu12' or extra == 'extra-18-nvidia-physicsnemo-cu13' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, + { name = "six" }, + { name = "wheel" }, ] sdist = { url = "https://files.pythonhosted.org/packages/f3/af/4182184d3c338792894f34a62672919db7ca008c89abee9b564dd34d8029/astunparse-1.6.3.tar.gz", hash = "sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872", size = 18290, upload-time = "2019-12-22T18:12:13.129Z" } wheels = [ @@ -988,7 +988,7 @@ name = "cuda-core" version = "0.6.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy", marker = "extra == 'extra-18-nvidia-physicsnemo-cu12' or extra == 'extra-18-nvidia-physicsnemo-cu13' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, + { name = "numpy" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/7a/69/8361fa2873fdc86d298a01f70ca3ea4a13f59711e75312dd0ce3d411c05f/cuda_core-0.6.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:70c3cd2ae0fa82cd6681be636051b247bcd4c4c3249c35bd982034cefb5adca3", size = 21597027, upload-time = "2026-02-23T18:59:24.216Z" }, @@ -1545,10 +1545,10 @@ name = "dm-tree" version = "0.1.9" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "absl-py", marker = "extra == 'extra-18-nvidia-physicsnemo-cu12' or extra == 'extra-18-nvidia-physicsnemo-cu13' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, - { name = "attrs", marker = "extra == 'extra-18-nvidia-physicsnemo-cu12' or extra == 'extra-18-nvidia-physicsnemo-cu13' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, - { name = "numpy", marker = "extra == 'extra-18-nvidia-physicsnemo-cu12' or extra == 'extra-18-nvidia-physicsnemo-cu13' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, - { name = "wrapt", marker = "extra == 'extra-18-nvidia-physicsnemo-cu12' or extra == 'extra-18-nvidia-physicsnemo-cu13' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, + { name = "absl-py" }, + { name = "attrs" }, + { name = "numpy" }, + { name = "wrapt" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a6/83/ce29720ccf934c6cfa9b9c95ebbe96558386e66886626066632b5e44afed/dm_tree-0.1.9.tar.gz", hash = "sha256:a4c7db3d3935a5a2d5e4b383fc26c6b0cd6f78c6d4605d3e7b518800ecd5342b", size = 35623, upload-time = "2025-01-30T20:45:37.13Z" } wheels = [ @@ -1613,6 +1613,50 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl", hash = "sha256:2a3175ce74a06109ff9307d90a230f81215cbac9a751f4d1c6194644b8204f9d", size = 21592, upload-time = "2024-05-23T14:13:55.283Z" }, ] +[[package]] +name = "earth2grid" +version = "2025.11.1+torch210" +source = { url = "https://github.com/NVlabs/earth2grid/archive/8fdff5a78d324f8d25afe224915301b3169bffe2.tar.gz" } +dependencies = [ + { name = "einops" }, + { name = "numpy" }, + { name = "scipy" }, + { name = "torch", version = "2.10.0", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-18-nvidia-physicsnemo-cu12' and extra == 'extra-18-nvidia-physicsnemo-cu13') or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13') or (extra != 'extra-18-nvidia-physicsnemo-cu12' and extra != 'extra-18-nvidia-physicsnemo-cu13')" }, + { name = "torch", version = "2.10.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-18-nvidia-physicsnemo-cu12' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, + { name = "torch", version = "2.10.0+cu130", source = { registry = "https://download.pytorch.org/whl/cu130" }, marker = "extra == 'extra-18-nvidia-physicsnemo-cu13' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, +] +sdist = { hash = "sha256:55f54084fd6a1fa505b68a2ab067aba7cab76f94d158dc8720dbb15c4e3ac235" } + +[package.metadata] +requires-dist = [ + { name = "black", marker = "extra == 'test'", specifier = ">=21.5b2" }, + { name = "bump2version", marker = "extra == 'dev'", specifier = ">=1.0.1" }, + { name = "coverage", marker = "extra == 'test'", specifier = ">=7.0.0" }, + { name = "einops", specifier = ">=0.7.0" }, + { name = "flake8", marker = "extra == 'test'", specifier = ">=3.9.2" }, + { name = "flake8-docstrings", marker = "extra == 'test'", specifier = ">=1.6.0" }, + { name = "isort", marker = "extra == 'test'", specifier = ">=5.8.0" }, + { name = "matplotlib", marker = "extra == 'test'" }, + { name = "matplotlib", marker = "extra == 'viz'" }, + { name = "mypy", marker = "extra == 'test'", specifier = ">=0.900" }, + { name = "netcdf4", marker = "extra == 'all'", specifier = ">=1.6.5" }, + { name = "numpy", specifier = ">=1.23.3" }, + { name = "pip", marker = "extra == 'dev'", specifier = ">=20.3.1" }, + { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=2.12.0" }, + { name = "pytest", marker = "extra == 'test'", specifier = ">=6.2.4" }, + { name = "pytest-cov", marker = "extra == 'test'", specifier = ">=2.12.0" }, + { name = "pytest-regtest", marker = "extra == 'test'", specifier = ">=1.5.1,<2" }, + { name = "pyvista", marker = "extra == 'viz'", specifier = ">=0.43.2" }, + { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.5" }, + { name = "scipy" }, + { name = "toml", marker = "extra == 'dev'", specifier = ">=0.10.2" }, + { name = "torch", specifier = ">=2.10,<2.11" }, + { name = "tox", marker = "extra == 'dev'", specifier = ">=3.20.1" }, + { name = "twine", marker = "extra == 'dev'", specifier = ">=3.3.0" }, + { name = "virtualenv", marker = "extra == 'dev'", specifier = ">=20.2.2" }, +] +provides-extras = ["all", "viz", "test", "dev"] + [[package]] name = "einops" version = "0.8.2" @@ -3311,8 +3355,8 @@ name = "numba" version = "0.61.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "llvmlite", marker = "extra == 'extra-18-nvidia-physicsnemo-cu12' or extra == 'extra-18-nvidia-physicsnemo-cu13' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, - { name = "numpy", marker = "extra == 'extra-18-nvidia-physicsnemo-cu12' or extra == 'extra-18-nvidia-physicsnemo-cu13' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, + { name = "llvmlite" }, + { name = "numpy" }, ] sdist = { url = "https://files.pythonhosted.org/packages/1c/a0/e21f57604304aa03ebb8e098429222722ad99176a4f979d34af1d1ee80da/numba-0.61.2.tar.gz", hash = "sha256:8750ee147940a6637b80ecf7f95062185ad8726c8c28a2295b8ec1160a196f7d", size = 2820615, upload-time = "2025-04-09T02:58:07.659Z" } wheels = [ @@ -3341,9 +3385,9 @@ dependencies = [ { name = "cuda-bindings", version = "12.9.4", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-18-nvidia-physicsnemo-cu12' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, { name = "cuda-bindings", version = "13.0.3", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'emscripten' and sys_platform != 'win32' and extra == 'extra-18-nvidia-physicsnemo-cu13') or (sys_platform != 'emscripten' and sys_platform != 'win32' and extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13') or (sys_platform == 'emscripten' and extra == 'extra-18-nvidia-physicsnemo-cu12' and extra == 'extra-18-nvidia-physicsnemo-cu13') or (sys_platform == 'emscripten' and extra == 'extra-18-nvidia-physicsnemo-cu13' and extra == 'extra-18-nvidia-physicsnemo-natten-cu12') or (sys_platform == 'emscripten' and extra == 'extra-18-nvidia-physicsnemo-cu13' and extra != 'extra-18-nvidia-physicsnemo-natten-cu13') or (sys_platform == 'emscripten' and extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13') or (sys_platform == 'win32' and extra == 'extra-18-nvidia-physicsnemo-cu12' and extra == 'extra-18-nvidia-physicsnemo-cu13') or (sys_platform == 'win32' and extra == 'extra-18-nvidia-physicsnemo-cu13' and extra == 'extra-18-nvidia-physicsnemo-natten-cu12') or (sys_platform == 'win32' and extra == 'extra-18-nvidia-physicsnemo-cu13' and extra != 'extra-18-nvidia-physicsnemo-natten-cu13') or (sys_platform == 'win32' and extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, { name = "cuda-bindings", version = "13.2.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'emscripten' and sys_platform != 'win32' and extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13') or (sys_platform == 'emscripten' and extra == 'extra-18-nvidia-physicsnemo-cu13' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13') or (sys_platform == 'emscripten' and extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13') or (sys_platform == 'win32' and extra == 'extra-18-nvidia-physicsnemo-cu13' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13') or (sys_platform == 'win32' and extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13') or (extra == 'extra-18-nvidia-physicsnemo-cu12' and extra == 'extra-18-nvidia-physicsnemo-cu13') or (extra == 'extra-18-nvidia-physicsnemo-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, - { name = "cuda-core", marker = "extra == 'extra-18-nvidia-physicsnemo-cu12' or extra == 'extra-18-nvidia-physicsnemo-cu13' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, - { name = "numba", marker = "extra == 'extra-18-nvidia-physicsnemo-cu12' or extra == 'extra-18-nvidia-physicsnemo-cu13' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, - { name = "packaging", marker = "extra == 'extra-18-nvidia-physicsnemo-cu12' or extra == 'extra-18-nvidia-physicsnemo-cu13' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, + { name = "cuda-core" }, + { name = "numba" }, + { name = "packaging" }, ] sdist = { url = "https://files.pythonhosted.org/packages/aa/cd/9017506815047ee30ad404e3c469788676a6abeaaff8014d07a0180cdfbc/numba_cuda-0.22.2.tar.gz", hash = "sha256:e8c19bc1174dfc3596259381fa708f1c3397a618bdbbaa5d068bcc56af8fd921", size = 1340447, upload-time = "2025-12-19T01:08:57.73Z" } wheels = [ @@ -4661,6 +4705,16 @@ gnns = [ { name = "vtk" }, { name = "wandb" }, ] +healda = [ + { name = "dask" }, + { name = "earth2grid" }, + { name = "netcdf4" }, + { name = "pyarrow" }, + { name = "tensordict" }, + { name = "tfrecord" }, + { name = "xarray" }, + { name = "zarr" }, +] mesh-extras = [ { name = "matplotlib" }, { name = "pyacvd" }, @@ -4726,6 +4780,8 @@ requires-dist = [ { name = "cupy-cuda12x", marker = "extra == 'cu12'" }, { name = "cupy-cuda13x", marker = "extra == 'cu13'", specifier = "==13.6.0" }, { name = "dask", marker = "extra == 'datapipes-extras'" }, + { name = "dask", marker = "extra == 'healda'" }, + { name = "earth2grid", marker = "extra == 'healda'", url = "https://github.com/NVlabs/earth2grid/archive/8fdff5a78d324f8d25afe224915301b3169bffe2.tar.gz" }, { name = "einops", specifier = ">=0.8.1" }, { name = "gitpython", specifier = ">=3.1.40" }, { name = "h5py", specifier = ">=3.15.1" }, @@ -4744,6 +4800,7 @@ requires-dist = [ { name = "natten", marker = "extra == 'natten-cu12'", specifier = ">=0.21.5", index = "https://whl.natten.org/cu128/torch2.10.0", conflict = { package = "nvidia-physicsnemo", extra = "natten-cu12" } }, { name = "natten", marker = "extra == 'natten-cu13'", specifier = ">=0.21.5", index = "https://whl.natten.org/cu130/torch2.10.0", conflict = { package = "nvidia-physicsnemo", extra = "natten-cu13" } }, { name = "netcdf4", marker = "extra == 'datapipes-extras'" }, + { name = "netcdf4", marker = "extra == 'healda'" }, { name = "numpy", specifier = ">=1.22.4" }, { name = "nvidia-dali-cuda120", marker = "extra == 'cu12'", index = "https://pypi.nvidia.com/" }, { name = "nvidia-dali-cuda130", marker = "extra == 'cu13'", index = "https://pypi.nvidia.com/" }, @@ -4753,6 +4810,7 @@ requires-dist = [ { name = "packaging", specifier = ">=24.2" }, { name = "pandas", specifier = ">=2.2.0" }, { name = "pyacvd", marker = "extra == 'mesh-extras'", specifier = ">=0.3.2" }, + { name = "pyarrow", marker = "extra == 'healda'", specifier = ">=14.0.0" }, { name = "pylibraft-cu12", marker = "extra == 'cu12'", index = "https://pypi.nvidia.com/" }, { name = "pylibraft-cu13", marker = "extra == 'cu13'", index = "https://pypi.nvidia.com/" }, { name = "pyvista", marker = "extra == 'gnns'", specifier = ">=0.46.4" }, @@ -4769,8 +4827,10 @@ requires-dist = [ { name = "stl", marker = "extra == 'utils-extras'" }, { name = "tensordict", specifier = ">=0.10.0" }, { name = "tensordict", marker = "extra == 'datapipes-extras'", specifier = ">=0.11.0" }, + { name = "tensordict", marker = "extra == 'healda'", specifier = ">=0.11.0" }, { name = "termcolor", specifier = ">=3.2.0" }, { name = "tfrecord", marker = "extra == 'datapipes-extras'" }, + { name = "tfrecord", marker = "extra == 'healda'" }, { name = "timm", specifier = ">=1.0.22" }, { name = "torch", specifier = ">=2.5.0" }, { name = "torch", marker = "extra == 'cu12'", specifier = ">=2.5.0", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "nvidia-physicsnemo", extra = "cu12" } }, @@ -4796,9 +4856,11 @@ requires-dist = [ { name = "wandb", marker = "extra == 'utils-extras'" }, { name = "warp-lang", specifier = ">=1.5.0" }, { name = "xarray", marker = "extra == 'datapipes-extras'", specifier = ">=2025.6.1" }, + { name = "xarray", marker = "extra == 'healda'", specifier = ">=2025.6.1" }, { name = "zarr", marker = "extra == 'datapipes-extras'", specifier = ">=3.0.0" }, + { name = "zarr", marker = "extra == 'healda'", specifier = ">=3.0.0" }, ] -provides-extras = ["cu12", "cu13", "datapipes-extras", "gnns", "mesh-extras", "model-extras", "natten-cu12", "natten-cu13", "nn-extras", "perf", "utils-extras"] +provides-extras = ["cu12", "cu13", "datapipes-extras", "gnns", "healda", "mesh-extras", "model-extras", "natten-cu12", "natten-cu13", "nn-extras", "perf", "utils-extras"] [package.metadata.requires-dev] dev = [ @@ -6531,18 +6593,18 @@ dependencies = [ { name = "typing-extensions", marker = "extra == 'extra-18-nvidia-physicsnemo-cu12' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:85ed7944655ea6fd69377692e9cbfd7bba28d99696ceae79985e7caa99cf0a95" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:1d01ffaebf64715c0f507a39463149cb19e596ff702bd4bcf862601f2881dabc" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp311-cp311-win_amd64.whl", hash = "sha256:3523fda6e2cfab2b04ae09b1424681358e508bb3faa11ceb67004113d5e7acad" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:6f09cdf2415516be028ae82e6b985bcfc3eac37bc52ab401142689f6224516ca" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:628e89bd5110ced7debee2a57c69959725b7fbc64eab81a39dd70e46c7e28ba5" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp312-cp312-win_amd64.whl", hash = "sha256:fbde8f6a9ec8c76979a0d14df21c10b9e5cab6f0d106a73ca73e2179bc597cae" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:bdbcc703382f948e951c063448c9406bf38ce66c41dd698d9e2733fcf96c037a" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:7b4bd23ed63de97456fcc81c26fea9f02ee02ce1112111c4dac0d8cfe574b23e" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-win_amd64.whl", hash = "sha256:4d1b0b49c54223c7c04050b49eac141d77b6edbc34aea1dfc74a6fdb661baa8c" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:f1f8b840c64b645a4bc61a393db48effb9c92b2dc26c8373873911f0750d1ea7" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:23f58258012bcf1c349cb22af387e33aadca7f83ea617b080e774eb41e4fe8ff" }, - { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:01b216e097b17a5277cfb47c383cdcacf06abeadcb0daca0c76b59e72854c3b6" }, + { url = "https://download-r2.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:85ed7944655ea6fd69377692e9cbfd7bba28d99696ceae79985e7caa99cf0a95" }, + { url = "https://download-r2.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:1d01ffaebf64715c0f507a39463149cb19e596ff702bd4bcf862601f2881dabc" }, + { url = "https://download-r2.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp311-cp311-win_amd64.whl", hash = "sha256:3523fda6e2cfab2b04ae09b1424681358e508bb3faa11ceb67004113d5e7acad" }, + { url = "https://download-r2.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:6f09cdf2415516be028ae82e6b985bcfc3eac37bc52ab401142689f6224516ca" }, + { url = "https://download-r2.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:628e89bd5110ced7debee2a57c69959725b7fbc64eab81a39dd70e46c7e28ba5" }, + { url = "https://download-r2.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp312-cp312-win_amd64.whl", hash = "sha256:fbde8f6a9ec8c76979a0d14df21c10b9e5cab6f0d106a73ca73e2179bc597cae" }, + { url = "https://download-r2.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:bdbcc703382f948e951c063448c9406bf38ce66c41dd698d9e2733fcf96c037a" }, + { url = "https://download-r2.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:7b4bd23ed63de97456fcc81c26fea9f02ee02ce1112111c4dac0d8cfe574b23e" }, + { url = "https://download-r2.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-win_amd64.whl", hash = "sha256:4d1b0b49c54223c7c04050b49eac141d77b6edbc34aea1dfc74a6fdb661baa8c" }, + { url = "https://download-r2.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:f1f8b840c64b645a4bc61a393db48effb9c92b2dc26c8373873911f0750d1ea7" }, + { url = "https://download-r2.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:23f58258012bcf1c349cb22af387e33aadca7f83ea617b080e774eb41e4fe8ff" }, + { url = "https://download-r2.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:01b216e097b17a5277cfb47c383cdcacf06abeadcb0daca0c76b59e72854c3b6" }, ] [[package]] @@ -6625,18 +6687,18 @@ dependencies = [ { name = "typing-extensions", marker = "extra == 'extra-18-nvidia-physicsnemo-cu13' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:ea3239d544b2e569a8f47db5c7fa4fd42a2fe96aefb84bb1eda45ce213020fd2" }, - { url = "https://download.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:22cfa45e73f1e8c64f4012737987a727d01d152121b93d196b0ca22f39a3f8e3" }, - { url = "https://download.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp311-cp311-win_amd64.whl", hash = "sha256:218ae0f323d5ebe8f2770e46cbfb7bbff9af2c8d192d5187878d0964d43c8b71" }, - { url = "https://download.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:4fc8f67637f4c92b989a07d80ffe755e79a3510ca02ebf23ce66396fb277c88d" }, - { url = "https://download.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:858f0cbcc78d726fea9499eb3464faa98392fa093845a3262209bd226b7844d6" }, - { url = "https://download.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp312-cp312-win_amd64.whl", hash = "sha256:224649fa0ab181ec483cc368e3303dda1760e4ba31bea806b88979f855436aaa" }, - { url = "https://download.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:75780283308df9fede371eeda01e9607c8862a1803a2f2f31a08a2c0deaed342" }, - { url = "https://download.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:7e0d9922e9e91f780b2761a0c5ebac3c15c9740bab042e1b59149afa6d6474eb" }, - { url = "https://download.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp313-cp313-win_amd64.whl", hash = "sha256:48af94af745a9dd9b42be81ea15b56aba981666bcfe10394dceca6d9476a50fa" }, - { url = "https://download.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:46699da91f0367d8dfa1b606cb0352aaf190b5853f463010e75ff08f15a94e7d" }, - { url = "https://download.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:775d1fff07e302fb669d555a5005f781aa460aa80dff7a512e8e6e723f9def83" }, - { url = "https://download.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp313-cp313t-win_amd64.whl", hash = "sha256:b38e5b505b015903a51c2b3f12e50a9f152f92fe7e3992e79f504138cf90601d" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:ea3239d544b2e569a8f47db5c7fa4fd42a2fe96aefb84bb1eda45ce213020fd2" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:22cfa45e73f1e8c64f4012737987a727d01d152121b93d196b0ca22f39a3f8e3" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp311-cp311-win_amd64.whl", hash = "sha256:218ae0f323d5ebe8f2770e46cbfb7bbff9af2c8d192d5187878d0964d43c8b71" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:4fc8f67637f4c92b989a07d80ffe755e79a3510ca02ebf23ce66396fb277c88d" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:858f0cbcc78d726fea9499eb3464faa98392fa093845a3262209bd226b7844d6" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp312-cp312-win_amd64.whl", hash = "sha256:224649fa0ab181ec483cc368e3303dda1760e4ba31bea806b88979f855436aaa" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:75780283308df9fede371eeda01e9607c8862a1803a2f2f31a08a2c0deaed342" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:7e0d9922e9e91f780b2761a0c5ebac3c15c9740bab042e1b59149afa6d6474eb" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp313-cp313-win_amd64.whl", hash = "sha256:48af94af745a9dd9b42be81ea15b56aba981666bcfe10394dceca6d9476a50fa" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:46699da91f0367d8dfa1b606cb0352aaf190b5853f463010e75ff08f15a94e7d" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:775d1fff07e302fb669d555a5005f781aa460aa80dff7a512e8e6e723f9def83" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torch-2.10.0%2Bcu130-cp313-cp313t-win_amd64.whl", hash = "sha256:b38e5b505b015903a51c2b3f12e50a9f152f92fe7e3992e79f504138cf90601d" }, ] [[package]] @@ -6924,9 +6986,9 @@ name = "treelite" version = "4.7.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy", marker = "extra == 'extra-18-nvidia-physicsnemo-cu12' or extra == 'extra-18-nvidia-physicsnemo-cu13' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, - { name = "packaging", marker = "extra == 'extra-18-nvidia-physicsnemo-cu12' or extra == 'extra-18-nvidia-physicsnemo-cu13' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, - { name = "scipy", marker = "extra == 'extra-18-nvidia-physicsnemo-cu12' or extra == 'extra-18-nvidia-physicsnemo-cu13' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "scipy" }, ] sdist = { url = "https://files.pythonhosted.org/packages/9e/dd/78886789f87a6d9cb3d78241fdd750c13123ea4c64df03bcc717ee5b5d26/treelite-4.7.0.tar.gz", hash = "sha256:6d1a0d990f4972e77bad6b42a6e0b7d68527d790564bd42d7d8d48ae1f14dc4c", size = 110239, upload-time = "2026-03-06T23:25:38.477Z" } wheels = [ @@ -7146,7 +7208,7 @@ name = "wheel" version = "0.46.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "packaging", marker = "extra == 'extra-18-nvidia-physicsnemo-cu12' or extra == 'extra-18-nvidia-physicsnemo-cu13' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, + { name = "packaging" }, ] sdist = { url = "https://files.pythonhosted.org/packages/89/24/a2eb353a6edac9a0303977c4cb048134959dd2a51b48a269dfc9dde00c8a/wheel-0.46.3.tar.gz", hash = "sha256:e3e79874b07d776c40bd6033f8ddf76a7dad46a7b8aa1b2787a83083519a1803", size = 60605, upload-time = "2026-01-22T12:39:49.136Z" } wheels = [ From 71eccaf411dc1ad5ad4bb85d23f70bfff379982a Mon Sep 17 00:00:00 2001 From: Peter Harrington Date: Wed, 8 Apr 2026 18:46:40 -0700 Subject: [PATCH 03/10] Update precommit for examples tests --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 09803226dc..41ba5f4539 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,7 @@ repos: language: python types: [python] additional_dependencies: ['interrogate==1.7.0'] - exclude: ^docs/|^physicsnemo/experimental/|^test/ + exclude: ^docs/|^physicsnemo/experimental/|^test/|^examples/.*/test/ - repo: https://github.com/igorshubovych/markdownlint-cli rev: v0.35.0 From 1caa5ef2e8d41a1cf08f14566830e299c0d91e72 Mon Sep 17 00:00:00 2001 From: Peter Harrington Date: Tue, 14 Apr 2026 16:15:58 -0700 Subject: [PATCH 04/10] integrate restartable sampler, other updates, migrate tests --- .github/CODEOWNERS | 1 + examples/weather/healda/README.md | 93 ++++++++-- examples/weather/healda/test/test_samplers.py | 147 --------------- .../experimental/datapipes/healda/__init__.py | 23 +-- .../experimental/datapipes/healda/dataset.py | 11 ++ .../experimental/datapipes/healda/samplers.py | 170 +++++------------- .../datapipes/healda/time_utils.py | 30 ++++ .../datapipes/healda/transforms/era5_obs.py | 163 ++++++++++------- .../experimental/datapipes/healda/types.py | 49 ++--- pyproject.toml | 9 +- .../datapipes/healda/__init__.py | 8 - .../datapipes/healda}/test_combined_schema.py | 4 +- .../datapipes/healda}/test_features.py | 4 +- .../datapipes/healda}/test_indexing.py | 0 .../datapipes/healda}/test_obs_filtering.py | 4 +- .../datapipes/healda}/test_prefetch.py | 0 test/datapipes/healda/test_samplers.py | 113 ++++++++++++ .../datapipes/healda}/test_time_utils.py | 5 +- .../datapipes/healda}/test_types.py | 26 +-- uv.lock | 50 +----- 20 files changed, 446 insertions(+), 464 deletions(-) delete mode 100644 examples/weather/healda/test/test_samplers.py rename examples/weather/healda/test/conftest.py => test/datapipes/healda/__init__.py (85%) rename {examples/weather/healda/test => test/datapipes/healda}/test_combined_schema.py (97%) rename {examples/weather/healda/test => test/datapipes/healda}/test_features.py (98%) rename {examples/weather/healda/test => test/datapipes/healda}/test_indexing.py (100%) rename {examples/weather/healda/test => test/datapipes/healda}/test_obs_filtering.py (98%) rename {examples/weather/healda/test => test/datapipes/healda}/test_prefetch.py (100%) create mode 100644 test/datapipes/healda/test_samplers.py rename {examples/weather/healda/test => test/datapipes/healda}/test_time_utils.py (97%) rename {examples/weather/healda/test => test/datapipes/healda}/test_types.py (87%) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 1bf35d9d0a..9391a46911 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -246,6 +246,7 @@ physicsnemo/models/vfgn/ @mnabian # Experimental deliberately has no codeowner physicsnemo/experimental/ +physicsnemo/experimental/datapipes/healda/ @pzharrington # ============================================================================== # EXAMPLES - Active Learning diff --git a/examples/weather/healda/README.md b/examples/weather/healda/README.md index 06cc298c36..fce9c28f98 100644 --- a/examples/weather/healda/README.md +++ b/examples/weather/healda/README.md @@ -1,6 +1,6 @@ # HealDA — AI-based Data Assimilation on the HEALPix Grid -> **This recipe is under active construction.** +> **🏗️ This recipe is under active construction. 🏗️** > Structure and functionality are subject to changes. HealDA is a stateless assimilation model that produces a single @@ -8,9 +8,6 @@ global weather analysis from conventional and satellite observations. It operates on a HEALPix level-6 padded XY grid and outputs ERA5-compatible atmospheric variables. -This example provides a recipe to train HealDA, with support -for extension to custom data. - ## Setup Start by installing PhysicsNeMo (if not already installed) with @@ -36,11 +33,9 @@ ObsERA5Dataset(era5_data, obs_loader, transform) | Temporal windowing via FrameIndexGenerator | __getitems__ -> get() per index -> transform.transform() v -ChunkedDistributedSampler (contiguous chunks for cache locality) - | -DataLoader (1 worker each, pin_memory, persistent_workers) +RestartableDistributedSampler (stateful distributed sampling with checkpointing) | -RoundRobinLoader (interleaves per-worker DataLoaders) +DataLoader (pin_memory, persistent_workers) | prefetch_map(loader, transform.device_transform) | @@ -83,8 +78,7 @@ class MyTransform: | `UFSUnifiedLoader` | `loaders.ufs_obs` | Parquet obs loader | | `ERA5Loader` | `loaders.era5` | Async ERA5 zarr loader | | `ERA5ObsTransform` | `transforms.era5_obs` | Two-stage transform | -| `ChunkedDistributedSampler` | `samplers` | Distributed sampler | -| `RoundRobinLoader` | `samplers` | Multi-loader interleave | +| `RestartableDistributedSampler` | `samplers` | Stateful distributed sampler | | `prefetch_map` | `prefetch` | CUDA stream prefetching | All modules above are under @@ -125,7 +119,84 @@ from physicsnemo.experimental.datapipes.healda.configs.variable_configs import ( dataset = ObsERA5Dataset( era5_data=era5_xr["data"], obs_loader=GOESRadianceLoader(...), - transform=ERA5ObsTransform(...), + transform=ERA5ObsTransform(sensors=["goes"], ...), + variable_config=VARIABLE_CONFIGS["era5"], +) +``` + +### Putting It Together + +A complete training pipeline wires together all the +components — dataset, sampler, DataLoader, and GPU prefetch: + +```python +import torch +from torch.utils.data import DataLoader + +from physicsnemo.experimental.datapipes.healda import ( + ObsERA5Dataset, + RestartableDistributedSampler, + identity_collate, + prefetch_map, +) +from physicsnemo.experimental.datapipes.healda.loaders.ufs_obs import ( + UFSUnifiedLoader, +) +from physicsnemo.experimental.datapipes.healda.transforms.era5_obs import ( + ERA5ObsTransform, +) +from physicsnemo.experimental.datapipes.healda.configs.variable_configs import ( + VARIABLE_CONFIGS, +) + +sensors = ["atms", "mhs", "conv"] + +# 1. Build loaders +obs_loader = UFSUnifiedLoader( + data_path="/path/to/processed_obs", + sensors=sensors, + normalization="zscore", + obs_context_hours=(-21, 3), +) +transform = ERA5ObsTransform( variable_config=VARIABLE_CONFIGS["era5"], + sensors=sensors, ) + +# 2. Build dataset +dataset = ObsERA5Dataset( + era5_data=era5_xr["data"], + obs_loader=obs_loader, + transform=transform, + variable_config=VARIABLE_CONFIGS["era5"], + split="train", +) + +# 3. Sampler + DataLoader +sampler = RestartableDistributedSampler( + dataset, rank=rank, num_replicas=world_size, +) +sampler.set_epoch(0) +dataloader = DataLoader( + dataset, + sampler=sampler, + batch_size=2, + num_workers=8, + collate_fn=identity_collate, + pin_memory=True, + persistent_workers=True, +) + +# 4. GPU prefetch (hides CPU→GPU transfer behind training) +device = torch.device("cuda") +loader = prefetch_map( + dataloader, + lambda batch: transform.device_transform(batch, device), + queue_size=1, +) + +# 5. Training loop — batches arrive GPU-ready +for batch in loader: + loss = model(batch) + ... ``` diff --git a/examples/weather/healda/test/test_samplers.py b/examples/weather/healda/test/test_samplers.py deleted file mode 100644 index 32922f535d..0000000000 --- a/examples/weather/healda/test/test_samplers.py +++ /dev/null @@ -1,147 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for ChunkedDistributedSampler and RoundRobinLoader.""" - -import itertools - -import torch -import torch.utils.data - -from physicsnemo.experimental.datapipes.healda.samplers import ( - ChunkedDistributedSampler, - RoundRobinLoader, -) - - -def test_chunked_sampler_sequential(): - """Indices within a chunk must be consecutive.""" - s = ChunkedDistributedSampler(list(range(100)), chunk_size=5) - it = iter(s) - visited = set() - for chunk in range(20): - last_i = 0 - for i in range(5): - idx = next(it) - if i > 0: - assert idx - last_i == 1 - last_i = idx - visited.add(idx) - - assert len(visited) == 100 - - -def test_chunked_sampler_with_islice(): - """Verify iter(sampler) continues rather than resetting.""" - dataset = list(range(100)) - sampler = ChunkedDistributedSampler(dataset, chunk_size=10, drop_last=False) - - iterator = iter(sampler) - first_10 = list(itertools.islice(iterator, 10)) - assert first_10 == list(range(10)) - - # Re-calling iter should continue, not restart - iterator2 = iter(sampler) - next_10 = list(itertools.islice(iterator2, 10)) - assert next_10 == list(range(10, 20)) - assert first_10 != next_10 - - -def test_shuffle_within_chunk(): - """Within-chunk shuffle randomizes order but preserves membership.""" - s = ChunkedDistributedSampler( - list(range(100)), - chunk_size=10, - shuffle=False, - shuffle_within_chunk=True, - seed=42, - ) - - indices = list(s) - assert sorted(indices) == list(range(100)) - - first_chunk = indices[:10] - assert sorted(first_chunk) == list(range(10)) - assert first_chunk != list(range(10)) # order should differ - - -def test_shuffle_epoch_changes_chunks(): - """Epoch auto-increment produces different chunk orderings.""" - s = ChunkedDistributedSampler( - list(range(100)), - chunk_size=10, - shuffle=True, - shuffle_within_chunk=True, - seed=42, - ) - - epoch1 = list(s) - epoch2 = list(s) - - assert sorted(epoch1) == list(range(100)) - assert sorted(epoch2) == list(range(100)) - assert sorted(epoch1[:10]) != sorted(epoch2[:10]) - - -# --------------------------------------------------------------------------- -# RoundRobinLoader tests -# --------------------------------------------------------------------------- - - -def test_round_robin_loader(): - """Round-robin interleaving across three loaders.""" - loader1 = torch.utils.data.DataLoader(list(range(0, 10)), batch_size=2) - loader2 = torch.utils.data.DataLoader(list(range(10, 15)), batch_size=2) - loader3 = torch.utils.data.DataLoader(list(range(15, 20)), batch_size=2) - - rr = RoundRobinLoader([loader1, loader2, loader3]) - assert len(rr) == len(loader1) + len(loader2) + len(loader3) - - batches = list(rr) - assert len(batches) == 11 - - # First round - assert torch.equal(batches[0], torch.tensor([0, 1])) - assert torch.equal(batches[1], torch.tensor([10, 11])) - assert torch.equal(batches[2], torch.tensor([15, 16])) - - -def test_round_robin_loader_uneven(): - """Uneven loader lengths — shorter ones drop out first.""" - loader1 = torch.utils.data.DataLoader(list(range(0, 20)), batch_size=2) - loader2 = torch.utils.data.DataLoader(list(range(20, 22)), batch_size=2) - - rr = RoundRobinLoader([loader1, loader2]) - batches = list(rr) - assert len(batches) == 11 - - assert torch.equal(batches[0], torch.tensor([0, 1])) - assert torch.equal(batches[1], torch.tensor([20, 21])) - assert torch.equal(batches[2], torch.tensor([2, 3])) - - -def test_round_robin_loader_empty(): - rr = RoundRobinLoader([]) - assert list(rr) == [] - - -def test_round_robin_loader_single(): - loader = torch.utils.data.DataLoader(list(range(10)), batch_size=3) - rr = RoundRobinLoader([loader]) - expected = list(torch.utils.data.DataLoader(list(range(10)), batch_size=3)) - actual = list(rr) - assert len(actual) == len(expected) - for a, e in zip(actual, expected): - assert torch.equal(a, e) diff --git a/physicsnemo/experimental/datapipes/healda/__init__.py b/physicsnemo/experimental/datapipes/healda/__init__.py index 14315e6e3b..ce7cf29af4 100644 --- a/physicsnemo/experimental/datapipes/healda/__init__.py +++ b/physicsnemo/experimental/datapipes/healda/__init__.py @@ -25,8 +25,7 @@ - :class:`UFSUnifiedLoader` — parquet-based observation loader - :class:`ERA5ObsTransform` — two-stage transform with Triton feature kernels - :func:`prefetch_map` — background CUDA stream prefetching -- :class:`ChunkedDistributedSampler` — cache-friendly distributed sampler -- :class:`RoundRobinLoader` — multi-loader round-robin interleaving +- :class:`RestartableDistributedSampler` — stateful distributed sampler with checkpoint support Protocols for custom loaders/transforms: @@ -35,12 +34,9 @@ - :class:`DeviceTransform` — GPU-side batch transform """ -from physicsnemo.experimental.datapipes.healda.dataset import ObsERA5Dataset -from physicsnemo.experimental.datapipes.healda.indexing import ( - FrameIndexGenerator, - MultiCoordIndex, - get_flat_indexer, - split_array_contiguous, +from physicsnemo.experimental.datapipes.healda.dataset import ( + ObsERA5Dataset, + identity_collate, ) from physicsnemo.experimental.datapipes.healda.prefetch import prefetch_map from physicsnemo.experimental.datapipes.healda.protocols import ( @@ -49,8 +45,7 @@ Transform, ) from physicsnemo.experimental.datapipes.healda.samplers import ( - ChunkedDistributedSampler, - RoundRobinLoader, + RestartableDistributedSampler, ) from physicsnemo.experimental.datapipes.healda.types import ( Batch, @@ -79,10 +74,6 @@ "split_by_sensor", # Infrastructure "prefetch_map", - "ChunkedDistributedSampler", - "RoundRobinLoader", - "FrameIndexGenerator", - "MultiCoordIndex", - "get_flat_indexer", - "split_array_contiguous", + "RestartableDistributedSampler", + "identity_collate", ] diff --git a/physicsnemo/experimental/datapipes/healda/dataset.py b/physicsnemo/experimental/datapipes/healda/dataset.py index 1ad9103766..9c68556e94 100644 --- a/physicsnemo/experimental/datapipes/healda/dataset.py +++ b/physicsnemo/experimental/datapipes/healda/dataset.py @@ -189,3 +189,14 @@ def __getitems__(self, indexes): """ times, objs = zip(*[self.get(i) for i in indexes]) return self.transform.transform(times, objs) + + +def identity_collate(obj): + """Identity collate function for use with ``ObsERA5Dataset``. + + Since ``__getitems__`` already returns an assembled batch dict, no + collation is needed. Pass this as ``collate_fn`` to the DataLoader:: + + DataLoader(dataset, sampler=sampler, collate_fn=identity_collate) + """ + return obj diff --git a/physicsnemo/experimental/datapipes/healda/samplers.py b/physicsnemo/experimental/datapipes/healda/samplers.py index 6ba75b6e13..9ed54c76f7 100644 --- a/physicsnemo/experimental/datapipes/healda/samplers.py +++ b/physicsnemo/experimental/datapipes/healda/samplers.py @@ -13,167 +13,79 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Samplers and multi-loader utilities for cache-friendly distributed training. +"""Stateful distributed sampler with checkpoint support. -``ChunkedDistributedSampler`` yields indices in contiguous chunks so that -data backed by chunked storage (e.g. zarr) benefits from sequential I/O. - -``RoundRobinLoader`` interleaves multiple DataLoaders — typically one per -worker, each with its own ``ChunkedDistributedSampler`` — to provide an -iterable-style interface over map-style datasets with per-worker chunk -affinity. +``RestartableDistributedSampler`` is a distributed sampler that tracks its +iteration state so that training can be resumed from a checkpoint without +replaying already-seen samples. """ -import random - import torch -import torch.distributed import torch.utils.data -# --------------------------------------------------------------------------- -# Chunked distributed sampler -# --------------------------------------------------------------------------- - - -class ChunkedDistributedSampler(torch.utils.data.Sampler): - """A distributed sampler that yields indices in contiguous chunks. +class RestartableDistributedSampler(torch.utils.data.Sampler): + """A stateful distributed sampler that automatically loops over the dataset. - Within each chunk, indices are sequential (optionally shuffled within the - chunk). Chunks themselves can be shuffled across epochs. This pattern is - critical when the underlying dataset caches data at chunk granularity, as - it ensures sequential access within each cache window. - - The sampler is infinite: after exhausting all chunks it advances the epoch - and re-shuffles. + Each epoch generates a rank-specific random permutation. The sampler + tracks its position within the permutation so that ``restart()`` can + resume from an exact checkpoint. Args: - dataset: Map-style dataset. - chunk_size: Number of contiguous indices per chunk. + dataset: Map-style dataset (used only for ``len``). rank: This worker's global rank. - num_replicas: Total number of workers (data-parallel * per-GPU workers). - shuffle: Whether to shuffle the order of chunks. - shuffle_within_chunk: Whether to shuffle indices within each chunk. - drop_last: Whether to drop incomplete trailing chunks. - seed: Random seed (broadcast from rank 0 when distributed). - sampler_fn: Optional custom sampler over chunk indices. + num_replicas: Total number of data-parallel workers. + shuffle: Whether to shuffle (always True in practice). + drop_last: Drop remainder so all ranks get the same count. + seed: Base random seed. """ def __init__( self, dataset: torch.utils.data.Dataset, - chunk_size: int = 1, rank=0, num_replicas=1, - shuffle=False, - shuffle_within_chunk=False, + shuffle=True, drop_last=True, seed=42, - sampler_fn=None, ): super().__init__() - self.n = len(dataset) - nchunks = self.n // chunk_size - chunks = list(range(nchunks)) - - if torch.distributed.is_initialized(): - seed = torch.tensor(seed).cuda() - torch.distributed.broadcast(seed, src=0) - seed = seed.item() - - self._chunk_sampler = ( - sampler_fn(chunks) - if sampler_fn is not None - else torch.utils.data.DistributedSampler( - chunks, - num_replicas=num_replicas, - rank=rank, - shuffle=shuffle, - seed=seed, - drop_last=drop_last, - ) - ) - self.chunk_size = chunk_size - self.shuffle_within_chunk = shuffle_within_chunk + self.iteration = 0 + self.epoch = 0 + self.len = len(dataset) self.seed = seed + self.permutation = None self.rank = rank - self.epoch = 0 - self.index_within_chunk = 0 - self._chunk_iter = iter(self._chunk_sampler) - self._current_chunk_indices = None + self.num_replicas = num_replicas - if self.shuffle_within_chunk: - self.rng = random.Random(seed + rank) + def __len__(self): + return self.len // self.num_replicas def set_epoch(self, epoch): - try: - self._chunk_sampler.set_epoch(epoch) - except AttributeError: - pass self.epoch = epoch + self.iteration = 0 + rng = torch.Generator().manual_seed(self.seed + self.epoch + self.rank) + permutation = torch.randperm(self.len, generator=rng) - def __len__(self): - return self.n + rem = self.len % self.num_replicas + if rem > 0: + permutation = permutation[:-rem] + self.permutation = permutation[self.rank :: self.num_replicas] + + def restart(self, epoch, iteration, seed=None): + """Resume from a checkpoint.""" + self.seed = seed or self.seed + self.set_epoch(epoch) + self.iteration = iteration def __iter__(self): return self def __next__(self): - if self.index_within_chunk == 0: - try: - self.active_chunk = next(self._chunk_iter) - except StopIteration: - self.set_epoch(self.epoch + 1) - self._chunk_iter = iter(self._chunk_sampler) - raise StopIteration() - - chunk_start = self.active_chunk * self.chunk_size - self._current_chunk_indices = list( - range(chunk_start, chunk_start + self.chunk_size) - ) - - if self.shuffle_within_chunk: - self.rng.shuffle(self._current_chunk_indices) - - i = self._current_chunk_indices[self.index_within_chunk] - self.index_within_chunk = (self.index_within_chunk + 1) % self.chunk_size - return i - - -# --------------------------------------------------------------------------- -# Round-robin loader -# --------------------------------------------------------------------------- - - -class RoundRobinLoader(torch.utils.data.IterableDataset): - """Round-robin interleaving of multiple map-style DataLoaders. - - This converts map-style datasets to iterable-style by cycling through - the given DataLoaders in round-robin order, removing exhausted ones - until all are done. - - Typical usage: create one ``DataLoader`` per worker, each backed by a - ``ChunkedDistributedSampler`` with a unique rank, then wrap them with - ``RoundRobinLoader``. - - Args: - dataloaders: List of DataLoader instances to interleave. - """ - - def __init__(self, dataloaders: list[torch.utils.data.DataLoader]): - super().__init__() - self.dataloaders = dataloaders - - def __len__(self): - return sum(len(dl) for dl in self.dataloaders) - - def __iter__(self): - iterators = [iter(dl) for dl in self.dataloaders] - active_indices = list(range(len(self.dataloaders))) + if self.iteration >= len(self): + self.set_epoch(self.epoch + 1) + raise StopIteration() - while active_indices: - for idx in list(active_indices): - try: - yield next(iterators[idx]) - except StopIteration: - active_indices.remove(idx) + idx = self.permutation[self.iteration].item() + self.iteration += 1 + return idx diff --git a/physicsnemo/experimental/datapipes/healda/time_utils.py b/physicsnemo/experimental/datapipes/healda/time_utils.py index 9629816975..9a1bc1b96a 100644 --- a/physicsnemo/experimental/datapipes/healda/time_utils.py +++ b/physicsnemo/experimental/datapipes/healda/time_utils.py @@ -68,3 +68,33 @@ def as_cftime(timestamp) -> cftime.DatetimeGregorian: timestamp.minute, timestamp.second, ) + + +# --------------------------------------------------------------------------- +# cftime-based time encodings (used by transforms) +# --------------------------------------------------------------------------- + + +def cftime_to_timestamp(time: cftime.datetime) -> float: + """Convert a cftime datetime to a Unix timestamp (seconds since epoch).""" + return datetime.datetime( + *cftime.to_tuple(time), tzinfo=datetime.timezone.utc + ).timestamp() + + +def compute_second_of_day(time: cftime.datetime) -> float: + """Return seconds elapsed since midnight for *time*.""" + day_start = time.replace(hour=0, minute=0, second=0) + return (time - day_start) / datetime.timedelta(seconds=1) + + +def compute_day_of_year(time: cftime.datetime) -> float: + """Return fractional day-of-year for *time*.""" + day_start = time.replace(hour=0, minute=0, second=0) + year_start = day_start.replace(month=1, day=1) + return (time - year_start) / datetime.timedelta(seconds=86400) + + +def compute_timestamp(time: cftime.datetime) -> int: + """Return integer Unix timestamp for *time*.""" + return int(cftime_to_timestamp(time)) diff --git a/physicsnemo/experimental/datapipes/healda/transforms/era5_obs.py b/physicsnemo/experimental/datapipes/healda/transforms/era5_obs.py index 99bff19a4c..e532abfb1e 100644 --- a/physicsnemo/experimental/datapipes/healda/transforms/era5_obs.py +++ b/physicsnemo/experimental/datapipes/healda/transforms/era5_obs.py @@ -27,11 +27,9 @@ """ import dataclasses -import datetime import functools import warnings -import cftime import numpy as np import torch @@ -41,9 +39,20 @@ pa = OptionalImport("pyarrow") pc = OptionalImport("pyarrow.compute") +from physicsnemo.experimental.datapipes.healda.configs.sensors import ( + NPLATFORMS, + PLATFORM_NAME_TO_ID, + SENSOR_CONFIGS, + SENSOR_NAME_TO_ID, +) from physicsnemo.experimental.datapipes.healda.configs.static_data import load_lfrac, load_orography from physicsnemo.experimental.datapipes.healda.configs.variable_configs import VARIABLE_CONFIGS from physicsnemo.experimental.datapipes.healda.loaders.era5 import get_batch_info +from physicsnemo.experimental.datapipes.healda.time_utils import ( + compute_day_of_year, + compute_second_of_day, + compute_timestamp, +) from physicsnemo.experimental.datapipes.healda.transforms import obs_features, obs_features_ext from physicsnemo.experimental.datapipes.healda.types import UnifiedObservation, VariableConfig @@ -53,12 +62,6 @@ ) -def _cftime_to_timestamp(time: cftime.DatetimeGregorian) -> float: - return datetime.datetime( - *cftime.to_tuple(time), tzinfo=datetime.timezone.utc - ).timestamp() - - def _reorder_nest_to_hpxpad(x): x = torch.as_tensor(x) src_order = earth2grid.healpix.NEST @@ -66,25 +69,10 @@ def _reorder_nest_to_hpxpad(x): return earth2grid.healpix.reorder(x, src_order, dst_order) -def _compute_second_of_day(time: cftime.datetime): - day_start = time.replace(hour=0, minute=0, second=0) - return (time - day_start) / datetime.timedelta(seconds=1) - - -def _compute_day_of_year(time: cftime.datetime): - day_start = time.replace(hour=0, minute=0, second=0) - year_start = day_start.replace(month=1, day=1) - return (time - year_start) / datetime.timedelta(seconds=86400) - - -def _compute_timestamp(time: cftime.datetime): - return int(_cftime_to_timestamp(time)) - - def _get_static_condition(HPX_LEVEL, variable_config) -> torch.Tensor: lfrac = load_lfrac(HPX_LEVEL) orography = load_orography() - # Global mean and std computed over the UFS HEALPix level-6 training data. + # Precomputed global mean/std over the UFS HEALPix level-6 grid (2000–2023 ERA5). orog_scale, orog_mean = 627.3885284872, 232.56013904090733 lfrac_scale, lfrac_mean = 0.4695501683565522, 0.3410480857539571 data = { @@ -96,6 +84,30 @@ def _get_static_condition(HPX_LEVEL, variable_config) -> torch.Tensor: return array.unsqueeze(1) +def _map_platform_to_local( + platform: torch.Tensor, + lengths: torch.Tensor, + ordered_sensor_ids: torch.Tensor, + platform_luts: dict[int, torch.Tensor], + device: torch.device, +) -> torch.Tensor: + local_platform = torch.zeros_like(platform) + prev_end = 0 + for s_local, sensor_id in enumerate(ordered_sensor_ids.tolist()): + count = int(lengths[s_local].sum().item()) + end = prev_end + count + if end <= prev_end: + continue + lut = platform_luts.get(sensor_id) + if lut is None: + raise ValueError(f"Missing platform lookup table for sensor_id={sensor_id}") + lut = lut.to(device) + sensor_platform = platform[prev_end:end].long().clamp_(0, lut.shape[0] - 1) + local_platform[prev_end:end] = lut[sensor_platform] + prev_end = end + return local_platform + + @dataclasses.dataclass class ERA5ObsTransform: """Two-stage batch transform for ERA5 state + observation data. @@ -108,12 +120,17 @@ class ERA5ObsTransform: hpx_level_condition: HEALPix level for static conditioning data. extended_features: Whether to use extended (30-feature) observation encoding instead of the standard 28-feature encoding. + sensors: Ordered list of sensor names (keys of ``SENSOR_CONFIGS``, + e.g. ``["atms", "mhs", "conv"]``). Controls observation + grouping in the ``lengths`` tensor and per-sensor platform ID + remapping. Must match the sensors passed to the obs loader. """ variable_config: VariableConfig = VARIABLE_CONFIGS["era5"] hpx_level: int = 10 hpx_level_condition: int = 6 extended_features: bool = False + sensors: list[str] = dataclasses.field(default_factory=list) def __post_init__(self): batch_info = get_batch_info(self.variable_config) @@ -123,9 +140,32 @@ def __post_init__(self): @functools.cached_property def _grid(self): return earth2grid.healpix.Grid( - self.hpx_level, pixel_order=earth2grid.healpix.NEST + self.hpx_level, pixel_order=earth2grid.healpix.HEALPIX_PAD_XY + ) + + @functools.cached_property + def _ordered_sensor_ids(self) -> torch.Tensor: + if not self.sensors: + return torch.zeros((0,), dtype=torch.int32) + return torch.tensor( + [SENSOR_NAME_TO_ID[sensor_name] for sensor_name in self.sensors], + dtype=torch.int32, ) + @functools.cached_property + def _platform_luts(self) -> dict[int, torch.Tensor]: + luts: dict[int, torch.Tensor] = {} + for sensor_name in self.sensors: + sensor_id = SENSOR_NAME_TO_ID[sensor_name] + platform_ids = [ + PLATFORM_NAME_TO_ID[p] for p in SENSOR_CONFIGS[sensor_name].platforms + ] + lut = torch.zeros(NPLATFORMS, dtype=torch.long) + for local_platform_id, global_platform_id in enumerate(platform_ids): + lut[global_platform_id] = local_platform_id + luts[sensor_id] = lut + return luts + # ------------------------------------------------------------------ # Obs processing helpers # ------------------------------------------------------------------ @@ -173,9 +213,10 @@ def _append_batch_time_info_chunked( return out @staticmethod - def _build_observation_lengths_3d(obs_table: pa.Table, frame_times): + def _build_observation_lengths_3d( + obs_table: pa.Table, frame_times, ordered_sensor_ids: torch.Tensor + ): B, T = len(frame_times), len(frame_times[0]) - sensor_ids = set() counts_map = {} for batch in obs_table.to_batches(): @@ -185,46 +226,40 @@ def _build_observation_lengths_3d(obs_table: pa.Table, frame_times): b_id = int(batch["batch_idx"][0].as_py()) t_id = int(batch["time_idx"][0].as_py()) n = batch.num_rows - sensor_ids.add(s_id) if s_id not in counts_map: counts_map[s_id] = torch.zeros((B, T), dtype=torch.int32) counts_map[s_id][b_id, t_id] += n - active_sensor_ids = sorted(sensor_ids) - S = len(active_sensor_ids) + S = int(ordered_sensor_ids.numel()) + if S == 0: + return torch.zeros((0, B, T), dtype=torch.int32) - if not sensor_ids: - lengths_3d = torch.zeros((0, B, T), dtype=torch.int32) - sensor_id_to_local = torch.zeros((0,), dtype=torch.int32) - return lengths_3d, sensor_id_to_local - - max_sensor_id = max(sensor_ids) lengths_3d = torch.zeros((S, B, T), dtype=torch.int32) - for s_local, s_id in enumerate(active_sensor_ids): - lengths_3d[s_local] = counts_map[s_id] - - sensor_id_to_local = torch.full((max_sensor_id + 1,), -1, dtype=torch.int32) - for local_idx, sensor_id in enumerate(active_sensor_ids): - sensor_id_to_local[sensor_id] = local_idx + for s_local, s_id in enumerate(ordered_sensor_ids.tolist()): + if s_id in counts_map: + lengths_3d[s_local] = counts_map[s_id] - return lengths_3d, sensor_id_to_local + return lengths_3d def _process_obs(self, target_times, frames): + if not self.sensors: + raise ValueError("ERA5ObsTransform requires configured sensors.") + all_obs_with_indices = [] for b_idx, sample_frames in enumerate(frames): for t_idx, frame_dict in enumerate(sample_frames): table = frame_dict["obs"] table_with_indices = self._append_batch_time_info_chunked( table, b_idx, t_idx, - _compute_timestamp(target_times[b_idx][t_idx]), + compute_timestamp(target_times[b_idx][t_idx]), ) all_obs_with_indices.append(table_with_indices) obs = pa.concat_tables(all_obs_with_indices) obs = self._sort_by_record_batch(obs, "sensor_id") - lengths_3d, sensor_id_to_local = self._build_observation_lengths_3d( - obs, target_times + lengths_3d = self._build_observation_lengths_3d( + obs, target_times, self._ordered_sensor_ids ) obs_tensors = {} @@ -254,7 +289,7 @@ def _process_obs(self, target_times, frames): obs_type = pc.fill_null(obs["Observation_Type"], 0) obs_tensors["observation_type"] = torch.from_numpy(obs_type.to_numpy()) - return obs_tensors, lengths_3d, sensor_id_to_local + return (obs_tensors, lengths_3d) def _get_target(self, frames) -> torch.Tensor: all_state = [f["state"] for sample in frames for f in sample] @@ -298,9 +333,9 @@ def _apply_time_func(func): if "obs" in frames[0][0].keys(): out["unified_obs"] = self._process_obs(times, frames) out["target"] = self._get_target(frames).float() - out["second_of_day"] = _apply_time_func(_compute_second_of_day).float() - out["day_of_year"] = _apply_time_func(_compute_day_of_year).float() - out["timestamp"] = _apply_time_func(_compute_timestamp) + out["second_of_day"] = _apply_time_func(compute_second_of_day).float() + out["day_of_year"] = _apply_time_func(compute_day_of_year).float() + out["timestamp"] = _apply_time_func(compute_timestamp) b, _, t, _ = out["target"].shape condition = self._static_condition.float() if condition.shape[0] not in (1, b): @@ -338,17 +373,15 @@ def device_transform(self, batch, device): out = {} for key in batch: if key == "unified_obs": - obs_tensors, lengths, sensor_id_to_local = batch["unified_obs"] + obs_tensors, lengths = batch["unified_obs"] out[key] = self._device_transform_unified_obs( - obs_tensors, lengths, sensor_id_to_local, device + obs_tensors, lengths, device ) else: out[key] = batch[key].to(device, non_blocking=True) return out - def _device_transform_unified_obs( - self, obs_tensors, lengths, sensor_id_to_local, device - ): + def _device_transform_unified_obs(self, obs_tensors, lengths, device): def _to_device(tensor, non_blocking=True): if isinstance(tensor, torch.Tensor): return tensor.to(device, non_blocking=non_blocking) @@ -396,21 +429,25 @@ def _to_device(tensor, non_blocking=True): sol_zenith_angle=sol_zenith_tensor, ) + lengths = _to_device(lengths) + local_platform = _map_platform_to_local( + platform=platform_id_tensor, + lengths=lengths, + ordered_sensor_ids=self._ordered_sensor_ids.to(device), + platform_luts=self._platform_luts, + device=device, + ) + return UnifiedObservation( obs=observation_tensor, time=obs_time_ns, float_metadata=meta, pix=pix, local_channel=local_channel_id_tensor, - platform=platform_id_tensor, + platform=local_platform, obs_type=obs_type_tensor, global_channel=global_channel_id_tensor, + global_platform=platform_id_tensor, hpx_level=self.hpx_level, - lengths=_to_device(lengths), - sensor_id_to_local=_to_device(sensor_id_to_local), + lengths=lengths, ) - - -def identity_collate(obj): - """Identity collate function — batch is already assembled by __getitems__.""" - return obj diff --git a/physicsnemo/experimental/datapipes/healda/types.py b/physicsnemo/experimental/datapipes/healda/types.py index 6cb6b0c02a..2b71b7e703 100644 --- a/physicsnemo/experimental/datapipes/healda/types.py +++ b/physicsnemo/experimental/datapipes/healda/types.py @@ -161,13 +161,11 @@ class UnifiedObservation: global_channel: torch.Tensor hpx_level: int # HEALPix level that ``pix`` is defined at + global_platform: torch.Tensor | None = None lengths: torch.Tensor | None = ( None # 3D: (n_active_sensors, batch, time) per-window obs counts ) - sensor_id_to_local: torch.Tensor | None = ( - None # (max_sensor_id+1,) map: sensor_id -> local_idx (-1 if inactive) - ) @classmethod def empty( @@ -186,9 +184,9 @@ def empty( platform=torch.empty(0, dtype=torch.long, device=device), obs_type=torch.empty(0, dtype=torch.long, device=device), global_channel=torch.empty(0, dtype=torch.long, device=device), - lengths=torch.zeros(1, B, T, dtype=torch.long, device=device), - sensor_id_to_local=torch.zeros(1, dtype=torch.long, device=device), + global_platform=torch.empty(0, dtype=torch.long, device=device), hpx_level=hpx_level, + lengths=torch.zeros(1, B, T, dtype=torch.long, device=device), ) @property @@ -221,8 +219,8 @@ def _move(x): obs_type=_move(self.obs_type), global_channel=_move(self.global_channel), hpx_level=self.hpx_level, + global_platform=_move(self.global_platform), lengths=_move(self.lengths), - sensor_id_to_local=_move(self.sensor_id_to_local), ) @@ -286,15 +284,14 @@ def split_by_sensor( ) -> dict[int, UnifiedObservation]: """Slice a ``UnifiedObservation`` into per-sensor sub-objects. - Uses precomputed ``lengths`` and ``sensor_id_to_local`` for efficient - splitting without per-element sensor ID checks. + ``target_sensor_ids`` must list sensor IDs in the same order as the + sensor dimension (S) of ``obs.lengths``. Position ``s_local`` in + ``target_sensor_ids`` corresponds to index ``s_local`` in ``lengths[S]``. """ - if obs.lengths is None or obs.sensor_id_to_local is None: + if obs.lengths is None: raise ValueError("lengths is required for split_by_sensor") lengths = obs.lengths # [S, B, T] - sensor_id_to_local = obs.sensor_id_to_local - device = obs.obs.device B, T = obs.batch_dims @@ -309,21 +306,19 @@ def split_by_sensor( obs.obs_type, obs.global_channel, ] + if obs.global_platform is not None: + obs_fields.append(obs.global_platform) splits = [torch.split(f, sizes) for f in obs_fields] + global_platform_idx = 8 if obs.global_platform is not None else None - out = {} - for sensor_id in target_sensor_ids: - if sensor_id < len(sensor_id_to_local): - s_local = int(sensor_id_to_local[sensor_id].item()) - else: - s_local = -1 - - single_sensor_map = torch.full( - (sensor_id + 1,), -1, dtype=torch.int32, device=device + if len(target_sensor_ids) < len(sizes): + raise ValueError( + "target_sensor_ids must include the configured sensor order for split_by_sensor" ) - single_sensor_map[sensor_id] = 0 - if s_local < 0: + out = {} + for s_local, sensor_id in enumerate(target_sensor_ids): + if s_local >= len(sizes): sensor_lengths = torch.zeros((1, B, T), dtype=lengths.dtype, device=device) out[sensor_id] = UnifiedObservation( obs=obs.obs[:0], @@ -334,9 +329,11 @@ def split_by_sensor( platform=obs.platform[:0], obs_type=obs.obs_type[:0], global_channel=obs.global_channel[:0], + global_platform=( + obs.global_platform[:0] if obs.global_platform is not None else None + ), hpx_level=obs.hpx_level, lengths=sensor_lengths, - sensor_id_to_local=single_sensor_map, ) else: out[sensor_id] = UnifiedObservation( @@ -348,9 +345,13 @@ def split_by_sensor( platform=splits[5][s_local], obs_type=splits[6][s_local], global_channel=splits[7][s_local], + global_platform=( + splits[global_platform_idx][s_local] + if global_platform_idx is not None + else None + ), hpx_level=obs.hpx_level, lengths=lengths[s_local : s_local + 1], - sensor_id_to_local=single_sensor_map, ) return out diff --git a/pyproject.toml b/pyproject.toml index 649fc8eac4..e3cb876d18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,6 @@ conflicts = [ torch-sparse = ["torch"] torch-cluster = ["torch"] torch-scatter = ["torch"] -earth2grid = ["setuptools", "torch"] [[tool.uv.index]] name = "nvidia" @@ -146,8 +145,6 @@ natten = [ { index = "natten-cu130-whl", extra = "natten-cu13" }, ] -earth2grid = { url = "https://github.com/NVlabs/earth2grid/archive/8fdff5a78d324f8d25afe224915301b3169bffe2.tar.gz" } - ##################################################################### # Flags Controlling the local build of physicsnemo ##################################################################### @@ -286,7 +283,11 @@ perf = [ healda = [ "nvidia-physicsnemo[datapipes-extras]", "pyarrow>=14.0.0", - "earth2grid", + "triton>=3.0.0", + "fsspec>=2023.5.0", + # earth2grid is required at runtime but temporarily excluded from the + # install group due to previous torch version pinning and --no-build-isolation + # conflicts with other pacakges; install it manually for now. ] diff --git a/examples/weather/healda/test/conftest.py b/test/datapipes/healda/__init__.py similarity index 85% rename from examples/weather/healda/test/conftest.py rename to test/datapipes/healda/__init__.py index 3a4ff72671..af85283aa4 100644 --- a/examples/weather/healda/test/conftest.py +++ b/test/datapipes/healda/__init__.py @@ -13,11 +13,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import pytest -import torch - - -@pytest.fixture -def device(): - return "cuda" if torch.cuda.is_available() else "cpu" diff --git a/examples/weather/healda/test/test_combined_schema.py b/test/datapipes/healda/test_combined_schema.py similarity index 97% rename from examples/weather/healda/test/test_combined_schema.py rename to test/datapipes/healda/test_combined_schema.py index d03b0efd08..0c3b4a478b 100644 --- a/examples/weather/healda/test/test_combined_schema.py +++ b/test/datapipes/healda/test_combined_schema.py @@ -15,7 +15,9 @@ # limitations under the License. """Tests for the combined observation schema and sensor config consistency.""" -import pyarrow as pa +import pytest + +pa = pytest.importorskip("pyarrow") from physicsnemo.experimental.datapipes.healda.configs.combined_schema import ( get_channel_table_schema, diff --git a/examples/weather/healda/test/test_features.py b/test/datapipes/healda/test_features.py similarity index 98% rename from examples/weather/healda/test/test_features.py rename to test/datapipes/healda/test_features.py index 120b870084..ad2669e960 100644 --- a/examples/weather/healda/test/test_features.py +++ b/test/datapipes/healda/test_features.py @@ -22,10 +22,10 @@ import pytest import torch +triton = pytest.importorskip("triton") + from physicsnemo.experimental.datapipes.healda.transforms import ( obs_features as standard, -) -from physicsnemo.experimental.datapipes.healda.transforms import ( obs_features_ext as extended, ) diff --git a/examples/weather/healda/test/test_indexing.py b/test/datapipes/healda/test_indexing.py similarity index 100% rename from examples/weather/healda/test/test_indexing.py rename to test/datapipes/healda/test_indexing.py diff --git a/examples/weather/healda/test/test_obs_filtering.py b/test/datapipes/healda/test_obs_filtering.py similarity index 98% rename from examples/weather/healda/test/test_obs_filtering.py rename to test/datapipes/healda/test_obs_filtering.py index 5fd364da3f..e99cfcda82 100644 --- a/examples/weather/healda/test/test_obs_filtering.py +++ b/test/datapipes/healda/test_obs_filtering.py @@ -16,7 +16,9 @@ """Tests for observation quality-control filtering.""" import numpy as np -import pyarrow as pa +import pytest + +pa = pytest.importorskip("pyarrow") from physicsnemo.experimental.datapipes.healda.configs.sensors import SENSOR_OFFSET from physicsnemo.experimental.datapipes.healda.transforms.obs_filtering import ( diff --git a/examples/weather/healda/test/test_prefetch.py b/test/datapipes/healda/test_prefetch.py similarity index 100% rename from examples/weather/healda/test/test_prefetch.py rename to test/datapipes/healda/test_prefetch.py diff --git a/test/datapipes/healda/test_samplers.py b/test/datapipes/healda/test_samplers.py new file mode 100644 index 0000000000..d37a56bcc2 --- /dev/null +++ b/test/datapipes/healda/test_samplers.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for RestartableDistributedSampler.""" + +import pytest + +from physicsnemo.experimental.datapipes.healda.samplers import ( + RestartableDistributedSampler, +) + + +def test_basic_iteration(): + """Sampler yields all indices for the rank and respects __len__.""" + dataset = list(range(100)) + sampler = RestartableDistributedSampler(dataset, rank=0, num_replicas=1, seed=42) + sampler.set_epoch(0) + + indices = list(sampler) + assert len(indices) == 100 + assert sorted(indices) == list(range(100)) + + +def test_epoch_auto_advance(): + """Exhausting an epoch raises StopIteration, then next epoch starts.""" + dataset = list(range(20)) + sampler = RestartableDistributedSampler(dataset, rank=0, num_replicas=1, seed=7) + sampler.set_epoch(0) + + epoch0 = list(sampler) + assert len(epoch0) == 20 + + # After StopIteration, epoch has advanced; next iteration gives a new epoch + epoch1 = list(sampler) + assert len(epoch1) == 20 + assert sorted(epoch1) == list(range(20)) + # Different permutation (with high probability) + assert epoch0 != epoch1 + + +def test_restart_resumes_correctly(): + """restart() resumes from exact checkpoint position.""" + dataset = list(range(50)) + sampler = RestartableDistributedSampler(dataset, rank=0, num_replicas=1, seed=42) + sampler.set_epoch(0) + + # Collect first 20 indices + first_20 = [] + for _ in range(20): + first_20.append(next(sampler)) + + # Collect remaining + remaining = [] + try: + while True: + remaining.append(next(sampler)) + except StopIteration: + pass + + # Now restart at position 20 and verify we get the same remaining + sampler.restart(epoch=0, iteration=20, seed=42) + restarted_remaining = list(sampler) + assert restarted_remaining == remaining + + +def test_reproducible(): + """Same seed/rank/epoch produces identical permutation.""" + dataset = list(range(100)) + s1 = RestartableDistributedSampler(dataset, rank=0, num_replicas=1, seed=42) + s2 = RestartableDistributedSampler(dataset, rank=0, num_replicas=1, seed=42) + s1.set_epoch(0) + s2.set_epoch(0) + assert list(s1) == list(s2) + + +def test_multi_replica_independent(): + """Different ranks get independent permutations of the same length.""" + dataset = list(range(100)) + s0 = RestartableDistributedSampler(dataset, rank=0, num_replicas=2, seed=42) + s1 = RestartableDistributedSampler(dataset, rank=1, num_replicas=2, seed=42) + s0.set_epoch(0) + s1.set_epoch(0) + + idx0 = list(s0) + idx1 = list(s1) + + assert len(idx0) == 50 + assert len(idx1) == 50 + # Each rank gets a valid subset of indices + assert all(0 <= i < 100 for i in idx0) + assert all(0 <= i < 100 for i in idx1) + # Rank-dependent seed produces different orderings + assert idx0 != idx1 + + +def test_len_drops_remainder(): + """Length accounts for dropping remainder across replicas.""" + dataset = list(range(103)) + sampler = RestartableDistributedSampler(dataset, rank=0, num_replicas=4, seed=0) + # 103 // 4 = 25, remainder 3 dropped + assert len(sampler) == 25 diff --git a/examples/weather/healda/test/test_time_utils.py b/test/datapipes/healda/test_time_utils.py similarity index 97% rename from examples/weather/healda/test/test_time_utils.py rename to test/datapipes/healda/test_time_utils.py index 942d492285..6ab430e846 100644 --- a/examples/weather/healda/test/test_time_utils.py +++ b/test/datapipes/healda/test_time_utils.py @@ -17,7 +17,10 @@ import datetime -import cftime +import pytest + +cftime = pytest.importorskip("cftime") + import numpy as np import pandas as pd diff --git a/examples/weather/healda/test/test_types.py b/test/datapipes/healda/test_types.py similarity index 87% rename from examples/weather/healda/test/test_types.py rename to test/datapipes/healda/test_types.py index 4fff22621e..0da6a1d563 100644 --- a/examples/weather/healda/test/test_types.py +++ b/test/datapipes/healda/test_types.py @@ -51,10 +51,6 @@ def make_realistic_obs( if obs[0] == s_id and obs[1] == b and obs[2] == t ) - sensor_id_to_local = torch.full((max(sensors) + 1,), -1, dtype=torch.int32) - for local_idx, s_id in enumerate(sensors): - sensor_id_to_local[s_id] = local_idx - nobs = len(all_obs) return UnifiedObservation( obs=values.unsqueeze(1).expand(nobs, 3), @@ -65,9 +61,9 @@ def make_realistic_obs( platform=torch.zeros(nobs, dtype=torch.long), obs_type=torch.zeros(nobs, dtype=torch.long), global_channel=torch.zeros(nobs, dtype=torch.long), + global_platform=torch.zeros(nobs, dtype=torch.long), hpx_level=6, lengths=lengths_3d, - sensor_id_to_local=sensor_id_to_local, ) @@ -102,7 +98,9 @@ def test_split_lengths_match_obs_count(): def test_split_empty_sensor(): + """Extra sensor_ids beyond the configured list produce empty sub-objects.""" obs = make_realistic_obs(B=1, T=1, sensors=[0, 1]) + # target_sensor_ids must include the configured order first, extras appended split = split_by_sensor(obs, [0, 1, 2]) assert split[2].obs.shape[0] == 0 @@ -121,7 +119,6 @@ def test_split_requires_lengths(): global_channel=torch.zeros(10, dtype=torch.long), hpx_level=6, lengths=None, - sensor_id_to_local=None, ) with pytest.raises(ValueError, match="lengths is required"): @@ -150,10 +147,6 @@ def test_split_handles_sparse_windows(): lengths_3d[0, :, :] = 2 lengths_3d[1, 1, 2] = 3 - sensor_id_to_local = torch.full((5,), -1, dtype=torch.int32) - for local_idx, s_id in enumerate(sensors): - sensor_id_to_local[s_id] = local_idx - obs = UnifiedObservation( obs=torch.arange(nobs, dtype=torch.float32).unsqueeze(1).expand(nobs, 3), time=torch.zeros(nobs, dtype=torch.long), @@ -165,13 +158,15 @@ def test_split_handles_sparse_windows(): platform=torch.zeros(nobs, dtype=torch.long), obs_type=torch.zeros(nobs, dtype=torch.long), global_channel=torch.zeros(nobs, dtype=torch.long), + global_platform=torch.zeros(nobs, dtype=torch.long), hpx_level=6, lengths=lengths_3d, - sensor_id_to_local=sensor_id_to_local, ) assert obs.batch_dims == (2, 3) + # Positional: target_sensor_ids[0]=0 -> lengths[0], target_sensor_ids[1]=4 -> lengths[1] + # Extra sensor 99 is beyond len(sizes) -> empty split = split_by_sensor(obs, [0, 4, 99]) s0 = split[0] @@ -186,3 +181,12 @@ def test_split_handles_sparse_windows(): s99 = split[99] assert s99.obs.shape[0] == 0 assert torch.all(s99.lengths == 0) + + +def test_split_global_platform_propagated(): + """global_platform is sliced correctly through split_by_sensor.""" + obs = make_realistic_obs(B=1, T=1, sensors=[0, 1]) + split = split_by_sensor(obs, [0, 1]) + for sid in [0, 1]: + assert split[sid].global_platform is not None + assert split[sid].global_platform.shape[0] == split[sid].obs.shape[0] diff --git a/uv.lock b/uv.lock index eb9c622c0c..f900e4dd67 100644 --- a/uv.lock +++ b/uv.lock @@ -1613,50 +1613,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl", hash = "sha256:2a3175ce74a06109ff9307d90a230f81215cbac9a751f4d1c6194644b8204f9d", size = 21592, upload-time = "2024-05-23T14:13:55.283Z" }, ] -[[package]] -name = "earth2grid" -version = "2025.11.1+torch210" -source = { url = "https://github.com/NVlabs/earth2grid/archive/8fdff5a78d324f8d25afe224915301b3169bffe2.tar.gz" } -dependencies = [ - { name = "einops" }, - { name = "numpy" }, - { name = "scipy" }, - { name = "torch", version = "2.10.0", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-18-nvidia-physicsnemo-cu12' and extra == 'extra-18-nvidia-physicsnemo-cu13') or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13') or (extra != 'extra-18-nvidia-physicsnemo-cu12' and extra != 'extra-18-nvidia-physicsnemo-cu13')" }, - { name = "torch", version = "2.10.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "extra == 'extra-18-nvidia-physicsnemo-cu12' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, - { name = "torch", version = "2.10.0+cu130", source = { registry = "https://download.pytorch.org/whl/cu130" }, marker = "extra == 'extra-18-nvidia-physicsnemo-cu13' or (extra == 'extra-18-nvidia-physicsnemo-natten-cu12' and extra == 'extra-18-nvidia-physicsnemo-natten-cu13')" }, -] -sdist = { hash = "sha256:55f54084fd6a1fa505b68a2ab067aba7cab76f94d158dc8720dbb15c4e3ac235" } - -[package.metadata] -requires-dist = [ - { name = "black", marker = "extra == 'test'", specifier = ">=21.5b2" }, - { name = "bump2version", marker = "extra == 'dev'", specifier = ">=1.0.1" }, - { name = "coverage", marker = "extra == 'test'", specifier = ">=7.0.0" }, - { name = "einops", specifier = ">=0.7.0" }, - { name = "flake8", marker = "extra == 'test'", specifier = ">=3.9.2" }, - { name = "flake8-docstrings", marker = "extra == 'test'", specifier = ">=1.6.0" }, - { name = "isort", marker = "extra == 'test'", specifier = ">=5.8.0" }, - { name = "matplotlib", marker = "extra == 'test'" }, - { name = "matplotlib", marker = "extra == 'viz'" }, - { name = "mypy", marker = "extra == 'test'", specifier = ">=0.900" }, - { name = "netcdf4", marker = "extra == 'all'", specifier = ">=1.6.5" }, - { name = "numpy", specifier = ">=1.23.3" }, - { name = "pip", marker = "extra == 'dev'", specifier = ">=20.3.1" }, - { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=2.12.0" }, - { name = "pytest", marker = "extra == 'test'", specifier = ">=6.2.4" }, - { name = "pytest-cov", marker = "extra == 'test'", specifier = ">=2.12.0" }, - { name = "pytest-regtest", marker = "extra == 'test'", specifier = ">=1.5.1,<2" }, - { name = "pyvista", marker = "extra == 'viz'", specifier = ">=0.43.2" }, - { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.5" }, - { name = "scipy" }, - { name = "toml", marker = "extra == 'dev'", specifier = ">=0.10.2" }, - { name = "torch", specifier = ">=2.10,<2.11" }, - { name = "tox", marker = "extra == 'dev'", specifier = ">=3.20.1" }, - { name = "twine", marker = "extra == 'dev'", specifier = ">=3.3.0" }, - { name = "virtualenv", marker = "extra == 'dev'", specifier = ">=20.2.2" }, -] -provides-extras = ["all", "viz", "test", "dev"] - [[package]] name = "einops" version = "0.8.2" @@ -4707,11 +4663,12 @@ gnns = [ ] healda = [ { name = "dask" }, - { name = "earth2grid" }, + { name = "fsspec" }, { name = "netcdf4" }, { name = "pyarrow" }, { name = "tensordict" }, { name = "tfrecord" }, + { name = "triton" }, { name = "xarray" }, { name = "zarr" }, ] @@ -4781,8 +4738,8 @@ requires-dist = [ { name = "cupy-cuda13x", marker = "extra == 'cu13'", specifier = "==13.6.0" }, { name = "dask", marker = "extra == 'datapipes-extras'" }, { name = "dask", marker = "extra == 'healda'" }, - { name = "earth2grid", marker = "extra == 'healda'", url = "https://github.com/NVlabs/earth2grid/archive/8fdff5a78d324f8d25afe224915301b3169bffe2.tar.gz" }, { name = "einops", specifier = ">=0.8.1" }, + { name = "fsspec", marker = "extra == 'healda'", specifier = ">=2023.5.0" }, { name = "gitpython", specifier = ">=3.1.40" }, { name = "h5py", specifier = ">=3.15.1" }, { name = "hydra-core", specifier = ">=1.3.2" }, @@ -4845,6 +4802,7 @@ requires-dist = [ { name = "tqdm", specifier = ">=4.60.0" }, { name = "transformer-engine", extras = ["pytorch"], marker = "extra == 'perf'" }, { name = "treelib", specifier = ">=1.2.5" }, + { name = "triton", marker = "extra == 'healda'", specifier = ">=3.0.0" }, { name = "vtk", marker = "extra == 'gnns'" }, { name = "vtk", marker = "extra == 'mesh-extras'", specifier = ">=9.6.0" }, { name = "vtk", marker = "extra == 'model-extras'" }, From 31b7e8c3beb5d9a74089ea042dfdf04f396e3e3c Mon Sep 17 00:00:00 2001 From: Peter Harrington Date: Tue, 14 Apr 2026 16:39:41 -0700 Subject: [PATCH 05/10] move imports, cleanup --- .../healda/configs/variable_configs.py | 40 +------------------ .../datapipes/healda/loaders/era5.py | 2 +- .../experimental/datapipes/healda/prefetch.py | 3 +- test/datapipes/healda/test_combined_schema.py | 4 +- test/datapipes/healda/test_features.py | 6 ++- test/datapipes/healda/test_obs_filtering.py | 5 +-- test/datapipes/healda/test_samplers.py | 8 +--- test/datapipes/healda/test_time_utils.py | 7 ++-- test/datapipes/healda/test_types.py | 1 - 9 files changed, 17 insertions(+), 59 deletions(-) diff --git a/physicsnemo/experimental/datapipes/healda/configs/variable_configs.py b/physicsnemo/experimental/datapipes/healda/configs/variable_configs.py index 0e2698acb2..5856c39e13 100644 --- a/physicsnemo/experimental/datapipes/healda/configs/variable_configs.py +++ b/physicsnemo/experimental/datapipes/healda/configs/variable_configs.py @@ -19,27 +19,6 @@ VARIABLE_CONFIGS = {} -VARIABLE_CONFIGS["default"] = VariableConfig( - name="ufs", - levels=[1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50], - variables_3d=["Q", "U", "V", "T", "Z"], - variables_2d=[ - "tas", - "uas", - "vas", - "rlut", - "rsut", - "pressfc", - "pr", - "rsds", - "sst", - "sic", - "hfls", - "huss", - ], - variables_static=["orog", "lfrac"], -) - VARIABLE_CONFIGS["era5"] = VariableConfig( name="era5", levels=[1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50], @@ -56,21 +35,4 @@ "sic", ], variables_static=["orog", "lfrac"], -) - -VARIABLE_CONFIGS["gfs"] = VariableConfig( - name="gfs", - levels=[1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50], - variables_3d=["U", "V", "T", "Z", "Q"], - variables_2d=[ - "tcwv", - "tas", - "uas", - "vas", - "100u", - "100v", - "pres_msl", - "sp", - ], - variables_static=["orog", "lfrac"], -) +) \ No newline at end of file diff --git a/physicsnemo/experimental/datapipes/healda/loaders/era5.py b/physicsnemo/experimental/datapipes/healda/loaders/era5.py index 6ee1c15b0b..b711cc6ef0 100644 --- a/physicsnemo/experimental/datapipes/healda/loaders/era5.py +++ b/physicsnemo/experimental/datapipes/healda/loaders/era5.py @@ -201,7 +201,7 @@ def _get_nearest_stats(config: VariableConfig): return raw.loc[mapped_idx] -def open_era5_xarray(path: str | None = None, **kwargs) -> xarray.Dataset: +def open_era5_xarray(path: str | None = None, **kwargs) -> "xarray.Dataset": """Open the ERA5 74-variable zarr dataset as xarray. Args: diff --git a/physicsnemo/experimental/datapipes/healda/prefetch.py b/physicsnemo/experimental/datapipes/healda/prefetch.py index 273251feac..360de0ba50 100644 --- a/physicsnemo/experimental/datapipes/healda/prefetch.py +++ b/physicsnemo/experimental/datapipes/healda/prefetch.py @@ -18,7 +18,8 @@ ``prefetch_map`` wraps any iterable (typically a DataLoader) and applies a transform function in a background thread using a separate CUDA stream. This hides the CPU-to-GPU transfer and GPU featurization latency behind the -training forward/backward pass. +training forward/backward pass (better parallelsim). It works best when +`pin_memory=True` is used on the DataLoader. """ import dataclasses diff --git a/test/datapipes/healda/test_combined_schema.py b/test/datapipes/healda/test_combined_schema.py index 0c3b4a478b..a1d16cd77e 100644 --- a/test/datapipes/healda/test_combined_schema.py +++ b/test/datapipes/healda/test_combined_schema.py @@ -17,8 +17,6 @@ import pytest -pa = pytest.importorskip("pyarrow") - from physicsnemo.experimental.datapipes.healda.configs.combined_schema import ( get_channel_table_schema, get_combined_observation_schema, @@ -28,6 +26,8 @@ SENSOR_NAME_TO_ID, ) +pa = pytest.importorskip("pyarrow") + def test_combined_schema_has_required_fields(): schema = get_combined_observation_schema() diff --git a/test/datapipes/healda/test_features.py b/test/datapipes/healda/test_features.py index ad2669e960..062528a749 100644 --- a/test/datapipes/healda/test_features.py +++ b/test/datapipes/healda/test_features.py @@ -22,13 +22,15 @@ import pytest import torch -triton = pytest.importorskip("triton") - from physicsnemo.experimental.datapipes.healda.transforms import ( obs_features as standard, +) +from physicsnemo.experimental.datapipes.healda.transforms import ( obs_features_ext as extended, ) +triton = pytest.importorskip("triton") + def _make_obs_data(n, device, include_lat=False): g = torch.Generator(device=device) diff --git a/test/datapipes/healda/test_obs_filtering.py b/test/datapipes/healda/test_obs_filtering.py index e99cfcda82..e0fc73d19b 100644 --- a/test/datapipes/healda/test_obs_filtering.py +++ b/test/datapipes/healda/test_obs_filtering.py @@ -18,13 +18,13 @@ import numpy as np import pytest -pa = pytest.importorskip("pyarrow") - from physicsnemo.experimental.datapipes.healda.configs.sensors import SENSOR_OFFSET from physicsnemo.experimental.datapipes.healda.transforms.obs_filtering import ( filter_observations, ) +pa = pytest.importorskip("pyarrow") + def _make_filter_test_table(): """Create a minimal table with channel metadata columns required for filtering.""" @@ -38,7 +38,6 @@ def _make_filter_test_table(): conv_offset + 6, conv_offset + 7, ] - n = len(channels) return pa.table( { diff --git a/test/datapipes/healda/test_samplers.py b/test/datapipes/healda/test_samplers.py index d37a56bcc2..44ef1c1dcc 100644 --- a/test/datapipes/healda/test_samplers.py +++ b/test/datapipes/healda/test_samplers.py @@ -15,8 +15,6 @@ # limitations under the License. """Tests for RestartableDistributedSampler.""" -import pytest - from physicsnemo.experimental.datapipes.healda.samplers import ( RestartableDistributedSampler, ) @@ -56,10 +54,8 @@ def test_restart_resumes_correctly(): sampler = RestartableDistributedSampler(dataset, rank=0, num_replicas=1, seed=42) sampler.set_epoch(0) - # Collect first 20 indices - first_20 = [] - for _ in range(20): - first_20.append(next(sampler)) + # Consume first 20 indices + _ = [next(sampler) for _ in range(20)] # Collect remaining remaining = [] diff --git a/test/datapipes/healda/test_time_utils.py b/test/datapipes/healda/test_time_utils.py index 6ab430e846..29e9945554 100644 --- a/test/datapipes/healda/test_time_utils.py +++ b/test/datapipes/healda/test_time_utils.py @@ -17,12 +17,9 @@ import datetime -import pytest - -cftime = pytest.importorskip("cftime") - import numpy as np import pandas as pd +import pytest from physicsnemo.experimental.datapipes.healda.time_utils import ( as_cftime, @@ -31,6 +28,8 @@ as_timestamp, ) +cftime = pytest.importorskip("cftime") + def test_as_numpy_from_pandas_index(): idx = pd.date_range("2020-01-01", periods=3, freq="h") diff --git a/test/datapipes/healda/test_types.py b/test/datapipes/healda/test_types.py index 0da6a1d563..2fb39d0d8a 100644 --- a/test/datapipes/healda/test_types.py +++ b/test/datapipes/healda/test_types.py @@ -133,7 +133,6 @@ def test_lengths_nonnegative(): def test_split_handles_sparse_windows(): """Sensor missing from some (b,t) windows.""" B, T = 2, 3 - sensors = [0, 4] all_obs = [] for b in range(B): From 2165efc8f6b23f86c2d5f0b154394e9340476b61 Mon Sep 17 00:00:00 2001 From: Peter Harrington Date: Tue, 14 Apr 2026 16:52:13 -0700 Subject: [PATCH 06/10] ruff check fix --- test/datapipes/healda/test_combined_schema.py | 8 ++++---- test/datapipes/healda/test_features.py | 8 ++++---- test/datapipes/healda/test_obs_filtering.py | 10 ++++++---- test/datapipes/healda/test_time_utils.py | 6 +++--- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/test/datapipes/healda/test_combined_schema.py b/test/datapipes/healda/test_combined_schema.py index a1d16cd77e..6d0f7656b5 100644 --- a/test/datapipes/healda/test_combined_schema.py +++ b/test/datapipes/healda/test_combined_schema.py @@ -17,17 +17,17 @@ import pytest -from physicsnemo.experimental.datapipes.healda.configs.combined_schema import ( +pa = pytest.importorskip("pyarrow") + +from physicsnemo.experimental.datapipes.healda.configs.combined_schema import ( # noqa: E402 get_channel_table_schema, get_combined_observation_schema, ) -from physicsnemo.experimental.datapipes.healda.configs.sensors import ( +from physicsnemo.experimental.datapipes.healda.configs.sensors import ( # noqa: E402 SENSOR_CONFIGS, SENSOR_NAME_TO_ID, ) -pa = pytest.importorskip("pyarrow") - def test_combined_schema_has_required_fields(): schema = get_combined_observation_schema() diff --git a/test/datapipes/healda/test_features.py b/test/datapipes/healda/test_features.py index 062528a749..3c77ecdb9e 100644 --- a/test/datapipes/healda/test_features.py +++ b/test/datapipes/healda/test_features.py @@ -22,15 +22,15 @@ import pytest import torch -from physicsnemo.experimental.datapipes.healda.transforms import ( +triton = pytest.importorskip("triton") + +from physicsnemo.experimental.datapipes.healda.transforms import ( # noqa: E402 obs_features as standard, ) -from physicsnemo.experimental.datapipes.healda.transforms import ( +from physicsnemo.experimental.datapipes.healda.transforms import ( # noqa: E402 obs_features_ext as extended, ) -triton = pytest.importorskip("triton") - def _make_obs_data(n, device, include_lat=False): g = torch.Generator(device=device) diff --git a/test/datapipes/healda/test_obs_filtering.py b/test/datapipes/healda/test_obs_filtering.py index e0fc73d19b..ca5144bae8 100644 --- a/test/datapipes/healda/test_obs_filtering.py +++ b/test/datapipes/healda/test_obs_filtering.py @@ -18,13 +18,15 @@ import numpy as np import pytest -from physicsnemo.experimental.datapipes.healda.configs.sensors import SENSOR_OFFSET -from physicsnemo.experimental.datapipes.healda.transforms.obs_filtering import ( +pa = pytest.importorskip("pyarrow") + +from physicsnemo.experimental.datapipes.healda.configs.sensors import ( # noqa: E402 + SENSOR_OFFSET, +) +from physicsnemo.experimental.datapipes.healda.transforms.obs_filtering import ( # noqa: E402 filter_observations, ) -pa = pytest.importorskip("pyarrow") - def _make_filter_test_table(): """Create a minimal table with channel metadata columns required for filtering.""" diff --git a/test/datapipes/healda/test_time_utils.py b/test/datapipes/healda/test_time_utils.py index 29e9945554..65a0945373 100644 --- a/test/datapipes/healda/test_time_utils.py +++ b/test/datapipes/healda/test_time_utils.py @@ -21,15 +21,15 @@ import pandas as pd import pytest -from physicsnemo.experimental.datapipes.healda.time_utils import ( +cftime = pytest.importorskip("cftime") + +from physicsnemo.experimental.datapipes.healda.time_utils import ( # noqa: E402 as_cftime, as_numpy, as_pydatetime, as_timestamp, ) -cftime = pytest.importorskip("cftime") - def test_as_numpy_from_pandas_index(): idx = pd.date_range("2020-01-01", periods=3, freq="h") From 46db004f48dda4f1a42a4cad78f88c2d3d921744 Mon Sep 17 00:00:00 2001 From: Peter Harrington Date: Tue, 14 Apr 2026 18:29:20 -0700 Subject: [PATCH 07/10] skip prefetch on CPU --- test/datapipes/healda/test_prefetch.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/datapipes/healda/test_prefetch.py b/test/datapipes/healda/test_prefetch.py index a0080bb09f..80ec6128ea 100644 --- a/test/datapipes/healda/test_prefetch.py +++ b/test/datapipes/healda/test_prefetch.py @@ -16,10 +16,12 @@ """Tests for prefetch_map background processing.""" import pytest +import torch from physicsnemo.experimental.datapipes.healda.prefetch import prefetch_map +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_prefetch_map_basic(): """Prefetch with a simple doubling transform.""" data = list(range(10)) @@ -27,6 +29,7 @@ def test_prefetch_map_basic(): assert list(loader) == list(range(0, 20, 2)) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_prefetch_map_error_propagation(): """Exceptions in the background thread propagate to the consumer.""" data = list(range(4)) From 229abfb4dbddbfeac4aeafc1b56fad38f2ec109c Mon Sep 17 00:00:00 2001 From: Peter Harrington Date: Thu, 16 Apr 2026 16:57:40 -0700 Subject: [PATCH 08/10] Rename to local_platform --- examples/weather/healda/requirements.txt | 4 +- .../weather/healda/scripts/compare_loaders.py | 587 ++++++++++++++++++ .../datapipes/healda/transforms/era5_obs.py | 2 +- .../experimental/datapipes/healda/types.py | 12 +- test/datapipes/healda/test_types.py | 6 +- 5 files changed, 598 insertions(+), 13 deletions(-) create mode 100755 examples/weather/healda/scripts/compare_loaders.py diff --git a/examples/weather/healda/requirements.txt b/examples/weather/healda/requirements.txt index 1889a50887..37a87fa890 100644 --- a/examples/weather/healda/requirements.txt +++ b/examples/weather/healda/requirements.txt @@ -3,7 +3,5 @@ cftime pyarrow dotenv earth2grid @ git+https://github.com/NVlabs/earth2grid.git@main -healpy matplotlib -joblib -icechunk \ No newline at end of file +joblib \ No newline at end of file diff --git a/examples/weather/healda/scripts/compare_loaders.py b/examples/weather/healda/scripts/compare_loaders.py new file mode 100755 index 0000000000..e9c42a27ba --- /dev/null +++ b/examples/weather/healda/scripts/compare_loaders.py @@ -0,0 +1,587 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Compare outputs between the ported data loader and the reference implementation. + +Modes: + smoketest — 3 sample indices, fast sanity check (~1 min) + full — 50 indices spread across train split (~10 min) + +Usage: + # From examples/weather/healda/ + python scripts/compare_loaders.py smoketest + python scripts/compare_loaders.py full + python scripts/compare_loaders.py full --indices 0 100 500 1000 + +Requires: + - The reference codebase importable (healda-reference/src on PYTHONPATH, or + the healda package installed). + - Environment variables from .env (ERA5_74VAR, UFS_OBS_PATH, etc.). +""" + +from __future__ import annotations + +import argparse +import os +import sys +import time + +import numpy as np +import pandas as pd +import torch + +# --------------------------------------------------------------------------- +# Path setup +# --------------------------------------------------------------------------- +RECIPE_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +REFERENCE_ROOT = os.path.join(os.path.dirname(RECIPE_ROOT), "healda-reference") + +# Load .env +from dotenv import load_dotenv + +load_dotenv(os.path.join(RECIPE_ROOT, ".env")) + + +# ============================================================================ +# Ported loader construction +# ============================================================================ + + +def build_ported_dataset(split="train", sensors=None): + """Construct ObsERA5Dataset from the ported data/ package.""" + from physicsnemo.experimental.datapipes.healda.configs.variable_configs import VARIABLE_CONFIGS + from physicsnemo.experimental.datapipes.healda.dataset import ObsERA5Dataset + from physicsnemo.experimental.datapipes.healda.loaders.ufs_obs import UFSUnifiedLoader + from physicsnemo.experimental.datapipes.healda.transforms.era5_obs import ERA5ObsTransform + + variable_config = VARIABLE_CONFIGS["era5"] + + if sensors is None: + sensors = ["atms", "mhs", "amsua", "amsub"] + + obs_path = os.environ["UFS_OBS_PATH"] + obs_loader = UFSUnifiedLoader( + data_path=obs_path, + sensors=sensors, + normalization="zscore", + obs_context_hours=(-21, 3), + ) + + transform = ERA5ObsTransform(variable_config=variable_config, sensors=sensors) + + import xarray + + era5_path = os.environ["ERA5_74VAR"] + era5_ds = xarray.open_zarr(era5_path, chunks=None) + era5_data = era5_ds["data"] + + dataset = ObsERA5Dataset( + era5_data=era5_data, + obs_loader=obs_loader, + transform=transform, + variable_config=variable_config, + split=split, + ) + return dataset + + +# ============================================================================ +# Reference loader construction +# ============================================================================ + + +def build_reference_dataset(split="train", sensors=None): + """Construct ObsERA5Dataset from the reference healda-reference codebase. + + Requires healda-reference/src and healda-reference/ on sys.path. + """ + # Add reference paths + ref_src = os.path.join(REFERENCE_ROOT, "src") + ref_private = REFERENCE_ROOT + for p in [ref_src, ref_private]: + if p not in sys.path: + sys.path.insert(0, p) + + import dotenv as _dotenv + + _dotenv.load_dotenv(os.path.join(RECIPE_ROOT, ".env")) + + from healda.config.models import ObsConfig + from private.fcn3_dataset import ObsERA5Dataset as RefObsERA5Dataset + + # Build obs_config matching default sensors + use_conv = sensors is not None and "conv" in sensors + obs_config = ObsConfig( + use_obs=True, + innovation_type="none", + context_start=-21, + context_end=3, + use_conv=use_conv, + ) + + dataset = RefObsERA5Dataset( + split=split, + time_length=1, + frame_step=1, + model_rank=0, + model_world_size=1, + obs_config=obs_config, + ) + return dataset + + +# ============================================================================ +# Comparison logic +# ============================================================================ + + +def compare_single_sample(ported_ds, ref_ds, idx: int, verbose: bool = True): + """Compare a single sample between ported and reference datasets. + + Returns a dict of comparison results. + """ + results = {"idx": idx, "pass": True, "errors": []} + + # --- Raw data comparison (before transform) --- + try: + t0 = time.time() + ported_times, ported_objs = ported_ds.get(idx) + ported_elapsed = time.time() - t0 + + t0 = time.time() + ref_times, ref_objs = ref_ds.get(idx) + ref_elapsed = time.time() - t0 + + results["ported_time_s"] = ported_elapsed + results["ref_time_s"] = ref_elapsed + + except Exception as e: + results["pass"] = False + results["errors"].append(f"Loading failed: {e}") + return results + + # Compare timestamps + for i, (pt, rt) in enumerate(zip(ported_times, ref_times)): + if str(pt) != str(rt): + results["pass"] = False + results["errors"].append(f"Time mismatch at frame {i}: {pt} vs {rt}") + + # Compare state arrays + for i, (po, ro) in enumerate(zip(ported_objs, ref_objs)): + p_state = po["state"] + r_state = ro["state"] + + if p_state.shape != r_state.shape: + results["pass"] = False + results["errors"].append( + f"State shape mismatch at frame {i}: {p_state.shape} vs {r_state.shape}" + ) + continue + + max_diff = np.max(np.abs(p_state - r_state)) + results[f"state_frame{i}_maxdiff"] = float(max_diff) + + if max_diff > 1e-6: + results["pass"] = False + results["errors"].append( + f"State value mismatch at frame {i}: max_diff={max_diff:.2e}" + ) + + # Compare observation tables + # Note: ported and reference may produce rows in different order (due to + # platform grouping within parquet row-groups). This is benign — the + # downstream transform processes all obs in a window together. We sort + # both tables by a canonical key before value comparison. + import pyarrow.compute as pc + + def _sort_obs_table(table): + """Sort by (Global_Channel_ID, Latitude, Longitude, Absolute_Obs_Time) + to produce a deterministic row order for comparison.""" + sort_keys = [ + ("Global_Channel_ID", "ascending"), + ("Latitude", "ascending"), + ("Longitude", "ascending"), + ("Absolute_Obs_Time", "ascending"), + ] + indices = pc.sort_indices(table, sort_keys=sort_keys) + return table.take(indices) + + for i, (po, ro) in enumerate(zip(ported_objs, ref_objs)): + p_obs = po.get("obs") + r_obs = ro.get("obs") or ro.get("obs_v2") # reference uses legacy key + + if p_obs is None and r_obs is None: + continue + if (p_obs is None) != (r_obs is None): + results["pass"] = False + results["errors"].append(f"Obs presence mismatch at frame {i}") + continue + + p_nrows = p_obs.num_rows + r_nrows = r_obs.num_rows + results[f"obs_frame{i}_nrows_ported"] = p_nrows + results[f"obs_frame{i}_nrows_ref"] = r_nrows + + if p_nrows != r_nrows: + results["pass"] = False + results["errors"].append( + f"Obs row count mismatch at frame {i}: {p_nrows} vs {r_nrows}" + ) + continue + + if p_nrows > 0: + # Compare schemas + p_cols = set(p_obs.schema.names) + r_cols = set(r_obs.schema.names) + if p_cols != r_cols: + results["pass"] = False + results["errors"].append( + f"Obs schema mismatch at frame {i}: " + f"ported_only={p_cols - r_cols}, ref_only={r_cols - p_cols}" + ) + continue + + # Sort both tables to canonical order before comparison + p_sorted = _sort_obs_table(p_obs) + r_sorted = _sort_obs_table(r_obs) + + # Compare observation values + p_vals = p_sorted["Observation"].to_numpy() + r_vals = r_sorted["Observation"].to_numpy() + obs_max_diff = np.nanmax(np.abs(p_vals - r_vals)) + results[f"obs_frame{i}_val_maxdiff"] = float(obs_max_diff) + if obs_max_diff > 1e-5: + results["pass"] = False + results["errors"].append( + f"Obs value mismatch at frame {i}: max_diff={obs_max_diff:.2e}" + ) + + # Also verify Global_Channel_ID sets match + p_gcids = set(p_obs["Global_Channel_ID"].to_pylist()) + r_gcids = set(r_obs["Global_Channel_ID"].to_pylist()) + if p_gcids != r_gcids: + results["pass"] = False + results["errors"].append( + f"Obs GCID set mismatch at frame {i}: " + f"ported_only={p_gcids - r_gcids}, ref_only={r_gcids - p_gcids}" + ) + + if verbose: + status = "PASS" if results["pass"] else "FAIL" + timing = ( + f"ported={results.get('ported_time_s', 0):.2f}s " + f"ref={results.get('ref_time_s', 0):.2f}s" + ) + print(f" [{status}] idx={idx:6d} {timing}") + for err in results["errors"]: + print(f" {err}") + + return results + + +def compare_transformed_sample(ported_ds, ref_ds, idx: int, verbose: bool = True): + """Compare transformed (batched) output between ported and reference. + + Uses __getitems__ to exercise the full transform pipeline. + """ + results = {"idx": idx, "pass": True, "errors": []} + + try: + t0 = time.time() + ported_batch = ported_ds.__getitems__([idx]) + ported_elapsed = time.time() - t0 + + t0 = time.time() + ref_batch = ref_ds.__getitems__([idx]) + ref_elapsed = time.time() - t0 + + results["ported_transform_s"] = ported_elapsed + results["ref_transform_s"] = ref_elapsed + + except Exception as e: + results["pass"] = False + results["errors"].append(f"Transform failed: {e}") + if verbose: + print(f" [FAIL] idx={idx:6d} Transform error: {e}") + return results + + # Compare batch dict keys + p_keys = set(ported_batch.keys()) + r_keys = set(ref_batch.keys()) + if p_keys != r_keys: + results["errors"].append( + f"Batch key mismatch: ported_only={p_keys - r_keys}, ref_only={r_keys - p_keys}" + ) + # Don't fail — extra/missing keys may be intentional + + # Compare tensor fields + for key in sorted(p_keys & r_keys): + pv = ported_batch[key] + rv = ref_batch[key] + + if isinstance(pv, torch.Tensor) and isinstance(rv, torch.Tensor): + if pv.shape != rv.shape: + results["pass"] = False + results["errors"].append( + f"Shape mismatch for '{key}': {pv.shape} vs {rv.shape}" + ) + continue + + if pv.numel() == 0: + continue + max_diff = (pv.float() - rv.float()).abs().max().item() + results[f"{key}_maxdiff"] = max_diff + + # Use loose tolerance for float transforms + tol = 1e-4 if pv.is_floating_point() else 0 + if max_diff > tol: + results["pass"] = False + results["errors"].append( + f"Value mismatch for '{key}': max_diff={max_diff:.2e}" + ) + + elif isinstance(pv, tuple) and isinstance(rv, tuple): + # unified_obs is a tuple (obs_tensors, lengths_3d) + # Row ordering may differ between ported and reference (benign — + # within each sensor group, platforms can appear in different order + # depending on parquet row-group layout). We sort both by + # (global_channel_id, latitude, longitude) before comparing values. + if len(pv) != len(rv): + results["pass"] = False + results["errors"].append( + f"Tuple length mismatch for '{key}': {len(pv)} vs {len(rv)}" + ) + continue + + if isinstance(pv[0], dict) and isinstance(rv[0], dict): + p_obs_keys = set(pv[0].keys()) + r_obs_keys = set(rv[0].keys()) + if p_obs_keys != r_obs_keys: + results["errors"].append( + f"Obs tensor key mismatch: " + f"ported_only={p_obs_keys - r_obs_keys}, " + f"ref_only={r_obs_keys - p_obs_keys}" + ) + + # Build a stable sort index using torch.lexsort-style + # multi-key sorting: (gcid, abs_time, lat, lon, observation) + def _sort_idx(obs_dict): + gcid = obs_dict.get("global_channel_id") + lat = obs_dict.get("latitude") + lon = obs_dict.get("longitude") + obs_time = obs_dict.get("absolute_obs_time") + obs_val = obs_dict.get("observation") + if gcid is None or gcid.numel() == 0: + return None + # Stack columns as (N, K) float64 for lexicographic sort. + # torch.lexsort isn't available, so we use numpy. + cols = [gcid.double().cpu().numpy()] + if obs_time is not None: + cols.append(obs_time.double().cpu().numpy()) + if lat is not None: + cols.append(lat.double().cpu().numpy()) + if lon is not None: + cols.append(lon.double().cpu().numpy()) + if obs_val is not None: + cols.append(obs_val.double().cpu().numpy()) + # np.lexsort sorts by last key first, so reverse + order = np.lexsort(cols[::-1]) + return torch.from_numpy(order).long() + + p_order = _sort_idx(pv[0]) + r_order = _sort_idx(rv[0]) + + for obs_key in sorted(p_obs_keys & r_obs_keys): + pt = pv[0][obs_key] + rt = rv[0][obs_key] + if pt.shape != rt.shape: + results["pass"] = False + results["errors"].append( + f"Obs tensor shape mismatch for '{obs_key}': " + f"{pt.shape} vs {rt.shape}" + ) + elif pt.numel() > 0: + # Apply sort order before comparison + ps = pt[p_order] if p_order is not None else pt + rs = rt[r_order] if r_order is not None else rt + d = (ps.float() - rs.float()).abs().max().item() + results[f"obs_{obs_key}_maxdiff"] = d + if d > 1e-4: + results["pass"] = False + results["errors"].append( + f"Obs tensor mismatch for '{obs_key}': " + f"max_diff={d:.2e}" + ) + + # Compare lengths_3d (sensor, batch, time) — these count obs per + # sensor/window and are order-independent as long as sensor_id + # mapping matches. + for ti, name in [(1, "lengths")]: + if ti < len(pv) and ti < len(rv): + pt, rt = pv[ti], rv[ti] + if isinstance(pt, torch.Tensor) and isinstance(rt, torch.Tensor): + if pt.shape != rt.shape: + results["pass"] = False + results["errors"].append( + f"{name} shape mismatch: {pt.shape} vs {rt.shape}" + ) + elif not torch.equal(pt, rt): + results["pass"] = False + results["errors"].append(f"{name} value mismatch") + + if verbose: + status = "PASS" if results["pass"] else "FAIL" + timing = ( + f"ported={results.get('ported_transform_s', 0):.2f}s " + f"ref={results.get('ref_transform_s', 0):.2f}s" + ) + print(f" [{status}] idx={idx:6d} {timing}") + for err in results["errors"]: + print(f" {err}") + + return results + + +# ============================================================================ +# Main driver +# ============================================================================ + + +def get_indices(mode: str, ds_len: int, custom_indices=None): + """Return sample indices based on mode.""" + if custom_indices: + return [i for i in custom_indices if i < ds_len] + + if mode == "smoketest": + # 3 indices: start, middle, near end + return [0, ds_len // 2, ds_len - 1] + + elif mode == "full": + # 50 indices spread across the dataset + n = min(50, ds_len) + step = max(1, ds_len // n) + return list(range(0, ds_len, step))[:n] + + else: + raise ValueError(f"Unknown mode: {mode}") + + +def main(): + parser = argparse.ArgumentParser( + description="Compare ported vs reference data loader outputs." + ) + parser.add_argument( + "mode", + choices=["smoketest", "full"], + help="smoketest: 3 indices, fast. full: 50 indices.", + ) + parser.add_argument( + "--indices", + type=int, + nargs="*", + default=None, + help="Override indices to compare.", + ) + parser.add_argument( + "--split", default="train", help="Dataset split (default: train)." + ) + parser.add_argument( + "--sensors", + nargs="*", + default=None, + help="Sensor list (default: atms mhs amsua amsub).", + ) + parser.add_argument( + "--transform", + action="store_true", + help="Also compare transformed (__getitems__) output.", + ) + parser.add_argument( + "--no-raw", + action="store_true", + help="Skip raw (get) comparison, only do transform.", + ) + args = parser.parse_args() + + print("=" * 70) + print(f"Loader comparison — mode={args.mode}, split={args.split}") + print("=" * 70) + + # Build datasets + print("\nBuilding ported dataset...") + t0 = time.time() + ported_ds = build_ported_dataset(split=args.split, sensors=args.sensors) + print(f" Done in {time.time() - t0:.1f}s (len={len(ported_ds)})") + + print("Building reference dataset...") + t0 = time.time() + ref_ds = build_reference_dataset(split=args.split, sensors=args.sensors) + print(f" Done in {time.time() - t0:.1f}s (len={len(ref_ds)})") + + # Verify lengths match + if len(ported_ds) != len(ref_ds): + print( + f"\nWARNING: Dataset lengths differ! " + f"ported={len(ported_ds)} vs ref={len(ref_ds)}" + ) + + indices = get_indices( + args.mode, min(len(ported_ds), len(ref_ds)), args.indices + ) + print(f"\nComparing {len(indices)} samples: {indices[:10]}{'...' if len(indices) > 10 else ''}") + + # --- Raw comparison --- + if not args.no_raw: + print(f"\n--- Raw comparison (get) ---") + raw_results = [] + for idx in indices: + r = compare_single_sample(ported_ds, ref_ds, idx) + raw_results.append(r) + + n_pass = sum(1 for r in raw_results if r["pass"]) + n_fail = len(raw_results) - n_pass + print(f"\nRaw: {n_pass}/{len(raw_results)} passed, {n_fail} failed") + + # --- Transform comparison --- + if args.transform: + print(f"\n--- Transform comparison (__getitems__) ---") + xform_results = [] + for idx in indices: + r = compare_transformed_sample(ported_ds, ref_ds, idx) + xform_results.append(r) + + n_pass = sum(1 for r in xform_results if r["pass"]) + n_fail = len(xform_results) - n_pass + print(f"\nTransform: {n_pass}/{len(xform_results)} passed, {n_fail} failed") + + # --- Summary --- + print("\n" + "=" * 70) + all_results = [] + if not args.no_raw: + all_results.extend(raw_results) + if args.transform: + all_results.extend(xform_results) + + n_total = len(all_results) + n_pass = sum(1 for r in all_results if r["pass"]) + if n_total == n_pass: + print(f"ALL {n_total} CHECKS PASSED") + else: + print(f"{n_total - n_pass}/{n_total} CHECKS FAILED") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/physicsnemo/experimental/datapipes/healda/transforms/era5_obs.py b/physicsnemo/experimental/datapipes/healda/transforms/era5_obs.py index e532abfb1e..23541c7259 100644 --- a/physicsnemo/experimental/datapipes/healda/transforms/era5_obs.py +++ b/physicsnemo/experimental/datapipes/healda/transforms/era5_obs.py @@ -444,7 +444,7 @@ def _to_device(tensor, non_blocking=True): float_metadata=meta, pix=pix, local_channel=local_channel_id_tensor, - platform=local_platform, + local_platform=local_platform, obs_type=obs_type_tensor, global_channel=global_channel_id_tensor, global_platform=platform_id_tensor, diff --git a/physicsnemo/experimental/datapipes/healda/types.py b/physicsnemo/experimental/datapipes/healda/types.py index 2b71b7e703..8e39a68399 100644 --- a/physicsnemo/experimental/datapipes/healda/types.py +++ b/physicsnemo/experimental/datapipes/healda/types.py @@ -156,7 +156,7 @@ class UnifiedObservation: # Integer metadata (each shape (n_obs,)) pix: torch.Tensor # HEALPix pixel index (NEST) local_channel: torch.Tensor - platform: torch.Tensor + local_platform: torch.Tensor obs_type: torch.Tensor global_channel: torch.Tensor @@ -181,7 +181,7 @@ def empty( float_metadata=torch.empty((0, 28), device=device), pix=torch.empty(0, dtype=torch.long, device=device), local_channel=torch.empty(0, dtype=torch.long, device=device), - platform=torch.empty(0, dtype=torch.long, device=device), + local_platform=torch.empty(0, dtype=torch.long, device=device), obs_type=torch.empty(0, dtype=torch.long, device=device), global_channel=torch.empty(0, dtype=torch.long, device=device), global_platform=torch.empty(0, dtype=torch.long, device=device), @@ -215,7 +215,7 @@ def _move(x): float_metadata=_move(self.float_metadata), pix=_move(self.pix), local_channel=_move(self.local_channel), - platform=_move(self.platform), + local_platform=_move(self.local_platform), obs_type=_move(self.obs_type), global_channel=_move(self.global_channel), hpx_level=self.hpx_level, @@ -302,7 +302,7 @@ def split_by_sensor( obs.float_metadata, obs.pix, obs.local_channel, - obs.platform, + obs.local_platform, obs.obs_type, obs.global_channel, ] @@ -326,7 +326,7 @@ def split_by_sensor( float_metadata=obs.float_metadata[:0], pix=obs.pix[:0], local_channel=obs.local_channel[:0], - platform=obs.platform[:0], + local_platform=obs.local_platform[:0], obs_type=obs.obs_type[:0], global_channel=obs.global_channel[:0], global_platform=( @@ -342,7 +342,7 @@ def split_by_sensor( float_metadata=splits[2][s_local], pix=splits[3][s_local], local_channel=splits[4][s_local], - platform=splits[5][s_local], + local_platform=splits[5][s_local], obs_type=splits[6][s_local], global_channel=splits[7][s_local], global_platform=( diff --git a/test/datapipes/healda/test_types.py b/test/datapipes/healda/test_types.py index 2fb39d0d8a..d55d9a9734 100644 --- a/test/datapipes/healda/test_types.py +++ b/test/datapipes/healda/test_types.py @@ -58,7 +58,7 @@ def make_realistic_obs( float_metadata=values.unsqueeze(1).expand(nobs, 5), pix=torch.arange(nobs, dtype=torch.long), local_channel=torch.zeros(nobs, dtype=torch.long), - platform=torch.zeros(nobs, dtype=torch.long), + local_platform=torch.zeros(nobs, dtype=torch.long), obs_type=torch.zeros(nobs, dtype=torch.long), global_channel=torch.zeros(nobs, dtype=torch.long), global_platform=torch.zeros(nobs, dtype=torch.long), @@ -114,7 +114,7 @@ def test_split_requires_lengths(): float_metadata=torch.randn(10, 5), pix=torch.zeros(10, dtype=torch.long), local_channel=torch.zeros(10, dtype=torch.long), - platform=torch.zeros(10, dtype=torch.long), + local_platform=torch.zeros(10, dtype=torch.long), obs_type=torch.zeros(10, dtype=torch.long), global_channel=torch.zeros(10, dtype=torch.long), hpx_level=6, @@ -154,7 +154,7 @@ def test_split_handles_sparse_windows(): .expand(nobs, 5), pix=torch.arange(nobs, dtype=torch.long), local_channel=torch.zeros(nobs, dtype=torch.long), - platform=torch.zeros(nobs, dtype=torch.long), + local_platform=torch.zeros(nobs, dtype=torch.long), obs_type=torch.zeros(nobs, dtype=torch.long), global_channel=torch.zeros(nobs, dtype=torch.long), global_platform=torch.zeros(nobs, dtype=torch.long), From c3a2c0c4b0af84ae9b3b51b9fe777d9f5e671f18 Mon Sep 17 00:00:00 2001 From: Peter Harrington Date: Thu, 16 Apr 2026 17:04:02 -0700 Subject: [PATCH 09/10] Revert precommit change --- .pre-commit-config.yaml | 2 +- examples/weather/healda/requirements.txt | 4 +- .../weather/healda/scripts/compare_loaders.py | 587 ------------------ 3 files changed, 2 insertions(+), 591 deletions(-) delete mode 100755 examples/weather/healda/scripts/compare_loaders.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 41ba5f4539..09803226dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,7 @@ repos: language: python types: [python] additional_dependencies: ['interrogate==1.7.0'] - exclude: ^docs/|^physicsnemo/experimental/|^test/|^examples/.*/test/ + exclude: ^docs/|^physicsnemo/experimental/|^test/ - repo: https://github.com/igorshubovych/markdownlint-cli rev: v0.35.0 diff --git a/examples/weather/healda/requirements.txt b/examples/weather/healda/requirements.txt index 37a87fa890..2921ccd5e5 100644 --- a/examples/weather/healda/requirements.txt +++ b/examples/weather/healda/requirements.txt @@ -1,6 +1,4 @@ -# nvidia-physicsnemo[datapipes-extras] -cftime -pyarrow +nvidia-physicsnemo[healda] dotenv earth2grid @ git+https://github.com/NVlabs/earth2grid.git@main matplotlib diff --git a/examples/weather/healda/scripts/compare_loaders.py b/examples/weather/healda/scripts/compare_loaders.py deleted file mode 100755 index e9c42a27ba..0000000000 --- a/examples/weather/healda/scripts/compare_loaders.py +++ /dev/null @@ -1,587 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Compare outputs between the ported data loader and the reference implementation. - -Modes: - smoketest — 3 sample indices, fast sanity check (~1 min) - full — 50 indices spread across train split (~10 min) - -Usage: - # From examples/weather/healda/ - python scripts/compare_loaders.py smoketest - python scripts/compare_loaders.py full - python scripts/compare_loaders.py full --indices 0 100 500 1000 - -Requires: - - The reference codebase importable (healda-reference/src on PYTHONPATH, or - the healda package installed). - - Environment variables from .env (ERA5_74VAR, UFS_OBS_PATH, etc.). -""" - -from __future__ import annotations - -import argparse -import os -import sys -import time - -import numpy as np -import pandas as pd -import torch - -# --------------------------------------------------------------------------- -# Path setup -# --------------------------------------------------------------------------- -RECIPE_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -REFERENCE_ROOT = os.path.join(os.path.dirname(RECIPE_ROOT), "healda-reference") - -# Load .env -from dotenv import load_dotenv - -load_dotenv(os.path.join(RECIPE_ROOT, ".env")) - - -# ============================================================================ -# Ported loader construction -# ============================================================================ - - -def build_ported_dataset(split="train", sensors=None): - """Construct ObsERA5Dataset from the ported data/ package.""" - from physicsnemo.experimental.datapipes.healda.configs.variable_configs import VARIABLE_CONFIGS - from physicsnemo.experimental.datapipes.healda.dataset import ObsERA5Dataset - from physicsnemo.experimental.datapipes.healda.loaders.ufs_obs import UFSUnifiedLoader - from physicsnemo.experimental.datapipes.healda.transforms.era5_obs import ERA5ObsTransform - - variable_config = VARIABLE_CONFIGS["era5"] - - if sensors is None: - sensors = ["atms", "mhs", "amsua", "amsub"] - - obs_path = os.environ["UFS_OBS_PATH"] - obs_loader = UFSUnifiedLoader( - data_path=obs_path, - sensors=sensors, - normalization="zscore", - obs_context_hours=(-21, 3), - ) - - transform = ERA5ObsTransform(variable_config=variable_config, sensors=sensors) - - import xarray - - era5_path = os.environ["ERA5_74VAR"] - era5_ds = xarray.open_zarr(era5_path, chunks=None) - era5_data = era5_ds["data"] - - dataset = ObsERA5Dataset( - era5_data=era5_data, - obs_loader=obs_loader, - transform=transform, - variable_config=variable_config, - split=split, - ) - return dataset - - -# ============================================================================ -# Reference loader construction -# ============================================================================ - - -def build_reference_dataset(split="train", sensors=None): - """Construct ObsERA5Dataset from the reference healda-reference codebase. - - Requires healda-reference/src and healda-reference/ on sys.path. - """ - # Add reference paths - ref_src = os.path.join(REFERENCE_ROOT, "src") - ref_private = REFERENCE_ROOT - for p in [ref_src, ref_private]: - if p not in sys.path: - sys.path.insert(0, p) - - import dotenv as _dotenv - - _dotenv.load_dotenv(os.path.join(RECIPE_ROOT, ".env")) - - from healda.config.models import ObsConfig - from private.fcn3_dataset import ObsERA5Dataset as RefObsERA5Dataset - - # Build obs_config matching default sensors - use_conv = sensors is not None and "conv" in sensors - obs_config = ObsConfig( - use_obs=True, - innovation_type="none", - context_start=-21, - context_end=3, - use_conv=use_conv, - ) - - dataset = RefObsERA5Dataset( - split=split, - time_length=1, - frame_step=1, - model_rank=0, - model_world_size=1, - obs_config=obs_config, - ) - return dataset - - -# ============================================================================ -# Comparison logic -# ============================================================================ - - -def compare_single_sample(ported_ds, ref_ds, idx: int, verbose: bool = True): - """Compare a single sample between ported and reference datasets. - - Returns a dict of comparison results. - """ - results = {"idx": idx, "pass": True, "errors": []} - - # --- Raw data comparison (before transform) --- - try: - t0 = time.time() - ported_times, ported_objs = ported_ds.get(idx) - ported_elapsed = time.time() - t0 - - t0 = time.time() - ref_times, ref_objs = ref_ds.get(idx) - ref_elapsed = time.time() - t0 - - results["ported_time_s"] = ported_elapsed - results["ref_time_s"] = ref_elapsed - - except Exception as e: - results["pass"] = False - results["errors"].append(f"Loading failed: {e}") - return results - - # Compare timestamps - for i, (pt, rt) in enumerate(zip(ported_times, ref_times)): - if str(pt) != str(rt): - results["pass"] = False - results["errors"].append(f"Time mismatch at frame {i}: {pt} vs {rt}") - - # Compare state arrays - for i, (po, ro) in enumerate(zip(ported_objs, ref_objs)): - p_state = po["state"] - r_state = ro["state"] - - if p_state.shape != r_state.shape: - results["pass"] = False - results["errors"].append( - f"State shape mismatch at frame {i}: {p_state.shape} vs {r_state.shape}" - ) - continue - - max_diff = np.max(np.abs(p_state - r_state)) - results[f"state_frame{i}_maxdiff"] = float(max_diff) - - if max_diff > 1e-6: - results["pass"] = False - results["errors"].append( - f"State value mismatch at frame {i}: max_diff={max_diff:.2e}" - ) - - # Compare observation tables - # Note: ported and reference may produce rows in different order (due to - # platform grouping within parquet row-groups). This is benign — the - # downstream transform processes all obs in a window together. We sort - # both tables by a canonical key before value comparison. - import pyarrow.compute as pc - - def _sort_obs_table(table): - """Sort by (Global_Channel_ID, Latitude, Longitude, Absolute_Obs_Time) - to produce a deterministic row order for comparison.""" - sort_keys = [ - ("Global_Channel_ID", "ascending"), - ("Latitude", "ascending"), - ("Longitude", "ascending"), - ("Absolute_Obs_Time", "ascending"), - ] - indices = pc.sort_indices(table, sort_keys=sort_keys) - return table.take(indices) - - for i, (po, ro) in enumerate(zip(ported_objs, ref_objs)): - p_obs = po.get("obs") - r_obs = ro.get("obs") or ro.get("obs_v2") # reference uses legacy key - - if p_obs is None and r_obs is None: - continue - if (p_obs is None) != (r_obs is None): - results["pass"] = False - results["errors"].append(f"Obs presence mismatch at frame {i}") - continue - - p_nrows = p_obs.num_rows - r_nrows = r_obs.num_rows - results[f"obs_frame{i}_nrows_ported"] = p_nrows - results[f"obs_frame{i}_nrows_ref"] = r_nrows - - if p_nrows != r_nrows: - results["pass"] = False - results["errors"].append( - f"Obs row count mismatch at frame {i}: {p_nrows} vs {r_nrows}" - ) - continue - - if p_nrows > 0: - # Compare schemas - p_cols = set(p_obs.schema.names) - r_cols = set(r_obs.schema.names) - if p_cols != r_cols: - results["pass"] = False - results["errors"].append( - f"Obs schema mismatch at frame {i}: " - f"ported_only={p_cols - r_cols}, ref_only={r_cols - p_cols}" - ) - continue - - # Sort both tables to canonical order before comparison - p_sorted = _sort_obs_table(p_obs) - r_sorted = _sort_obs_table(r_obs) - - # Compare observation values - p_vals = p_sorted["Observation"].to_numpy() - r_vals = r_sorted["Observation"].to_numpy() - obs_max_diff = np.nanmax(np.abs(p_vals - r_vals)) - results[f"obs_frame{i}_val_maxdiff"] = float(obs_max_diff) - if obs_max_diff > 1e-5: - results["pass"] = False - results["errors"].append( - f"Obs value mismatch at frame {i}: max_diff={obs_max_diff:.2e}" - ) - - # Also verify Global_Channel_ID sets match - p_gcids = set(p_obs["Global_Channel_ID"].to_pylist()) - r_gcids = set(r_obs["Global_Channel_ID"].to_pylist()) - if p_gcids != r_gcids: - results["pass"] = False - results["errors"].append( - f"Obs GCID set mismatch at frame {i}: " - f"ported_only={p_gcids - r_gcids}, ref_only={r_gcids - p_gcids}" - ) - - if verbose: - status = "PASS" if results["pass"] else "FAIL" - timing = ( - f"ported={results.get('ported_time_s', 0):.2f}s " - f"ref={results.get('ref_time_s', 0):.2f}s" - ) - print(f" [{status}] idx={idx:6d} {timing}") - for err in results["errors"]: - print(f" {err}") - - return results - - -def compare_transformed_sample(ported_ds, ref_ds, idx: int, verbose: bool = True): - """Compare transformed (batched) output between ported and reference. - - Uses __getitems__ to exercise the full transform pipeline. - """ - results = {"idx": idx, "pass": True, "errors": []} - - try: - t0 = time.time() - ported_batch = ported_ds.__getitems__([idx]) - ported_elapsed = time.time() - t0 - - t0 = time.time() - ref_batch = ref_ds.__getitems__([idx]) - ref_elapsed = time.time() - t0 - - results["ported_transform_s"] = ported_elapsed - results["ref_transform_s"] = ref_elapsed - - except Exception as e: - results["pass"] = False - results["errors"].append(f"Transform failed: {e}") - if verbose: - print(f" [FAIL] idx={idx:6d} Transform error: {e}") - return results - - # Compare batch dict keys - p_keys = set(ported_batch.keys()) - r_keys = set(ref_batch.keys()) - if p_keys != r_keys: - results["errors"].append( - f"Batch key mismatch: ported_only={p_keys - r_keys}, ref_only={r_keys - p_keys}" - ) - # Don't fail — extra/missing keys may be intentional - - # Compare tensor fields - for key in sorted(p_keys & r_keys): - pv = ported_batch[key] - rv = ref_batch[key] - - if isinstance(pv, torch.Tensor) and isinstance(rv, torch.Tensor): - if pv.shape != rv.shape: - results["pass"] = False - results["errors"].append( - f"Shape mismatch for '{key}': {pv.shape} vs {rv.shape}" - ) - continue - - if pv.numel() == 0: - continue - max_diff = (pv.float() - rv.float()).abs().max().item() - results[f"{key}_maxdiff"] = max_diff - - # Use loose tolerance for float transforms - tol = 1e-4 if pv.is_floating_point() else 0 - if max_diff > tol: - results["pass"] = False - results["errors"].append( - f"Value mismatch for '{key}': max_diff={max_diff:.2e}" - ) - - elif isinstance(pv, tuple) and isinstance(rv, tuple): - # unified_obs is a tuple (obs_tensors, lengths_3d) - # Row ordering may differ between ported and reference (benign — - # within each sensor group, platforms can appear in different order - # depending on parquet row-group layout). We sort both by - # (global_channel_id, latitude, longitude) before comparing values. - if len(pv) != len(rv): - results["pass"] = False - results["errors"].append( - f"Tuple length mismatch for '{key}': {len(pv)} vs {len(rv)}" - ) - continue - - if isinstance(pv[0], dict) and isinstance(rv[0], dict): - p_obs_keys = set(pv[0].keys()) - r_obs_keys = set(rv[0].keys()) - if p_obs_keys != r_obs_keys: - results["errors"].append( - f"Obs tensor key mismatch: " - f"ported_only={p_obs_keys - r_obs_keys}, " - f"ref_only={r_obs_keys - p_obs_keys}" - ) - - # Build a stable sort index using torch.lexsort-style - # multi-key sorting: (gcid, abs_time, lat, lon, observation) - def _sort_idx(obs_dict): - gcid = obs_dict.get("global_channel_id") - lat = obs_dict.get("latitude") - lon = obs_dict.get("longitude") - obs_time = obs_dict.get("absolute_obs_time") - obs_val = obs_dict.get("observation") - if gcid is None or gcid.numel() == 0: - return None - # Stack columns as (N, K) float64 for lexicographic sort. - # torch.lexsort isn't available, so we use numpy. - cols = [gcid.double().cpu().numpy()] - if obs_time is not None: - cols.append(obs_time.double().cpu().numpy()) - if lat is not None: - cols.append(lat.double().cpu().numpy()) - if lon is not None: - cols.append(lon.double().cpu().numpy()) - if obs_val is not None: - cols.append(obs_val.double().cpu().numpy()) - # np.lexsort sorts by last key first, so reverse - order = np.lexsort(cols[::-1]) - return torch.from_numpy(order).long() - - p_order = _sort_idx(pv[0]) - r_order = _sort_idx(rv[0]) - - for obs_key in sorted(p_obs_keys & r_obs_keys): - pt = pv[0][obs_key] - rt = rv[0][obs_key] - if pt.shape != rt.shape: - results["pass"] = False - results["errors"].append( - f"Obs tensor shape mismatch for '{obs_key}': " - f"{pt.shape} vs {rt.shape}" - ) - elif pt.numel() > 0: - # Apply sort order before comparison - ps = pt[p_order] if p_order is not None else pt - rs = rt[r_order] if r_order is not None else rt - d = (ps.float() - rs.float()).abs().max().item() - results[f"obs_{obs_key}_maxdiff"] = d - if d > 1e-4: - results["pass"] = False - results["errors"].append( - f"Obs tensor mismatch for '{obs_key}': " - f"max_diff={d:.2e}" - ) - - # Compare lengths_3d (sensor, batch, time) — these count obs per - # sensor/window and are order-independent as long as sensor_id - # mapping matches. - for ti, name in [(1, "lengths")]: - if ti < len(pv) and ti < len(rv): - pt, rt = pv[ti], rv[ti] - if isinstance(pt, torch.Tensor) and isinstance(rt, torch.Tensor): - if pt.shape != rt.shape: - results["pass"] = False - results["errors"].append( - f"{name} shape mismatch: {pt.shape} vs {rt.shape}" - ) - elif not torch.equal(pt, rt): - results["pass"] = False - results["errors"].append(f"{name} value mismatch") - - if verbose: - status = "PASS" if results["pass"] else "FAIL" - timing = ( - f"ported={results.get('ported_transform_s', 0):.2f}s " - f"ref={results.get('ref_transform_s', 0):.2f}s" - ) - print(f" [{status}] idx={idx:6d} {timing}") - for err in results["errors"]: - print(f" {err}") - - return results - - -# ============================================================================ -# Main driver -# ============================================================================ - - -def get_indices(mode: str, ds_len: int, custom_indices=None): - """Return sample indices based on mode.""" - if custom_indices: - return [i for i in custom_indices if i < ds_len] - - if mode == "smoketest": - # 3 indices: start, middle, near end - return [0, ds_len // 2, ds_len - 1] - - elif mode == "full": - # 50 indices spread across the dataset - n = min(50, ds_len) - step = max(1, ds_len // n) - return list(range(0, ds_len, step))[:n] - - else: - raise ValueError(f"Unknown mode: {mode}") - - -def main(): - parser = argparse.ArgumentParser( - description="Compare ported vs reference data loader outputs." - ) - parser.add_argument( - "mode", - choices=["smoketest", "full"], - help="smoketest: 3 indices, fast. full: 50 indices.", - ) - parser.add_argument( - "--indices", - type=int, - nargs="*", - default=None, - help="Override indices to compare.", - ) - parser.add_argument( - "--split", default="train", help="Dataset split (default: train)." - ) - parser.add_argument( - "--sensors", - nargs="*", - default=None, - help="Sensor list (default: atms mhs amsua amsub).", - ) - parser.add_argument( - "--transform", - action="store_true", - help="Also compare transformed (__getitems__) output.", - ) - parser.add_argument( - "--no-raw", - action="store_true", - help="Skip raw (get) comparison, only do transform.", - ) - args = parser.parse_args() - - print("=" * 70) - print(f"Loader comparison — mode={args.mode}, split={args.split}") - print("=" * 70) - - # Build datasets - print("\nBuilding ported dataset...") - t0 = time.time() - ported_ds = build_ported_dataset(split=args.split, sensors=args.sensors) - print(f" Done in {time.time() - t0:.1f}s (len={len(ported_ds)})") - - print("Building reference dataset...") - t0 = time.time() - ref_ds = build_reference_dataset(split=args.split, sensors=args.sensors) - print(f" Done in {time.time() - t0:.1f}s (len={len(ref_ds)})") - - # Verify lengths match - if len(ported_ds) != len(ref_ds): - print( - f"\nWARNING: Dataset lengths differ! " - f"ported={len(ported_ds)} vs ref={len(ref_ds)}" - ) - - indices = get_indices( - args.mode, min(len(ported_ds), len(ref_ds)), args.indices - ) - print(f"\nComparing {len(indices)} samples: {indices[:10]}{'...' if len(indices) > 10 else ''}") - - # --- Raw comparison --- - if not args.no_raw: - print(f"\n--- Raw comparison (get) ---") - raw_results = [] - for idx in indices: - r = compare_single_sample(ported_ds, ref_ds, idx) - raw_results.append(r) - - n_pass = sum(1 for r in raw_results if r["pass"]) - n_fail = len(raw_results) - n_pass - print(f"\nRaw: {n_pass}/{len(raw_results)} passed, {n_fail} failed") - - # --- Transform comparison --- - if args.transform: - print(f"\n--- Transform comparison (__getitems__) ---") - xform_results = [] - for idx in indices: - r = compare_transformed_sample(ported_ds, ref_ds, idx) - xform_results.append(r) - - n_pass = sum(1 for r in xform_results if r["pass"]) - n_fail = len(xform_results) - n_pass - print(f"\nTransform: {n_pass}/{len(xform_results)} passed, {n_fail} failed") - - # --- Summary --- - print("\n" + "=" * 70) - all_results = [] - if not args.no_raw: - all_results.extend(raw_results) - if args.transform: - all_results.extend(xform_results) - - n_total = len(all_results) - n_pass = sum(1 for r in all_results if r["pass"]) - if n_total == n_pass: - print(f"ALL {n_total} CHECKS PASSED") - else: - print(f"{n_total - n_pass}/{n_total} CHECKS FAILED") - sys.exit(1) - - -if __name__ == "__main__": - main() From 47af3af38537e6af8b44c7f3e1f29554b907517f Mon Sep 17 00:00:00 2001 From: Peter Harrington Date: Fri, 17 Apr 2026 15:43:30 -0700 Subject: [PATCH 10/10] greptile feedback --- examples/weather/healda/README.md | 1 - .../experimental/datapipes/healda/dataset.py | 1 - .../experimental/datapipes/healda/indexing.py | 8 +-- .../datapipes/healda/loaders/ufs_obs.py | 50 ++++++------------- .../experimental/datapipes/healda/samplers.py | 9 ++-- test/datapipes/healda/test_samplers.py | 6 ++- 6 files changed, 27 insertions(+), 48 deletions(-) diff --git a/examples/weather/healda/README.md b/examples/weather/healda/README.md index fce9c28f98..0f1b2f874f 100644 --- a/examples/weather/healda/README.md +++ b/examples/weather/healda/README.md @@ -155,7 +155,6 @@ sensors = ["atms", "mhs", "conv"] obs_loader = UFSUnifiedLoader( data_path="/path/to/processed_obs", sensors=sensors, - normalization="zscore", obs_context_hours=(-21, 3), ) transform = ERA5ObsTransform( diff --git a/physicsnemo/experimental/datapipes/healda/dataset.py b/physicsnemo/experimental/datapipes/healda/dataset.py index 9c68556e94..6c0e77a157 100644 --- a/physicsnemo/experimental/datapipes/healda/dataset.py +++ b/physicsnemo/experimental/datapipes/healda/dataset.py @@ -33,7 +33,6 @@ obs_loader = UFSUnifiedLoader( data_path="/path/to/obs", sensors=["atms", "mhs", "conv"], - normalization="zscore", obs_context_hours=(-21, 3), ) transform = ERA5ObsTransform(variable_config=VARIABLE_CONFIGS["era5"]) diff --git a/physicsnemo/experimental/datapipes/healda/indexing.py b/physicsnemo/experimental/datapipes/healda/indexing.py index 2c0daa3e1e..b2143db416 100644 --- a/physicsnemo/experimental/datapipes/healda/indexing.py +++ b/physicsnemo/experimental/datapipes/healda/indexing.py @@ -36,8 +36,8 @@ def split_array_contiguous(x): This detects gaps in a time array (e.g. year boundaries, missing data) and returns a list of contiguous segments. """ - if x.size == 0: - return [] + if x.size <= 1: + return [x] if x.size == 1 else [] d = x[1] - x[0] segments = [] @@ -135,10 +135,10 @@ def generate_frame_indices(self, sample_indices: torch.Tensor) -> list[list[int] def _map_logical_to_physical(self, logical_idx: int) -> int: """Map a logical sample index to a physical frame index across segments.""" - if logical_idx >= self.total_samples: + if logical_idx >= self.valid_length: raise IndexError( f"Sample index {logical_idx} out of bounds " - f"for {self.total_samples} samples" + f"for {self.valid_length} valid samples" ) segment_idx = 0 diff --git a/physicsnemo/experimental/datapipes/healda/loaders/ufs_obs.py b/physicsnemo/experimental/datapipes/healda/loaders/ufs_obs.py index 27fbe7fba2..337b217728 100644 --- a/physicsnemo/experimental/datapipes/healda/loaders/ufs_obs.py +++ b/physicsnemo/experimental/datapipes/healda/loaders/ufs_obs.py @@ -25,7 +25,6 @@ data_path="/path/to/processed_obs", sensors=["atms", "mhs", "conv"], obs_context_hours=(-3, 3), - normalization="zscore", ) result = await loader.sel_time(pd.DatetimeIndex([...])) tables = result["obs"] # list[pa.Table], one per timestamp @@ -68,7 +67,6 @@ def get_channel_table(data_path: str, filesystem=None): data_path, sensors=[], obs_context_hours=(-3, 3), - normalization="zscore", filesystem=filesystem, ).channel_table @@ -83,7 +81,6 @@ class UFSUnifiedLoader: data_path: Path to the processed observation data directory. sensors: List of sensor names to load (e.g. ``["atms", "mhs", "conv"]``). filesystem: Optional fsspec filesystem for remote access (e.g. S3). - normalization: Normalization method (``"fixed_range"`` or ``"zscore"``). innovation_type: Innovation type (``"none"``, ``"adjusted"``, ``"unadjusted"``). qc_filter: Whether to apply quality-control filtering. filter_innovation: Whether to filter based on innovation values. @@ -100,7 +97,6 @@ def __init__( data_path: str, sensors: List[str], filesystem: fsspec.AbstractFileSystem | None = None, - normalization: Literal["fixed_range", "zscore"] = "fixed_range", innovation_type: Literal["none", "adjusted", "unadjusted"] = "none", qc_filter: bool = False, filter_innovation: bool = False, @@ -114,7 +110,6 @@ def __init__( self.data_path = data_path self.sensors = sensors self.fs = filesystem - self.normalization = normalization self.innovation_type = innovation_type self.qc_filter = qc_filter self.filter_innovation = filter_innovation @@ -162,7 +157,7 @@ def channel_table(self) -> pa.Table: local_channel_ids = [] offset = 0 for i in range(len(sensor_id)): - if sensor_id[i] != sensor_id[i - 1]: + if i == 0 or sensor_id[i] != sensor_id[i - 1]: offset = i local_channel_ids.append(i - offset) array = pa.array(local_channel_ids).cast(LOCAL_CHANNEL_ID.type) @@ -203,30 +198,18 @@ def _iterate_parquet_da_windows(self, parquet_path, target_windows): .column(da_idx) .statistics ) - row_group_lo, row_group_hi = stats.min, stats.max - - this_window = None - for w in target_windows: - if row_group_lo <= w <= row_group_hi: - this_window = w - - if this_window is None: - continue - - table = parquet.read_row_group( - row_group_idx, columns=self._read_columns - ) - - if row_group_lo != row_group_hi: - mask = pc.is_in( - table["DA_window"], pa.array(list(target_windows)) + row_group_window = stats.min + if row_group_window != stats.max: + raise ValueError( + f"Expected one DA_window per row group, got " + f"[{stats.min}, {stats.max}] for {parquet_path} row_group={row_group_idx}" ) - table = table.filter(mask) - + if row_group_window not in target_windows: + continue + table = parquet.read_row_group(row_group_idx, columns=self._read_columns) if table.num_rows == 0: continue - - yield this_window, table + yield row_group_window, table except (FileNotFoundError, OSError): return @@ -239,15 +222,10 @@ def _filter_observations(self, table: pa.Table) -> pa.Table: ) def _normalize_observations(self, table: pa.Table) -> pa.Table: - if self.normalization == "fixed_range": - normalized = pc.divide(pc.subtract(table["Observation"], 0), 400 - 0) - elif self.normalization == "zscore": - normalized = pc.divide( - pc.subtract(table["Observation"], table["mean"]), - table["stddev"], - ) - else: - raise ValueError(f"Unknown normalization type: {self.normalization}") + normalized = pc.divide( + pc.subtract(table["Observation"], table["mean"]), + table["stddev"], + ) return table.set_column( table.schema.get_field_index("Observation"), "Observation", diff --git a/physicsnemo/experimental/datapipes/healda/samplers.py b/physicsnemo/experimental/datapipes/healda/samplers.py index 9ed54c76f7..c0e6748b6b 100644 --- a/physicsnemo/experimental/datapipes/healda/samplers.py +++ b/physicsnemo/experimental/datapipes/healda/samplers.py @@ -27,9 +27,10 @@ class RestartableDistributedSampler(torch.utils.data.Sampler): """A stateful distributed sampler that automatically loops over the dataset. - Each epoch generates a rank-specific random permutation. The sampler - tracks its position within the permutation so that ``restart()`` can - resume from an exact checkpoint. + Each epoch generates a shared random permutation across ranks, then + partitions it by stride so every sample is visited exactly once per epoch. + The sampler tracks its position within the permutation so that + ``restart()`` can resume from an exact checkpoint. Args: dataset: Map-style dataset (used only for ``len``). @@ -64,7 +65,7 @@ def __len__(self): def set_epoch(self, epoch): self.epoch = epoch self.iteration = 0 - rng = torch.Generator().manual_seed(self.seed + self.epoch + self.rank) + rng = torch.Generator().manual_seed(self.seed + self.epoch) permutation = torch.randperm(self.len, generator=rng) rem = self.len % self.num_replicas diff --git a/test/datapipes/healda/test_samplers.py b/test/datapipes/healda/test_samplers.py index 44ef1c1dcc..ca0621ad4e 100644 --- a/test/datapipes/healda/test_samplers.py +++ b/test/datapipes/healda/test_samplers.py @@ -82,7 +82,7 @@ def test_reproducible(): def test_multi_replica_independent(): - """Different ranks get independent permutations of the same length.""" + """Different ranks receive disjoint slices of the shared permutation.""" dataset = list(range(100)) s0 = RestartableDistributedSampler(dataset, rank=0, num_replicas=2, seed=42) s1 = RestartableDistributedSampler(dataset, rank=1, num_replicas=2, seed=42) @@ -97,8 +97,10 @@ def test_multi_replica_independent(): # Each rank gets a valid subset of indices assert all(0 <= i < 100 for i in idx0) assert all(0 <= i < 100 for i in idx1) - # Rank-dependent seed produces different orderings + # Ranks visit different indices (stride-partitioned shared permutation) assert idx0 != idx1 + # Together they cover every sample exactly once + assert sorted(idx0 + idx1) == list(range(100)) def test_len_drops_remainder():