From 1ab99d8d6d2cb8a3eb8713f885f0e8c00140da46 Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Mon, 1 Jun 2026 11:12:24 +0100 Subject: [PATCH 01/21] Add preprocessing functions --- src/tctrack/utils/__init__.py | 2 + src/tctrack/utils/preprocessing.py | 282 +++++++++++++++++++++++++++++ 2 files changed, 284 insertions(+) create mode 100644 src/tctrack/utils/preprocessing.py diff --git a/src/tctrack/utils/__init__.py b/src/tctrack/utils/__init__.py index 57bb764..7b8ecf5 100644 --- a/src/tctrack/utils/__init__.py +++ b/src/tctrack/utils/__init__.py @@ -1,8 +1,10 @@ """Package providing utility functions for the user.""" +from . import preprocessing from .metadata import load_tracker_metadata, read_tracker_metadata __all__ = [ "load_tracker_metadata", + "preprocessing", "read_tracker_metadata", ] diff --git a/src/tctrack/utils/preprocessing.py b/src/tctrack/utils/preprocessing.py new file mode 100644 index 0000000..1e070c1 --- /dev/null +++ b/src/tctrack/utils/preprocessing.py @@ -0,0 +1,282 @@ +"""Module with preprocessing functions that can be used with batching.""" + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, TypeAlias, TypeVar + +import cf +import numpy as np + +FilePaths: TypeAlias = str | Sequence[str] + + +@dataclass(frozen=True) +class FieldSelect: + """File name(s) plus the NetCDF variable name to select.""" + + files: FilePaths + var_name: str + + +FieldSource: TypeAlias = FilePaths | FieldSelect | cf.Field + + +T = TypeVar("T", cf.Field, list[cf.Field]) + + +def _write_output(result: T, output_file: str | None) -> T: + """Optionally write output before returning.""" + if output_file is not None: + cf.write(result, output_file) # type: ignore[operator] + return result + + +def read_files( + input_files: FilePaths, + output_file: str | None = None, + *, + select: str | None = None, +) -> list[cf.Field]: + """Read fields from separate files. Optionally write to a single output file.""" + fields = list(cf.read(input_files, select=select)) # type: ignore[operator] + return _write_output(fields, output_file) + + +def combine_time( + input_files: FilePaths, + time_bounds: tuple[str, str] | None = None, + output_file: str | None = None, +) -> list[cf.Field]: + """Combine files in time, with optional time bounds of [start, end).""" + fields = read_files(input_files) + + if time_bounds is not None: + time_interval = cf.wi( + cf.dt(time_bounds[0]), cf.dt(time_bounds[1]), open_upper=True + ) + fields = [field.subspace(T=time_interval) for field in fields] + + return _write_output(fields, output_file) + + +def separate_variables( + input_files: FilePaths, + output_files: dict[str, str], +) -> list[cf.Field]: + """Split variables into separate files. The output keys are nc variable names.""" + fields = {field.nc_get_variable(): field for field in read_files(input_files)} + + for var_name, output_file in output_files.items(): + if var_name not in fields: + msg = f"A variable to save ({var_name}) is not provided in the inputs." + raise ValueError(msg) + cf.write(fields[var_name], output_file) # type: ignore[operator] + + return list(fields.values()) + + +def _load_field(source: FieldSource) -> cf.Field: + """Load a field from NetCDF file(s) .""" + if isinstance(source, cf.Field): + return source + + if isinstance(source, FieldSelect): + fields = read_files(source.files, select=f"ncvar%{source.var_name}") + if not fields: + msg = f"No field with NetCDF variable name '{source.var_name}' was found." + raise ValueError(msg) + return fields[0] + + if isinstance(source, (str, Sequence)): + fields = read_files(source) + if len(fields) != 1: + msg = ( + f"Expected one field from '{source}', but found {len(fields)}. " + "Use FieldSelect(files, variable_name) to select a field." + ) + raise ValueError(msg) + return fields[0] + + msg = ( + "Invalid input type for the field source. " + "Allowed types are cf.Field, FieldSelect, or string filepath(s)." + ) + raise ValueError(msg) + + +def subsample_field( + input_: FieldSource, + subspace_kwargs: dict[str, Any], + output_file: str | None = None, + *, + squeeze: bool = True, +) -> cf.Field: + """Subsample a field using ``cf.Field.subspace``.""" + if not subspace_kwargs: + msg = "At least one subspace selector must be provided to 'subspace_kwargs'." + raise ValueError(msg) + + field = _load_field(input_) + subset = field.subspace(**subspace_kwargs) + if squeeze: + subset.squeeze(inplace=True) + return _write_output(subset, output_file) + + +def collapse_field( + input_: FieldSource, + method: str, + axes: str | Sequence[str], + output_file: str | None = None, + *, + squeeze: bool = True, +) -> cf.Field: + """Collapse a field over one or more axes.""" + field = _load_field(input_) + collapsed = field.collapse(method, axes=axes) + if squeeze: + collapsed.squeeze(inplace=True) + return _write_output(collapsed, output_file) + + +def calculate_curl_xy( + input_x: FieldSource, + input_y: FieldSource, + variable_name: str, + variable_info: dict[str, str], + output_file: str | None = None, +) -> cf.Field: + """Calculate the curl of two fields which are x and y components of a vector.""" + field_x = _load_field(input_x) + field_y = _load_field(input_y) + + curl = cf.curl_xy(field_x, field_y, radius="earth") + curl.nc_set_variable(variable_name) + for name, value in variable_info.items(): + if name == "units": + curl.override_units(value, inplace=True) + else: + curl.set_property(name, value) + return _write_output(curl, output_file) + + +def calculate_vorticity( + input_u: FieldSource, + input_v: FieldSource, + output_file: str | None = None, +) -> cf.Field: + """Calculate vorticity from colocated velocity fields.""" + return calculate_curl_xy( + input_u, + input_v, + variable_name="vorticity", + variable_info={ + "standard_name": "atmosphere_upward_absolute_vorticity", + "units": "s-1", + }, + output_file=output_file, + ) + + +def replace_fill_value( + input_: FieldSource, + fill_value: float, + output_file: str | None = None, +) -> cf.Field: + """Replace masked values in a field using ``cf.Field.filled``.""" + field = _load_field(input_) + field.filled(fill_value=fill_value, inplace=True) + return _write_output(field, output_file) + + +def set_netcdf_variable_name( + input_: FieldSource, + field_name: str, + output_file: str | None = None, + *, + coord_names: dict[str, str] | None = None, +) -> cf.Field: + """Set NetCDF variable names for a field and (optionally) its coordinates.""" + field = _load_field(input_) + field.nc_set_variable(field_name) + for coordinate, variable_name in (coord_names or {}).items(): + field.coordinate(coordinate).nc_set_variable(variable_name) + return _write_output(field, output_file) + + +def regrid_to_field( + input_: FieldSource, + target: FieldSource | cf.Domain, + output_file: str | None = None, + *, + method: str = "linear", +) -> cf.Field: + """Regrid a field onto the grid of another field / domain.""" + field = _load_field(input_) + if not isinstance(target, cf.Domain): + target = _load_field(target) + + regridded = field.regrids(target, method=method) + regridded.nc_clear_dataset_chunksizes() # Avoids a possible error when writing + return _write_output(regridded, output_file) + + +def regrid_to_lat_lon( + input_: FieldSource, + latitude: np.ndarray, + longitude: np.ndarray, + output_file: str | None = None, + *, + method: str = "linear", +) -> cf.Field: + """Regrid a field onto a latitude-longitude grid.""" + field = _load_field(input_) + + domain = field.domain.copy() + lat_coord = domain.dimension_coordinate("latitude") + lat_coord.set_data(latitude, inplace=True) + lat_coord.del_bounds() + lon_coord = domain.dimension_coordinate("longitude") + lon_coord.set_data(longitude, inplace=True) + lon_coord.del_bounds() + + regridded = field.regrids((lat_coord, lon_coord), method=method) + regridded.nc_clear_dataset_chunksizes() # Avoids a possible error when writing + return _write_output(regridded, output_file) + + +def gaussian_grid(n: int) -> tuple[np.ndarray, np.ndarray]: + """Create regular Gaussian latitude and longitude coordinates.""" + latitude = np.degrees(np.arcsin(np.polynomial.legendre.leggauss(2 * n)[0])) + longitude = np.arange(0.0, 360.0, 360.0 / (4 * n)) + return latitude, longitude + + +def regrid_to_gaussian( + input_: FieldSource, + n: int, + output_file: str | None = None, + *, + method: str = "linear", +) -> cf.Field: + """Regrid a field onto a regular Gaussian grid with n lat points per hemisphere.""" + lat, lon = gaussian_grid(n) + return regrid_to_lat_lon(input_, lat, lon, output_file=output_file, method=method) + + +__all__ = [ + "FieldSelect", + "calculate_curl_xy", + "calculate_vorticity", + "collapse_field", + "combine_time", + "gaussian_grid", + "read_files", + "regrid_to_field", + "regrid_to_gaussian", + "regrid_to_lat_lon", + "replace_fill_value", + "separate_variables", + "set_netcdf_variable_name", + "subsample_field", +] From 8c60bcdba81eb8d474162b6c5d316284950bbb86 Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Mon, 1 Jun 2026 11:14:13 +0100 Subject: [PATCH 02/21] Allow wildcard inputs --- src/tctrack/utils/preprocessing.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/tctrack/utils/preprocessing.py b/src/tctrack/utils/preprocessing.py index 1e070c1..9d40f93 100644 --- a/src/tctrack/utils/preprocessing.py +++ b/src/tctrack/utils/preprocessing.py @@ -1,5 +1,6 @@ """Module with preprocessing functions that can be used with batching.""" +import glob from collections.abc import Sequence from dataclasses import dataclass from typing import Any, TypeAlias, TypeVar @@ -21,6 +22,24 @@ class FieldSelect: FieldSource: TypeAlias = FilePaths | FieldSelect | cf.Field +def _expand_input_paths(paths: FilePaths) -> list[str]: + """Expand wildcard input paths into a list of file paths.""" + path_list = [paths] if isinstance(paths, str) else list(paths) + expanded_paths: list[str] = [] + + for path in path_list: + matches = sorted(glob.glob(path)) + if matches: + expanded_paths.extend(matches) + continue + if glob.has_magic(path): + msg = f"No files matched input pattern '{path}'." + raise FileNotFoundError(msg) + expanded_paths.append(path) + + return expanded_paths + + T = TypeVar("T", cf.Field, list[cf.Field]) @@ -38,7 +57,7 @@ def read_files( select: str | None = None, ) -> list[cf.Field]: """Read fields from separate files. Optionally write to a single output file.""" - fields = list(cf.read(input_files, select=select)) # type: ignore[operator] + fields = list(cf.read(_expand_input_paths(input_files), select=select)) # type: ignore[operator] return _write_output(fields, output_file) From 4e85cfe59596079fbcf50d950946f1a945a1c370 Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Mon, 1 Jun 2026 14:36:25 +0100 Subject: [PATCH 03/21] Add preprocessing tests --- tests/unit/utils/test_preprocessing.py | 270 +++++++++++++++++++++++++ 1 file changed, 270 insertions(+) create mode 100644 tests/unit/utils/test_preprocessing.py diff --git a/tests/unit/utils/test_preprocessing.py b/tests/unit/utils/test_preprocessing.py new file mode 100644 index 0000000..513d868 --- /dev/null +++ b/tests/unit/utils/test_preprocessing.py @@ -0,0 +1,270 @@ +"""Unit tests for preprocessing helper utilities.""" + +from pathlib import Path + +import cf +import numpy as np +import pytest + +from tctrack.utils.preprocessing import ( + FieldSelect, + _load_field, + calculate_vorticity, + collapse_field, + combine_time, + gaussian_grid, + read_files, + regrid_to_field, + regrid_to_gaussian, + replace_fill_value, + separate_variables, + set_netcdf_variable_name, + subsample_field, +) + + +def make_field(var_name: str, time: str | None = None) -> cf.Field: + """Create a small example field with an optional time value.""" + standard_names = { + "mslp": "air_pressure_at_mean_sea_level", + "u": "eastward_wind", + "v": "northward_wind", + } + + field = cf.example_field(0).copy() + field.nc_set_variable(var_name) + field.set_property("standard_name", standard_names[var_name]) + if time is not None: + field.coordinate("T").set_data([cf.dt(time)], inplace=True) + return field + + +def write_fields(fields: list[cf.Field], path: Path) -> str: + """Write one or more fields to a file and return the file path.""" + cf.write(fields, str(path)) # type: ignore[operator] + return str(path) + + +class TestPreprocessing: + """Tests for preprocessing functions.""" + + def test_read_files_combines_fields(self, tmp_path): + """Test read_files accepts files with multiple fields.""" + input_file = write_fields( + [make_field("mslp"), make_field("u")], + tmp_path / "input.nc", + ) + + fields = read_files(input_file) + + assert [field.nc_get_variable() for field in fields] == ["mslp", "u"] + + def test_read_files_combines_time(self, tmp_path): + """Test read_files accepts and combines fields split temporally over files.""" + input_files = [ + write_fields(make_field("mslp", "2000-01-01"), tmp_path / "a.nc"), + write_fields(make_field("mslp", "2000-01-02"), tmp_path / "b.nc"), + ] + + fields = read_files(input_files) + + assert len(fields) == 1 + assert fields[0].coordinate("T").size == 2 + + def test_read_files_wildcard(self, tmp_path): + """Test read_files correctly expands wildcard filepaths.""" + write_fields(make_field("mslp"), tmp_path / "a.nc") + write_fields(make_field("u"), tmp_path / "b.nc") + + fields = read_files(str(tmp_path / "*.nc")) + + assert len(fields) == 2 + + def test_read_files_wildcard_no_matches(self, tmp_path): + """Test read_files fails for wildcard paths with no matches.""" + with pytest.raises(FileNotFoundError, match="No files matched input pattern"): + read_files(str(tmp_path / "*.nc")) + + def test_read_files_output_file(self, tmp_path): + """Test read_files writes the combined output when requested.""" + input_file = write_fields(make_field("mslp"), tmp_path / "input.nc") + output_file = tmp_path / "output.nc" + + fields = read_files(input_file, str(output_file)) + + assert len(fields) == 1 + assert output_file.exists() + assert cf.read(str(output_file))[0].nc_get_variable() == "mslp" + + def test_combine_time_bounds(self, tmp_path): + """Test combine_time correctly selects data in time bounds.""" + input_files = [ + write_fields(make_field("mslp", "2000-01-01"), tmp_path / "a.nc"), + write_fields(make_field("mslp", "2000-01-02"), tmp_path / "b.nc"), + ] + + fields = combine_time(input_files, ("2000-01-01", "2000-01-02")) + + assert len(fields) == 1 # Same field (mslp) + assert fields[0].coordinate("T").size == 1 # Upper bound is excluded + + def test_separate_varibles(self, tmp_path): + """Test separate_variables correctly splits variables across multiple files.""" + input_file = write_fields( + [make_field("mslp"), make_field("u")], + tmp_path / "input.nc", + ) + output_files = { + "mslp": str(tmp_path / "mslp.nc"), + "u": str(tmp_path / "u.nc"), + } + + fields = separate_variables(input_file, output_files) + + assert [field.nc_get_variable() for field in fields] == ["mslp", "u"] + assert read_files(output_files["mslp"])[0] == fields[0] + assert read_files(output_files["u"])[0] == fields[1] + + def test_separate_varibles_invalid(self, tmp_path): + """Test separate_variables fails if an invalid variable name is given.""" + input_file = write_fields(make_field("mslp"), tmp_path / "input.nc") + + with pytest.raises(ValueError, match=r"A variable to save \(invalid\)"): + separate_variables(input_file, {"invalid": str(tmp_path / "output.nc")}) + + def test_load_field_accepts_fields(self): + """Test _load_field accepts in-memory fields.""" + field = make_field("mslp") + + assert _load_field(field) is field + + def test_load_field_accepts_files(self, tmp_path): + """Test _load_field accepts a filename / list of files.""" + file_a = write_fields(make_field("mslp", "2000-01-01"), tmp_path / "a.nc") + file_b = write_fields(make_field("mslp", "2000-01-02"), tmp_path / "b.nc") + + field = _load_field([file_a, file_b]) + + assert field.nc_get_variable() == "mslp" + assert field.coordinate("T").size == 2 + + def test_load_field_rejects_multifield_files(self, tmp_path): + """Test _load_field rejects files with multiple fields.""" + input_file = write_fields( + [make_field("u"), make_field("v")], + tmp_path / "input.nc", + ) + + with pytest.raises( + ValueError, + match=r"Use FieldSelect\(files, variable_name\) to select a field", + ): + _load_field(input_file) + + def test_load_field_accepts_field_select(self, tmp_path): + """Test _load_field selects a field when given a FieldSelect object.""" + input_file = write_fields( + [make_field("mslp"), make_field("u")], + tmp_path / "input.nc", + ) + + field = _load_field(FieldSelect(input_file, "u")) + + assert field.nc_get_variable() == "u" + + def test_load_field_rejects_missing_field_select(self, tmp_path): + """Test _load_field fails when a selected variable is missing.""" + input_file = write_fields(make_field("mslp"), tmp_path / "input.nc") + + with pytest.raises( + ValueError, + match="No field with NetCDF variable name 'invalid' was found", + ): + _load_field(FieldSelect(input_file, "invalid")) + + def test_subsample_field(self): + """Test subsample_field works correctly.""" + field = make_field("mslp") + + subset = subsample_field(field, {"X": slice(2)}) + + assert subset.shape == (field.shape[0], 2) + + def test_subsample_field_rejects_empty_subspace_kwargs(self): + """Test subsample_field fails when no selectors are provided.""" + with pytest.raises(ValueError, match="At least one subspace selector"): + subsample_field(make_field("mslp"), {}) + + def test_collapse_field(self): + """Test collapse_field works correctly.""" + field = make_field("mslp") + + collapsed = collapse_field(field, "mean", "X") + + assert collapsed.shape == (field.axis_size("latitude"),) + assert np.allclose( + collapsed.array, field.collapse("mean", axes="X", squeeze=True).array + ) + + def test_calculate_vorticity(self): + """Test calculate_vorticity works correctly.""" + field_u = make_field("u") + field_v = make_field("v") + + vorticity = calculate_vorticity(field_u, field_v) + + assert vorticity.nc_get_variable() == "vorticity" + assert ( + vorticity.get_property("standard_name") + == "atmosphere_upward_absolute_vorticity" + ) + assert vorticity.get_property("units") == "s-1" + + def test_replace_fill_value(self): + """Test replace_fill_value works correctly.""" + field = make_field("mslp") + field[0, 0] = cf.masked + + filled = replace_fill_value(field, -1.0) + + assert filled.array[0, 0] == pytest.approx(-1.0) + + def test_set_netcdf_variable_name(self): + """Test set_netcdf_variable_name works correctly.""" + field = make_field("mslp") + + renamed = set_netcdf_variable_name( + field, + "pressure", + coord_names={"X": "longitude", "Y": "latitude"}, + ) + + assert renamed.nc_get_variable() == "pressure" + assert renamed.coordinate("X").nc_get_variable() == "longitude" + assert renamed.coordinate("Y").nc_get_variable() == "latitude" + + def test_regrid_to_field(self): + """Test regrid_to_field works correctly.""" + target = make_field("u") + + regridded = regrid_to_field(make_field("v"), target) + + assert regridded.shape == target.shape + + def test_gaussian_grid(self): + """Test Gaussian grid helper returns the expected coordinate sizes.""" + latitude, longitude = gaussian_grid(4) + + assert len(latitude) == 8 + assert len(longitude) == 16 + assert longitude[0] == pytest.approx(0.0) + assert longitude[-1] == pytest.approx(337.5) + assert latitude[0] == pytest.approx(-latitude[-1]) + + def test_regrid_to_gaussian(self): + """Test regrid_to_gaussian works correctly.""" + field = make_field("mslp") + + regridded = regrid_to_gaussian(field, 4) + + assert regridded.shape == (8, 16) From e04b586b7c574bd17752c25bd30b1b05148156a7 Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Mon, 1 Jun 2026 15:56:59 +0100 Subject: [PATCH 04/21] Update the preprocessing api docs --- docs/api/utils_api.rst | 8 + src/tctrack/utils/preprocessing.py | 291 +++++++++++++++++++++++++---- 2 files changed, 267 insertions(+), 32 deletions(-) diff --git a/docs/api/utils_api.rst b/docs/api/utils_api.rst index 48db760..9b9e326 100644 --- a/docs/api/utils_api.rst +++ b/docs/api/utils_api.rst @@ -4,4 +4,12 @@ Utility functions .. automodule:: tctrack.utils .. autofunction:: load_tracker_metadata + .. autofunction:: read_tracker_metadata + +Preprocessing functions +----------------------- + +.. automodule:: tctrack.utils.preprocessing + :members: + :member-order: bysource diff --git a/src/tctrack/utils/preprocessing.py b/src/tctrack/utils/preprocessing.py index 9d40f93..e4805ca 100644 --- a/src/tctrack/utils/preprocessing.py +++ b/src/tctrack/utils/preprocessing.py @@ -13,16 +13,32 @@ @dataclass(frozen=True) class FieldSelect: - """File name(s) plus the NetCDF variable name to select.""" + """Class containing the file name(s) plus the NetCDF variable name to select. - files: FilePaths + Necessary for choosing a variable from files which contain multiple. + + Parameters + ---------- + files : str | Sequence[str] + Input file path(s) to read from. ``glob`` pattern matching allowed. + var_name : str + NetCDF variable name to select from the input files. + """ + + files: str | Sequence[str] var_name: str -FieldSource: TypeAlias = FilePaths | FieldSelect | cf.Field +FieldSource: TypeAlias = str | Sequence[str] | FieldSelect | cf.Field +"""Type alias for the allowed sources for ``cf.Field`` arguments. + +The ``cf.Field`` can be passed directly or using the path(s) CF-NetCDF file(s). +If the file(s) contain multiple fields then :class:`FieldSelect` should be used to +specify which to use. +""" -def _expand_input_paths(paths: FilePaths) -> list[str]: +def _expand_input_paths(paths: str | Sequence[str]) -> list[str]: """Expand wildcard input paths into a list of file paths.""" path_list = [paths] if isinstance(paths, str) else list(paths) expanded_paths: list[str] = [] @@ -51,22 +67,52 @@ def _write_output(result: T, output_file: str | None) -> T: def read_files( - input_files: FilePaths, + input_files: str | Sequence[str], output_file: str | None = None, *, select: str | None = None, ) -> list[cf.Field]: - """Read fields from separate files. Optionally write to a single output file.""" + """Read fields from one or more files. + + Parameters + ---------- + input_files : str | Sequence[str] + Input file path(s) to read. ``glob`` pattern matching allowed. + output_file : str | None, optional + Output file to write the loaded fields to. + select : str | None, optional + Optional field selection for ``cf.read``. + + Returns + ------- + list[cf.Field] + The list of fields read from the input files. + """ fields = list(cf.read(_expand_input_paths(input_files), select=select)) # type: ignore[operator] return _write_output(fields, output_file) def combine_time( - input_files: FilePaths, + input_files: str | Sequence[str], time_bounds: tuple[str, str] | None = None, output_file: str | None = None, ) -> list[cf.Field]: - """Combine files in time, with optional time bounds of [start, end).""" + """Combine files in time. + + Parameters + ---------- + input_files : str | Sequence[str] + Input file path(s) to combine. ``glob`` pattern matching allowed. + time_bounds : tuple[str, str] | None, optional + Optional start and end datetime strings. The end bound is open / exclusive. + output_file : str | None, optional + Output file to write the result to. + + Returns + ------- + list[cf.Field] + The list of combined fields. + """ fields = read_files(input_files) if time_bounds is not None: @@ -79,10 +125,23 @@ def combine_time( def separate_variables( - input_files: FilePaths, + input_files: str | Sequence[str], output_files: dict[str, str], ) -> list[cf.Field]: - """Split variables into separate files. The output keys are nc variable names.""" + """Split variables into separate files. + + Parameters + ---------- + input_files : str | Sequence[str] + Input file path(s) to read. ``glob`` pattern matching allowed. + output_files : dict[str, str] + Mapping from NetCDF variable name to output file path. + + Returns + ------- + list[cf.Field] + The list of fields read from the input files. + """ fields = {field.nc_get_variable(): field for field in read_files(input_files)} for var_name, output_file in output_files.items(): @@ -95,7 +154,7 @@ def separate_variables( def _load_field(source: FieldSource) -> cf.Field: - """Load a field from NetCDF file(s) .""" + """Load a single field from an in-memory field or file input.""" if isinstance(source, cf.Field): return source @@ -130,7 +189,24 @@ def subsample_field( *, squeeze: bool = True, ) -> cf.Field: - """Subsample a field using ``cf.Field.subspace``.""" + """Subsample a field using ``cf.Field.subspace``. + + Parameters + ---------- + input_ : FieldSource + A field, file path(s), or :class:`FieldSelect` describing which field to load. + subspace_kwargs : dict[str, Any] + Keyword arguments passed to ``cf.Field.subspace``. + output_file : str | None, optional + Output file to write the result to. + squeeze : bool, optional + Whether to squeeze size-1 dimensions after subspacing. + + Returns + ------- + cf.Field + The subsampled field. + """ if not subspace_kwargs: msg = "At least one subspace selector must be provided to 'subspace_kwargs'." raise ValueError(msg) @@ -150,7 +226,26 @@ def collapse_field( *, squeeze: bool = True, ) -> cf.Field: - """Collapse a field over one or more axes.""" + """Collapse a field over one or more axes. + + Parameters + ---------- + input_ : FieldSource + A field, file path(s), or :class:`FieldSelect` describing which field to load. + method : str + Collapse method passed to ``cf.Field.collapse``. E.g. ``"mean"``, ``"minimum"``. + axes : str | Sequence[str] + Axis or axes to collapse over. + output_file : str | None, optional + Output file to write the collapsed field to. + squeeze : bool, optional + Whether to squeeze size-1 dimensions after collapsing. + + Returns + ------- + cf.Field + The collapsed field. + """ field = _load_field(input_) collapsed = field.collapse(method, axes=axes) if squeeze: @@ -165,7 +260,26 @@ def calculate_curl_xy( variable_info: dict[str, str], output_file: str | None = None, ) -> cf.Field: - """Calculate the curl of two fields which are x and y components of a vector.""" + """Calculate the curl of x and y vector components. + + Parameters + ---------- + input_x : FieldSource + Field for the x component. + input_y : FieldSource + Field for the y component. + variable_name : str + NetCDF variable name for the output field. + variable_info : dict[str, str] + Field properties to set on the output. + output_file : str | None, optional + Output file to write the curl field to. + + Returns + ------- + cf.Field + Curl field derived from the two inputs. + """ field_x = _load_field(input_x) field_y = _load_field(input_y) @@ -184,7 +298,22 @@ def calculate_vorticity( input_v: FieldSource, output_file: str | None = None, ) -> cf.Field: - """Calculate vorticity from colocated velocity fields.""" + """Calculate vorticity from colocated velocity fields. + + Parameters + ---------- + input_u : FieldSource + Field for the eastward velocity component. + input_v : FieldSource + Field for the northward velocity component. + output_file : str | None, optional + Output file to write the vorticity field to. + + Returns + ------- + cf.Field + Vorticity field. + """ return calculate_curl_xy( input_u, input_v, @@ -202,7 +331,22 @@ def replace_fill_value( fill_value: float, output_file: str | None = None, ) -> cf.Field: - """Replace masked values in a field using ``cf.Field.filled``.""" + """Replace masked values in a field using ``cf.Field.filled``. + + Parameters + ---------- + input_ : FieldSource + A field, file path(s), or :class:`FieldSelect` describing which field to load. + fill_value : float + Value for missing data. + output_file : str | None, optional + Output file to write the updated field to. + + Returns + ------- + cf.Field + Field with fill value replaced. + """ field = _load_field(input_) field.filled(fill_value=fill_value, inplace=True) return _write_output(field, output_file) @@ -215,7 +359,25 @@ def set_netcdf_variable_name( *, coord_names: dict[str, str] | None = None, ) -> cf.Field: - """Set NetCDF variable names for a field and (optionally) its coordinates.""" + """Set NetCDF variable names for a field and, optionally, its coordinates. + + Parameters + ---------- + input_ : FieldSource + A field, file path(s), or :class:`FieldSelect` describing which field to load. + field_name : str + NetCDF variable name for the field. + output_file : str | None, optional + Output file to write the updated field to. + coord_names : dict[str, str] | None, optional + Optional updated NetCDF variable names for coordinates. Keys are the standard + names. + + Returns + ------- + cf.Field + Field with updated NetCDF variable names. + """ field = _load_field(input_) field.nc_set_variable(field_name) for coordinate, variable_name in (coord_names or {}).items(): @@ -230,7 +392,24 @@ def regrid_to_field( *, method: str = "linear", ) -> cf.Field: - """Regrid a field onto the grid of another field / domain.""" + """Regrid a field onto the grid of another field or domain. + + Parameters + ---------- + input_ : FieldSource + A field, file path(s), or :class:`FieldSelect` describing the field to regrid. + target : FieldSource | cf.Domain + Target field or domain that supplies the destination grid. + output_file : str | None, optional + Output file to write the regridded field to. + method : str, optional + Regridding method passed to ``cf.Field.regrids``. + + Returns + ------- + cf.Field + Regridded field. + """ field = _load_field(input_) if not isinstance(target, cf.Domain): target = _load_field(target) @@ -248,7 +427,26 @@ def regrid_to_lat_lon( *, method: str = "linear", ) -> cf.Field: - """Regrid a field onto a latitude-longitude grid.""" + """Regrid a field onto a latitude-longitude grid. + + Parameters + ---------- + input_ : FieldSource + A field, file path(s), or :class:`FieldSelect` describing the field to regrid. + latitude : np.ndarray + Latitude coordinate values for the target grid. + longitude : np.ndarray + Longitude coordinate values for the target grid. + output_file : str | None, optional + Output file to write the regridded field to. + method : str, optional + Regridding method passed to ``cf.Field.regrids``. + + Returns + ------- + cf.Field + Regridded field on the requested latitude-longitude grid. + """ field = _load_field(input_) domain = field.domain.copy() @@ -265,7 +463,18 @@ def regrid_to_lat_lon( def gaussian_grid(n: int) -> tuple[np.ndarray, np.ndarray]: - """Create regular Gaussian latitude and longitude coordinates.""" + """Create regular Gaussian latitude and longitude coordinates. + + Parameters + ---------- + n : int + Number of latitude points per hemisphere. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + Latitude and longitude coordinate arrays. + """ latitude = np.degrees(np.arcsin(np.polynomial.legendre.leggauss(2 * n)[0])) longitude = np.arange(0.0, 360.0, 360.0 / (4 * n)) return latitude, longitude @@ -278,24 +487,42 @@ def regrid_to_gaussian( *, method: str = "linear", ) -> cf.Field: - """Regrid a field onto a regular Gaussian grid with n lat points per hemisphere.""" + """Regrid a field onto a regular Gaussian grid. + + Parameters + ---------- + input_ : FieldSource + A field, file path(s), or :class:`FieldSelect` describing the field to regrid. + n : int + Number of latitude points per hemisphere for the target gaussian grid. + output_file : str | None, optional + Output file to write the regridded field to. + method : str, optional + Regridding method passed to ``cf.Field.regrids``. + + Returns + ------- + cf.Field + Regridded field on the Gaussian grid. + """ lat, lon = gaussian_grid(n) return regrid_to_lat_lon(input_, lat, lon, output_file=output_file, method=method) -__all__ = [ - "FieldSelect", +__all__ = [ # noqa: RUF022 # Prevent reorder for a more logical order in the api docs + "read_files", + "combine_time", + "separate_variables", + "subsample_field", + "collapse_field", "calculate_curl_xy", "calculate_vorticity", - "collapse_field", - "combine_time", - "gaussian_grid", - "read_files", - "regrid_to_field", - "regrid_to_gaussian", - "regrid_to_lat_lon", "replace_fill_value", - "separate_variables", "set_netcdf_variable_name", - "subsample_field", + "regrid_to_field", + "regrid_to_lat_lon", + "gaussian_grid", + "regrid_to_gaussian", + "FieldSource", + "FieldSelect", ] From 19bf4cf3ae012907310b84a1fb0ac17cbae3a660 Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Mon, 1 Jun 2026 17:15:38 +0100 Subject: [PATCH 05/21] Move preprocessing to it's own module --- docs/api/index.rst | 1 + docs/api/preprocessing_api.rst | 6 ++++++ docs/api/utils_api.rst | 7 ------- src/tctrack/__init__.py | 4 ++-- src/tctrack/{utils => }/preprocessing.py | 2 +- src/tctrack/utils/__init__.py | 2 -- tests/unit/{utils => preprocessing}/test_preprocessing.py | 4 ++-- 7 files changed, 12 insertions(+), 14 deletions(-) create mode 100644 docs/api/preprocessing_api.rst rename src/tctrack/{utils => }/preprocessing.py (99%) rename tests/unit/{utils => preprocessing}/test_preprocessing.py (99%) diff --git a/docs/api/index.rst b/docs/api/index.rst index deec32c..623aae5 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -9,4 +9,5 @@ API Documentation TSTORMS Tempest Extremes TRACK + Preprocessing Utility functions diff --git a/docs/api/preprocessing_api.rst b/docs/api/preprocessing_api.rst new file mode 100644 index 0000000..282869c --- /dev/null +++ b/docs/api/preprocessing_api.rst @@ -0,0 +1,6 @@ +Preprocessing functions +======================= + +.. automodule:: tctrack.preprocessing + :members: + :member-order: bysource diff --git a/docs/api/utils_api.rst b/docs/api/utils_api.rst index 9b9e326..2bb7446 100644 --- a/docs/api/utils_api.rst +++ b/docs/api/utils_api.rst @@ -6,10 +6,3 @@ Utility functions .. autofunction:: load_tracker_metadata .. autofunction:: read_tracker_metadata - -Preprocessing functions ------------------------ - -.. automodule:: tctrack.utils.preprocessing - :members: - :member-order: bysource diff --git a/src/tctrack/__init__.py b/src/tctrack/__init__.py index 1a93770..0d93c9e 100644 --- a/src/tctrack/__init__.py +++ b/src/tctrack/__init__.py @@ -1,5 +1,5 @@ """Package providing tropical cyclone tracking utilities.""" -from tctrack import core, tempest_extremes, track, tstorms +from tctrack import core, preprocessing, tempest_extremes, track, tstorms, utils -__all__ = ["core", "tempest_extremes", "track", "tstorms"] +__all__ = ["core", "preprocessing", "tempest_extremes", "track", "tstorms", "utils"] diff --git a/src/tctrack/utils/preprocessing.py b/src/tctrack/preprocessing.py similarity index 99% rename from src/tctrack/utils/preprocessing.py rename to src/tctrack/preprocessing.py index e4805ca..5ed861d 100644 --- a/src/tctrack/utils/preprocessing.py +++ b/src/tctrack/preprocessing.py @@ -200,7 +200,7 @@ def subsample_field( output_file : str | None, optional Output file to write the result to. squeeze : bool, optional - Whether to squeeze size-1 dimensions after subspacing. + Whether to squeeze size-1 dimensions after subspacing. Default: ``False``. Returns ------- diff --git a/src/tctrack/utils/__init__.py b/src/tctrack/utils/__init__.py index 7b8ecf5..57bb764 100644 --- a/src/tctrack/utils/__init__.py +++ b/src/tctrack/utils/__init__.py @@ -1,10 +1,8 @@ """Package providing utility functions for the user.""" -from . import preprocessing from .metadata import load_tracker_metadata, read_tracker_metadata __all__ = [ "load_tracker_metadata", - "preprocessing", "read_tracker_metadata", ] diff --git a/tests/unit/utils/test_preprocessing.py b/tests/unit/preprocessing/test_preprocessing.py similarity index 99% rename from tests/unit/utils/test_preprocessing.py rename to tests/unit/preprocessing/test_preprocessing.py index 513d868..5afd181 100644 --- a/tests/unit/utils/test_preprocessing.py +++ b/tests/unit/preprocessing/test_preprocessing.py @@ -1,4 +1,4 @@ -"""Unit tests for preprocessing helper utilities.""" +"""Unit tests for preprocessing functions.""" from pathlib import Path @@ -6,7 +6,7 @@ import numpy as np import pytest -from tctrack.utils.preprocessing import ( +from tctrack.preprocessing import ( FieldSelect, _load_field, calculate_vorticity, From 6ebfc9063f1a539630796aff9c3217206c196edc Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Mon, 1 Jun 2026 17:17:39 +0100 Subject: [PATCH 06/21] Add preprocessing function examples to the docs --- docs/data/preprocessing_data.rst | 95 +++++++++++++++++++++++++++++++- src/tctrack/preprocessing.py | 2 +- 2 files changed, 94 insertions(+), 3 deletions(-) diff --git a/docs/data/preprocessing_data.rst b/docs/data/preprocessing_data.rst index 7d56e24..4414042 100644 --- a/docs/data/preprocessing_data.rst +++ b/docs/data/preprocessing_data.rst @@ -6,6 +6,11 @@ we detail how to perform some of the typically required preprocessing steps usin cf-python library. Other tools can be used for the same tasks, however we focus on cf-python since it provides a uniform interface and it is a dependency of TCTrack. +We also provide simple wrapper functions in :mod:`tctrack.preprocessing` that can +simplify each of these tasks. Examples of these are given in each of the subsections +below. These functions also return the fields so the output files do not need to be +written every time. + For full documentation of the routines described on these pages and more see the `cf python documentation `_. @@ -31,6 +36,14 @@ cf-python: # Write the combined data to a single file cf.write(field, "combined-output.nc") +Or, equivalently, using TCTrack: + +.. code-block:: python + + tctrack.preprocessing.combine_time( + input_files, ["1950-01-01", "1950-04-01"], "combined-output.nc" + ) + Combine Variables ----------------- @@ -46,6 +59,14 @@ read them in separately and then write them together: # Write the combined fields to a single file cf.write([field1, field2], "combined_file.nc") +Using TCTrack: + +.. code-block:: python + + tctrack.preprocessing.read_files( + ["var1_file.nc", "var2_file.nc"], "combined_file.nc" + ) + Separating Variables -------------------- @@ -61,6 +82,15 @@ If variables instead need to be separated into multiple files, such as in :doc:` cf.write(field1, "var1_file.nc") cf.write(field2, "var2_file.nc") +Using TCTrack: + +.. code-block:: python + + tctrack.preprocessing.separate_variables( + "combined_file.nc", + {"var1": "var1_file.nc", "var2": "var2_file.nc"}, + ) + Subsampling ----------- @@ -99,6 +129,19 @@ To remove the single-valued coordinate from the field use cf-python's # or, for a new field field3 = field2.squeeze() +Using TCTrack, single-valued coordinates can be removed using the ``squeeze`` argument +(see the first example below). + +.. code-block:: python + + field2 = tctrack.preprocessing.subsample_field( + "var1_file.nc", {"Z": [5]}, squeeze=True + ) + field3 = tctrack.preprocessing.subsample_field("var1_file.nc", {"X": [0, 5]}) + field4 = tctrack.preprocessing.subsample_field( + "var1_file.nc", {"Y": slice(3, -3, 2)} + ) + Operations ---------- @@ -122,6 +165,14 @@ For example, to calculate vorticity from coincident velocity data we can use ``c # Save the new variable to NetCDF cf.write(w_field, "vorticity_file.nc") +Using TCTrack: + +.. code-block:: python + + tctrack.preprocessing.calculate_vorticity( + "u_file.nc", "v_file.nc", "vorticity_file.nc" + ) + Or to take a mean over a coordinate: .. code-block:: python @@ -136,19 +187,29 @@ Or to take a mean over a coordinate: # Save the new variable to NetCDF cf.write(field_zonal_mean, "zonal_mean_file.nc") +Using TCTrack: + +.. code-block:: python + + tctrack.preprocessing.collapse_field("file.nc", "mean", "X", "zonal_mean_file.nc") + Setting Fill Values ^^^^^^^^^^^^^^^^^^^ Sometimes it us useful to replace fill values after an operation before writing to file. This can be done using cf-python's ``filled`` routine. For example, after to set any null or masked values to ``0.0`` after calculating -vorticity above use: +vorticity above use the following before writing to file. .. code-block:: python w_field.filled(fill_value=0.0, inplace=True) -before writing to file. +Using TCTrack: + +.. code-block:: python + + tctrack.preprocessing.replace_fill_value(w_field, 0.0, "output.nc") Set NetCDF Variable Name ------------------------ @@ -167,6 +228,14 @@ To set specfic NetCDF variable names for the fields and coordinates you can use # Save with the new netcdf variable names cf.write(field, "slp_file.nc") +Using TCTrack: + +.. code-block:: python + + tctrack.preprocessing.set_netcdf_variable_name( + "var1_file.nc", "slp", "slp_file.nc", coord_names={"latitude": "lat"} + ) + Regridding ---------- @@ -208,6 +277,22 @@ To regrid onto a new grid: Note that regridding can be performed inplace using ``inplace=True``. +Using TCTrack: + +.. code-block:: python + + # Regrid onto a different variable + tctrack.preprocessing.regrid_to_field( + "var1_file.nc", "var2_file.nc", "var1_regridded.nc" + ) + + # Regrid onto a new grid + latitude = np.arange(-90, 91, 1) + longitude = np.arange(-180, 181, 1) + tctrack.preprocessing.regrid_to_lat_lon( + "var1_file.nc", latitude, longitude, "var1_regridded.nc" + ) + Gaussian Grid ^^^^^^^^^^^^^ @@ -238,3 +323,9 @@ objects to be used for the regridding. # Regrid field = field.regrids((lat_coord, lon_coord), method="linear") field.nc_clear_dataset_chunksizes() # Avoids a possible error when writing + +Using TCTrack: + +.. code-block:: python + + tctrack.preprocessing.regrid_to_gaussian("var1_file.nc", 256, "var1_regridded.nc") diff --git a/src/tctrack/preprocessing.py b/src/tctrack/preprocessing.py index 5ed861d..1b7af31 100644 --- a/src/tctrack/preprocessing.py +++ b/src/tctrack/preprocessing.py @@ -187,7 +187,7 @@ def subsample_field( subspace_kwargs: dict[str, Any], output_file: str | None = None, *, - squeeze: bool = True, + squeeze: bool = False, ) -> cf.Field: """Subsample a field using ``cf.Field.subspace``. From 5303d88ec96b453ea98f03bf90f4ed1abce43e7b Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Mon, 1 Jun 2026 23:19:34 +0100 Subject: [PATCH 07/21] Use preprocessing functions in tutorial Also fix issue by respecifying the netcdf backend --- src/tctrack/preprocessing.py | 6 +- tutorial/regrid.py | 150 +++++++++++++---------------------- 2 files changed, 58 insertions(+), 98 deletions(-) diff --git a/src/tctrack/preprocessing.py b/src/tctrack/preprocessing.py index 1b7af31..3f2110a 100644 --- a/src/tctrack/preprocessing.py +++ b/src/tctrack/preprocessing.py @@ -88,7 +88,11 @@ def read_files( list[cf.Field] The list of fields read from the input files. """ - fields = list(cf.read(_expand_input_paths(input_files), select=select)) # type: ignore[operator] + fields = list( + cf.read( # type: ignore[operator] + _expand_input_paths(input_files), select=select, netcdf_backend="netCDF4" + ) + ) return _write_output(fields, output_file) diff --git a/tutorial/regrid.py b/tutorial/regrid.py index 182a8b3..9f996ba 100644 --- a/tutorial/regrid.py +++ b/tutorial/regrid.py @@ -5,6 +5,8 @@ import cf +from tctrack import preprocessing + # Set up file structure data_dir = "data/" data_out = "data_processed/" @@ -12,56 +14,36 @@ # Define time window for data - ASO 1950 -ASO_1950 = cf.wi(cf.dt("1950-08-01"), cf.dt("1950-11-01"), open_upper=True) +time_bounds = ("1950-08-01", "1950-11-01") +time_window = cf.wi(cf.dt(time_bounds[0]), cf.dt(time_bounds[1]), open_upper=True) # ======== Tempest Extremes ======== # Extract ASO from annual data files print("Extracting subspace from psl...", end="", flush=True) -field_psl = cf.read( +field_psl = preprocessing.combine_time( f"{data_dir}/psl_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500101-19501230.nc", - netcdf_backend="netCDF4", -)[0] -field_psl = field_psl.subspace(T=ASO_1950) -print("writing data...", end="", flush=True) -cf.write( - field_psl, + time_bounds, f"{data_out}/psl_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500801-19501030.nc", -) +)[0] print("done.") print("Extracting subspace from sfcWind...", end="", flush=True) -field_sfcWind = cf.read( +preprocessing.combine_time( f"{data_dir}/sfcWind_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500101-19501230.nc", - netcdf_backend="netCDF4", -)[0] -field_sfcWind = field_sfcWind.subspace(T=ASO_1950) -print("writing data...", end="", flush=True) -cf.write( - field_sfcWind, + time_bounds, f"{data_out}/sfcWind_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500801-19501030.nc", ) -del field_sfcWind print("done.") -# Combine the monthly zg files into one, subspacing in time to the above resolutions: +# Combine the monthly zg files into one, subspacing in time print("Combining zg files into one...", end="", flush=True) -input_files = [ - f"{data_dir}/zg7h_Prim3hrPt_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_195008010000-195008302100.nc", - f"{data_dir}/zg7h_Prim3hrPt_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_195009010000-195009302100.nc", - f"{data_dir}/zg7h_Prim3hrPt_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_195010010000-195010302100.nc", -] - -field_zg_in = cf.read(input_files, netcdf_backend="netCDF4")[0] -# zg is 3hr data from 00:00 but we want daily at 12:00, so subspace with a slice -field_zg = field_zg_in.subspace(T=slice(4, None, 8)) -del field_zg_in -print("writing data...", end="", flush=True) -cf.write( - field_zg, +# zg is 3hr data from 00:00 but we want daily at 12:00, so set time of lower bound +preprocessing.combine_time( + f"{data_dir}/zg7h_*.nc", + (time_bounds[0] + " 12:00", time_bounds[1]), f"{data_out}/zg7h_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500801-19501030.nc", ) -del field_zg print("done.") # Copy orography across directly: @@ -74,43 +56,35 @@ # ======== TSTORMS ======== print("Renaming slp...", end="", flush=True) -field_psl.nc_set_variable("slp") -print("writing data...", end="", flush=True) -cf.write( +preprocessing.set_netcdf_variable_name( field_psl, + "slp", f"{data_out}/slp_day_ASO50.nc", ) del field_psl print("done.") print("Extracting subspace from uas and renaming...", end="", flush=True) -field_uas_in = cf.read( +field_uas = preprocessing.combine_time( f"{data_dir}/uas_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500101-19501230.nc", - netcdf_backend="netCDF4", + time_bounds, )[0] -field_uas = field_uas_in.subspace(T=ASO_1950) -del field_uas_in -field_uas.nc_set_variable("u_ref") -print("writing data...", end="", flush=True) -cf.write( +field_uas = preprocessing.set_netcdf_variable_name( field_uas, + "u_ref", f"{data_out}/u_ref_day_ASO50.nc", ) print("done.") print("Extracting subspace from vas and renaming...", end="", flush=True) -field_vas_in = cf.read( +field_vas = preprocessing.combine_time( f"{data_dir}/vas_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500701-19501230.nc", - netcdf_backend="netCDF4", + time_bounds, )[0] -field_vas = field_vas_in.subspace(T=ASO_1950) -field_vas.regrids(field_uas, method="linear", inplace=True) -field_vas.nc_clear_dataset_chunksizes() -del field_vas_in -field_vas.nc_set_variable("v_ref") -print("writing data...", end="", flush=True) -cf.write( +field_vas = preprocessing.regrid_to_field(field_vas, field_uas, method="linear") +preprocessing.set_netcdf_variable_name( field_vas, + "v_ref", f"{data_out}/v_ref_day_ASO50.nc", ) del field_vas @@ -119,55 +93,42 @@ print("Extracting subspace from ua and va to calculate vorticity and renaming...") print("\tExtracting subspace from ua for u850...", end="", flush=True) -field_ua = cf.read( +field_u850 = preprocessing.subsample_field( f"{data_dir}/ua_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500701-19501230.nc", - netcdf_backend="netCDF4", -)[0] -field_u850 = field_ua.subspace(T=ASO_1950, Z=[1]) -del field_ua -field_u850.squeeze(inplace=True) -field_u850_rg = field_u850.regrids(field_uas, method="linear") -del field_u850 -field_u850_rg.nc_clear_dataset_chunksizes() -field_u850_rg.nc_set_variable("u850") -print("writing data...", end="", flush=True) -cf.write( - field_u850_rg, + {"T": time_window, "Z": [1]}, + squeeze=True, +) +field_u850 = preprocessing.regrid_to_field(field_u850, field_uas, method="linear") +field_u850 = preprocessing.set_netcdf_variable_name( + field_u850, + "u850", f"{data_out}/u850_day_ASO50.nc", ) print("done.") print("\tExtracting subspace from va for v850...", end="", flush=True) -field_va = cf.read( +field_v850 = preprocessing.subsample_field( f"{data_dir}/va_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500701-19501230.nc", - netcdf_backend="netCDF4", -)[0] -field_v850 = field_va.subspace(T=ASO_1950, Z=[1]) -del field_va -field_v850.squeeze(inplace=True) -field_v850_rg = field_v850.regrids(field_uas, method="linear") + {"T": time_window, "Z": [1]}, + squeeze=True, +) +field_v850 = preprocessing.regrid_to_field(field_v850, field_uas, method="linear") del field_uas -del field_v850 -field_v850_rg.nc_clear_dataset_chunksizes() -field_v850_rg.nc_set_variable("v850") -print("writing data...", end="", flush=True) -cf.write( - field_v850_rg, +field_v850 = preprocessing.set_netcdf_variable_name( + field_v850, + "v850", f"{data_out}/v850_day_ASO50.nc", ) print("done.") print("\tCalculating vorticity for vort850...", end="", flush=True) -field_vort850 = cf.curl_xy(field_u850_rg, field_v850_rg, radius="earth") -del field_u850_rg -del field_v850_rg -field_vort850.filled(fill_value=0.0, inplace=True) -field_vort850.nc_set_variable("vort850") -field_vort850.set_property("standard_name", "atmosphere_upward_absolute_vorticity") -field_vort850.set_property("units", "s-1") -print("writing data...", end="", flush=True) -cf.write( +field_vort850 = preprocessing.calculate_vorticity(field_u850, field_v850) +del field_u850 +del field_v850 +field_vort850 = preprocessing.replace_fill_value(field_vort850, 0.0) +preprocessing.set_netcdf_variable_name( field_vort850, + "vort850", f"{data_out}/vort850_day_ASO50.nc", ) del field_vort850 @@ -176,19 +137,14 @@ print("done.") print("Extracting subspace and taking mean of ta and renaming...", end="", flush=True) -field_ta_full = cf.read( +field_ta = preprocessing.subsample_field( f"{data_dir}/ta_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500701-19501230.nc", - netcdf_backend="netCDF4", -)[0] -# Extract 500 and 250 pressure levels and take mean -field_ta = field_ta_full.subspace(T=ASO_1950, Z=slice(3, -3)) -del field_ta_full -field_ta.collapse("mean", axes="Z", inplace=True) -field_ta.squeeze(inplace=True) -field_ta.nc_set_variable("tm") -print("writing data...", end="", flush=True) -cf.write( + {"T": time_window, "Z": slice(3, -3)}, +) +field_ta = preprocessing.collapse_field(field_ta, "mean", axes="Z") +preprocessing.set_netcdf_variable_name( field_ta, + "tm", f"{data_out}/tm_day_ASO50.nc", ) del field_ta From 47e1f1cee319c041ddb8f0ad9b517e939ab18994 Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Tue, 2 Jun 2026 08:57:36 +0100 Subject: [PATCH 08/21] Rename regrid.py to preprocess_data.py in tutorial --- docs/getting-started/tutorial.rst | 4 ++-- tutorial/{regrid.py => preprocess_data.py} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename tutorial/{regrid.py => preprocess_data.py} (100%) diff --git a/docs/getting-started/tutorial.rst b/docs/getting-started/tutorial.rst index 1cb7fc9..fecbc97 100644 --- a/docs/getting-started/tutorial.rst +++ b/docs/getting-started/tutorial.rst @@ -90,9 +90,9 @@ before running. Pre-processing of Data ---------------------- -From inside the conda environment run the regridding script to pre-process the data:: +From inside the conda environment run the script to pre-process the data:: - python regrid.py + python preprocess_data.py This will pre-process the downloaded data as required for our codes and place it in ``data_processed/``. diff --git a/tutorial/regrid.py b/tutorial/preprocess_data.py similarity index 100% rename from tutorial/regrid.py rename to tutorial/preprocess_data.py From 0dde74a6a24f482c7c1c9b0ac31d1553dcc7380c Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Tue, 2 Jun 2026 09:22:56 +0100 Subject: [PATCH 09/21] Rename combine_time -> select_time_range --- docs/data/preprocessing_data.rst | 7 +++++-- src/tctrack/preprocessing.py | 20 +++++++++---------- .../unit/preprocessing/test_preprocessing.py | 8 ++++---- tutorial/preprocess_data.py | 10 +++++----- 4 files changed, 23 insertions(+), 22 deletions(-) diff --git a/docs/data/preprocessing_data.rst b/docs/data/preprocessing_data.rst index 4414042..2ab032f 100644 --- a/docs/data/preprocessing_data.rst +++ b/docs/data/preprocessing_data.rst @@ -36,11 +36,14 @@ cf-python: # Write the combined data to a single file cf.write(field, "combined-output.nc") -Or, equivalently, using TCTrack: +Or, equivalently, in TCTrack you can use +:func:`tctrack.preprocessing.select_time_range`, as below. All of the other +preprocessing functions can also be used to combine files if a specific time range is +not required. .. code-block:: python - tctrack.preprocessing.combine_time( + tctrack.preprocessing.select_time_range( input_files, ["1950-01-01", "1950-04-01"], "combined-output.nc" ) diff --git a/src/tctrack/preprocessing.py b/src/tctrack/preprocessing.py index 3f2110a..0f57fc3 100644 --- a/src/tctrack/preprocessing.py +++ b/src/tctrack/preprocessing.py @@ -96,19 +96,20 @@ def read_files( return _write_output(fields, output_file) -def combine_time( +def select_time_range( input_files: str | Sequence[str], - time_bounds: tuple[str, str] | None = None, + time_bounds: tuple[str, str], output_file: str | None = None, ) -> list[cf.Field]: - """Combine files in time. + """Combine files in time and select a time range. Parameters ---------- input_files : str | Sequence[str] Input file path(s) to combine. ``glob`` pattern matching allowed. - time_bounds : tuple[str, str] | None, optional - Optional start and end datetime strings. The end bound is open / exclusive. + time_bounds : tuple[str, str] + Start and end datetime strings in format ``"YYYY-MM-DD[ HH:MM]". + The end bound is open / exclusive. output_file : str | None, optional Output file to write the result to. @@ -119,11 +120,8 @@ def combine_time( """ fields = read_files(input_files) - if time_bounds is not None: - time_interval = cf.wi( - cf.dt(time_bounds[0]), cf.dt(time_bounds[1]), open_upper=True - ) - fields = [field.subspace(T=time_interval) for field in fields] + time_interval = cf.wi(cf.dt(time_bounds[0]), cf.dt(time_bounds[1]), open_upper=True) + fields = [field.subspace(T=time_interval) for field in fields] return _write_output(fields, output_file) @@ -515,7 +513,7 @@ def regrid_to_gaussian( __all__ = [ # noqa: RUF022 # Prevent reorder for a more logical order in the api docs "read_files", - "combine_time", + "select_time_range", "separate_variables", "subsample_field", "collapse_field", diff --git a/tests/unit/preprocessing/test_preprocessing.py b/tests/unit/preprocessing/test_preprocessing.py index 5afd181..2687a6b 100644 --- a/tests/unit/preprocessing/test_preprocessing.py +++ b/tests/unit/preprocessing/test_preprocessing.py @@ -11,12 +11,12 @@ _load_field, calculate_vorticity, collapse_field, - combine_time, gaussian_grid, read_files, regrid_to_field, regrid_to_gaussian, replace_fill_value, + select_time_range, separate_variables, set_netcdf_variable_name, subsample_field, @@ -96,14 +96,14 @@ def test_read_files_output_file(self, tmp_path): assert output_file.exists() assert cf.read(str(output_file))[0].nc_get_variable() == "mslp" - def test_combine_time_bounds(self, tmp_path): - """Test combine_time correctly selects data in time bounds.""" + def test_select_time_range_bounds(self, tmp_path): + """Test select_time_range correctly selects data in time bounds.""" input_files = [ write_fields(make_field("mslp", "2000-01-01"), tmp_path / "a.nc"), write_fields(make_field("mslp", "2000-01-02"), tmp_path / "b.nc"), ] - fields = combine_time(input_files, ("2000-01-01", "2000-01-02")) + fields = select_time_range(input_files, ("2000-01-01", "2000-01-02")) assert len(fields) == 1 # Same field (mslp) assert fields[0].coordinate("T").size == 1 # Upper bound is excluded diff --git a/tutorial/preprocess_data.py b/tutorial/preprocess_data.py index 9f996ba..f138e94 100644 --- a/tutorial/preprocess_data.py +++ b/tutorial/preprocess_data.py @@ -21,7 +21,7 @@ # ======== Tempest Extremes ======== # Extract ASO from annual data files print("Extracting subspace from psl...", end="", flush=True) -field_psl = preprocessing.combine_time( +field_psl = preprocessing.select_time_range( f"{data_dir}/psl_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500101-19501230.nc", time_bounds, f"{data_out}/psl_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500801-19501030.nc", @@ -29,7 +29,7 @@ print("done.") print("Extracting subspace from sfcWind...", end="", flush=True) -preprocessing.combine_time( +preprocessing.select_time_range( f"{data_dir}/sfcWind_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500101-19501230.nc", time_bounds, f"{data_out}/sfcWind_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500801-19501030.nc", @@ -39,7 +39,7 @@ # Combine the monthly zg files into one, subspacing in time print("Combining zg files into one...", end="", flush=True) # zg is 3hr data from 00:00 but we want daily at 12:00, so set time of lower bound -preprocessing.combine_time( +preprocessing.select_time_range( f"{data_dir}/zg7h_*.nc", (time_bounds[0] + " 12:00", time_bounds[1]), f"{data_out}/zg7h_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500801-19501030.nc", @@ -65,7 +65,7 @@ print("done.") print("Extracting subspace from uas and renaming...", end="", flush=True) -field_uas = preprocessing.combine_time( +field_uas = preprocessing.select_time_range( f"{data_dir}/uas_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500101-19501230.nc", time_bounds, )[0] @@ -77,7 +77,7 @@ print("done.") print("Extracting subspace from vas and renaming...", end="", flush=True) -field_vas = preprocessing.combine_time( +field_vas = preprocessing.select_time_range( f"{data_dir}/vas_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500701-19501230.nc", time_bounds, )[0] From e5aed24ba5fc6c2d230f89e6f3d0620ccd7ec013 Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Tue, 2 Jun 2026 10:03:48 +0100 Subject: [PATCH 10/21] Add checks that esmpy is installed for regridding --- docs/getting-started/index.rst | 7 +++---- src/tctrack/preprocessing.py | 17 +++++++++++++++++ tests/unit/preprocessing/test_preprocessing.py | 16 ++++++++++++++++ 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/docs/getting-started/index.rst b/docs/getting-started/index.rst index 3281e90..267b037 100644 --- a/docs/getting-started/index.rst +++ b/docs/getting-started/index.rst @@ -101,11 +101,10 @@ noting that we may need to add the library to the dynamic path e.g.:: esmpy ~~~~~ -Any regridding of data with cf-python requires `esmpy -`_ and `ESMF +Any preprocessing that involves regridding of data using :mod:`tctrack.preprocessing` or +cf-python requires `esmpy `_ and `ESMF `_ as dependencies. This is not needed directly in the -TCTrack package but may be needed for initial pre-processing of data, such as in the -tutorial and described in the :doc:`../data/preprocessing_data` page. +tracking algorithms. These are not pip-installable but can be installed in a conda environment:: diff --git a/src/tctrack/preprocessing.py b/src/tctrack/preprocessing.py index 0f57fc3..552c727 100644 --- a/src/tctrack/preprocessing.py +++ b/src/tctrack/preprocessing.py @@ -1,6 +1,7 @@ """Module with preprocessing functions that can be used with batching.""" import glob +import importlib.util from collections.abc import Sequence from dataclasses import dataclass from typing import Any, TypeAlias, TypeVar @@ -8,6 +9,20 @@ import cf import numpy as np +ESMPY_AVAILABLE = importlib.util.find_spec("esmpy") is not None + + +def _require_esmpy() -> None: + """Guard to ensure that esmpy is installed for regridding.""" + if not ESMPY_AVAILABLE: + msg = ( + "Regridding requires esmpy to be installed. " + "See the dependency installation documentation:\n" + "https://tctrack.readthedocs.io/en/latest/getting-started/index.html#esmpy" + ) + raise ImportError(msg) + + FilePaths: TypeAlias = str | Sequence[str] @@ -412,6 +427,7 @@ def regrid_to_field( cf.Field Regridded field. """ + _require_esmpy() field = _load_field(input_) if not isinstance(target, cf.Domain): target = _load_field(target) @@ -449,6 +465,7 @@ def regrid_to_lat_lon( cf.Field Regridded field on the requested latitude-longitude grid. """ + _require_esmpy() field = _load_field(input_) domain = field.domain.copy() diff --git a/tests/unit/preprocessing/test_preprocessing.py b/tests/unit/preprocessing/test_preprocessing.py index 2687a6b..2f24c0e 100644 --- a/tests/unit/preprocessing/test_preprocessing.py +++ b/tests/unit/preprocessing/test_preprocessing.py @@ -6,6 +6,7 @@ import numpy as np import pytest +import tctrack.preprocessing from tctrack.preprocessing import ( FieldSelect, _load_field, @@ -243,6 +244,17 @@ def test_set_netcdf_variable_name(self): assert renamed.coordinate("X").nc_get_variable() == "longitude" assert renamed.coordinate("Y").nc_get_variable() == "latitude" + def test_regrid_esmpy_guard(self, monkeypatch): + """Test regridding fails clearly when esmpy is unavailable.""" + monkeypatch.setattr(tctrack.preprocessing, "ESMPY_AVAILABLE", False) + + with pytest.raises(ImportError, match="Regridding requires esmpy"): + regrid_to_field(make_field("v"), make_field("u")) + + @pytest.mark.skipif( + not tctrack.preprocessing.ESMPY_AVAILABLE, + reason="esmpy is not pip-installable so this currently fails in CI", + ) def test_regrid_to_field(self): """Test regrid_to_field works correctly.""" target = make_field("u") @@ -261,6 +273,10 @@ def test_gaussian_grid(self): assert longitude[-1] == pytest.approx(337.5) assert latitude[0] == pytest.approx(-latitude[-1]) + @pytest.mark.skipif( + not tctrack.preprocessing.ESMPY_AVAILABLE, + reason="esmpy is not pip-installable so this currently fails in CI", + ) def test_regrid_to_gaussian(self): """Test regrid_to_gaussian works correctly.""" field = make_field("mslp") From de5b4fe769da5e34b4cd74dd1fc79dfb21feb22c Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Tue, 2 Jun 2026 10:43:11 +0100 Subject: [PATCH 11/21] Fix mistake in preprocessing script --- tutorial/preprocess_data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tutorial/preprocess_data.py b/tutorial/preprocess_data.py index f138e94..f7aae12 100644 --- a/tutorial/preprocess_data.py +++ b/tutorial/preprocess_data.py @@ -38,10 +38,10 @@ # Combine the monthly zg files into one, subspacing in time print("Combining zg files into one...", end="", flush=True) -# zg is 3hr data from 00:00 but we want daily at 12:00, so set time of lower bound -preprocessing.select_time_range( +# zg is 3hr data from 00:00 but we want daily at 12:00, so subspace with a slice +preprocessing.subsample_field( f"{data_dir}/zg7h_*.nc", - (time_bounds[0] + " 12:00", time_bounds[1]), + {"T": slice(4, None, 8)}, f"{data_out}/zg7h_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500801-19501030.nc", ) print("done.") From f15b6b9d3e0577e32ca7b1dbfe9554dd75f2d2a6 Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Tue, 2 Jun 2026 11:00:35 +0100 Subject: [PATCH 12/21] Make output_file a kwarg in preprocessing fns This makes all optional arguments consistent in needing to be passed by keyword. I think it also makes it clearer. --- docs/data/preprocessing_data.rst | 25 ++++++++++++------- src/tctrack/preprocessing.py | 18 +++++++------ .../unit/preprocessing/test_preprocessing.py | 2 +- 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/docs/data/preprocessing_data.rst b/docs/data/preprocessing_data.rst index 2ab032f..f24dd93 100644 --- a/docs/data/preprocessing_data.rst +++ b/docs/data/preprocessing_data.rst @@ -44,7 +44,7 @@ not required. .. code-block:: python tctrack.preprocessing.select_time_range( - input_files, ["1950-01-01", "1950-04-01"], "combined-output.nc" + input_files, ["1950-01-01", "1950-04-01"], output_file="combined-output.nc" ) Combine Variables @@ -67,7 +67,7 @@ Using TCTrack: .. code-block:: python tctrack.preprocessing.read_files( - ["var1_file.nc", "var2_file.nc"], "combined_file.nc" + ["var1_file.nc", "var2_file.nc"], output_file="combined_file.nc" ) Separating Variables @@ -173,7 +173,7 @@ Using TCTrack: .. code-block:: python tctrack.preprocessing.calculate_vorticity( - "u_file.nc", "v_file.nc", "vorticity_file.nc" + "u_file.nc", "v_file.nc", output_file="vorticity_file.nc" ) Or to take a mean over a coordinate: @@ -194,7 +194,9 @@ Using TCTrack: .. code-block:: python - tctrack.preprocessing.collapse_field("file.nc", "mean", "X", "zonal_mean_file.nc") + tctrack.preprocessing.collapse_field( + "file.nc", "mean", "X", output_file="zonal_mean_file.nc" + ) Setting Fill Values ^^^^^^^^^^^^^^^^^^^ @@ -212,7 +214,7 @@ Using TCTrack: .. code-block:: python - tctrack.preprocessing.replace_fill_value(w_field, 0.0, "output.nc") + tctrack.preprocessing.replace_fill_value(w_field, 0.0, output_file="output.nc") Set NetCDF Variable Name ------------------------ @@ -236,7 +238,10 @@ Using TCTrack: .. code-block:: python tctrack.preprocessing.set_netcdf_variable_name( - "var1_file.nc", "slp", "slp_file.nc", coord_names={"latitude": "lat"} + "var1_file.nc", + "slp", + coord_names={"latitude": "lat"}, + output_file="slp_file.nc", ) Regridding @@ -286,14 +291,14 @@ Using TCTrack: # Regrid onto a different variable tctrack.preprocessing.regrid_to_field( - "var1_file.nc", "var2_file.nc", "var1_regridded.nc" + "var1_file.nc", "var2_file.nc", output_file="var1_regridded.nc" ) # Regrid onto a new grid latitude = np.arange(-90, 91, 1) longitude = np.arange(-180, 181, 1) tctrack.preprocessing.regrid_to_lat_lon( - "var1_file.nc", latitude, longitude, "var1_regridded.nc" + "var1_file.nc", latitude, longitude, output_file="var1_regridded.nc" ) Gaussian Grid @@ -331,4 +336,6 @@ Using TCTrack: .. code-block:: python - tctrack.preprocessing.regrid_to_gaussian("var1_file.nc", 256, "var1_regridded.nc") + tctrack.preprocessing.regrid_to_gaussian( + "var1_file.nc", 256, output_file="var1_regridded.nc" + ) diff --git a/src/tctrack/preprocessing.py b/src/tctrack/preprocessing.py index 552c727..b63a20d 100644 --- a/src/tctrack/preprocessing.py +++ b/src/tctrack/preprocessing.py @@ -83,8 +83,8 @@ def _write_output(result: T, output_file: str | None) -> T: def read_files( input_files: str | Sequence[str], - output_file: str | None = None, *, + output_file: str | None = None, select: str | None = None, ) -> list[cf.Field]: """Read fields from one or more files. @@ -114,6 +114,7 @@ def read_files( def select_time_range( input_files: str | Sequence[str], time_bounds: tuple[str, str], + *, output_file: str | None = None, ) -> list[cf.Field]: """Combine files in time and select a time range. @@ -202,8 +203,8 @@ def _load_field(source: FieldSource) -> cf.Field: def subsample_field( input_: FieldSource, subspace_kwargs: dict[str, Any], - output_file: str | None = None, *, + output_file: str | None = None, squeeze: bool = False, ) -> cf.Field: """Subsample a field using ``cf.Field.subspace``. @@ -239,8 +240,8 @@ def collapse_field( input_: FieldSource, method: str, axes: str | Sequence[str], - output_file: str | None = None, *, + output_file: str | None = None, squeeze: bool = True, ) -> cf.Field: """Collapse a field over one or more axes. @@ -275,6 +276,7 @@ def calculate_curl_xy( input_y: FieldSource, variable_name: str, variable_info: dict[str, str], + *, output_file: str | None = None, ) -> cf.Field: """Calculate the curl of x and y vector components. @@ -313,6 +315,7 @@ def calculate_curl_xy( def calculate_vorticity( input_u: FieldSource, input_v: FieldSource, + *, output_file: str | None = None, ) -> cf.Field: """Calculate vorticity from colocated velocity fields. @@ -346,6 +349,7 @@ def calculate_vorticity( def replace_fill_value( input_: FieldSource, fill_value: float, + *, output_file: str | None = None, ) -> cf.Field: """Replace masked values in a field using ``cf.Field.filled``. @@ -372,8 +376,8 @@ def replace_fill_value( def set_netcdf_variable_name( input_: FieldSource, field_name: str, - output_file: str | None = None, *, + output_file: str | None = None, coord_names: dict[str, str] | None = None, ) -> cf.Field: """Set NetCDF variable names for a field and, optionally, its coordinates. @@ -405,8 +409,8 @@ def set_netcdf_variable_name( def regrid_to_field( input_: FieldSource, target: FieldSource | cf.Domain, - output_file: str | None = None, *, + output_file: str | None = None, method: str = "linear", ) -> cf.Field: """Regrid a field onto the grid of another field or domain. @@ -441,8 +445,8 @@ def regrid_to_lat_lon( input_: FieldSource, latitude: np.ndarray, longitude: np.ndarray, - output_file: str | None = None, *, + output_file: str | None = None, method: str = "linear", ) -> cf.Field: """Regrid a field onto a latitude-longitude grid. @@ -502,8 +506,8 @@ def gaussian_grid(n: int) -> tuple[np.ndarray, np.ndarray]: def regrid_to_gaussian( input_: FieldSource, n: int, - output_file: str | None = None, *, + output_file: str | None = None, method: str = "linear", ) -> cf.Field: """Regrid a field onto a regular Gaussian grid. diff --git a/tests/unit/preprocessing/test_preprocessing.py b/tests/unit/preprocessing/test_preprocessing.py index 2f24c0e..e8c79fb 100644 --- a/tests/unit/preprocessing/test_preprocessing.py +++ b/tests/unit/preprocessing/test_preprocessing.py @@ -91,7 +91,7 @@ def test_read_files_output_file(self, tmp_path): input_file = write_fields(make_field("mslp"), tmp_path / "input.nc") output_file = tmp_path / "output.nc" - fields = read_files(input_file, str(output_file)) + fields = read_files(input_file, output_file=str(output_file)) assert len(fields) == 1 assert output_file.exists() From 6138244b48c6bee7838fa22f3f1d576e9f8cb87b Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Thu, 4 Jun 2026 18:04:28 +0100 Subject: [PATCH 13/21] Change FieldSelect to a TypedDict --- src/tctrack/preprocessing.py | 19 +++++++++---------- .../unit/preprocessing/test_preprocessing.py | 12 ++++-------- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/src/tctrack/preprocessing.py b/src/tctrack/preprocessing.py index b63a20d..77fd04f 100644 --- a/src/tctrack/preprocessing.py +++ b/src/tctrack/preprocessing.py @@ -3,8 +3,7 @@ import glob import importlib.util from collections.abc import Sequence -from dataclasses import dataclass -from typing import Any, TypeAlias, TypeVar +from typing import Any, TypedDict, TypeAlias, TypeVar import cf import numpy as np @@ -26,11 +25,10 @@ def _require_esmpy() -> None: FilePaths: TypeAlias = str | Sequence[str] -@dataclass(frozen=True) -class FieldSelect: - """Class containing the file name(s) plus the NetCDF variable name to select. +class FieldSelect(TypedDict): + """Dictionary containing the file name(s) plus the NetCDF variable name to select. - Necessary for choosing a variable from files which contain multiple. + This is necessary for choosing a variable from files which contain multiple. Parameters ---------- @@ -176,10 +174,11 @@ def _load_field(source: FieldSource) -> cf.Field: if isinstance(source, cf.Field): return source - if isinstance(source, FieldSelect): - fields = read_files(source.files, select=f"ncvar%{source.var_name}") + if isinstance(source, dict): + nc_var = source["var_name"] + fields = read_files(source["files"], select=f"ncvar%{nc_var}") if not fields: - msg = f"No field with NetCDF variable name '{source.var_name}' was found." + msg = f"No field with NetCDF variable name '{nc_var}' was found." raise ValueError(msg) return fields[0] @@ -188,7 +187,7 @@ def _load_field(source: FieldSource) -> cf.Field: if len(fields) != 1: msg = ( f"Expected one field from '{source}', but found {len(fields)}. " - "Use FieldSelect(files, variable_name) to select a field." + "Use {\"files\": files, \"var_name\": variable_name} to select a field." ) raise ValueError(msg) return fields[0] diff --git a/tests/unit/preprocessing/test_preprocessing.py b/tests/unit/preprocessing/test_preprocessing.py index e8c79fb..bf49cc2 100644 --- a/tests/unit/preprocessing/test_preprocessing.py +++ b/tests/unit/preprocessing/test_preprocessing.py @@ -8,7 +8,6 @@ import tctrack.preprocessing from tctrack.preprocessing import ( - FieldSelect, _load_field, calculate_vorticity, collapse_field, @@ -156,20 +155,17 @@ def test_load_field_rejects_multifield_files(self, tmp_path): tmp_path / "input.nc", ) - with pytest.raises( - ValueError, - match=r"Use FieldSelect\(files, variable_name\) to select a field", - ): + with pytest.raises(ValueError, match="Use {.*} to select a field"): _load_field(input_file) def test_load_field_accepts_field_select(self, tmp_path): - """Test _load_field selects a field when given a FieldSelect object.""" + """Test _load_field selects a field when given a FieldSelect dictionary.""" input_file = write_fields( [make_field("mslp"), make_field("u")], tmp_path / "input.nc", ) - field = _load_field(FieldSelect(input_file, "u")) + field = _load_field({"files": input_file, "var_name": "u"}) assert field.nc_get_variable() == "u" @@ -181,7 +177,7 @@ def test_load_field_rejects_missing_field_select(self, tmp_path): ValueError, match="No field with NetCDF variable name 'invalid' was found", ): - _load_field(FieldSelect(input_file, "invalid")) + _load_field({"files": input_file, "var_name": "invalid"}) def test_subsample_field(self): """Test subsample_field works correctly.""" From 800c45392df67322bbe652fc3a0230fb18f6ef84 Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Fri, 5 Jun 2026 10:16:02 +0100 Subject: [PATCH 14/21] Fix formatting --- src/tctrack/preprocessing.py | 4 ++-- tests/unit/preprocessing/test_preprocessing.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tctrack/preprocessing.py b/src/tctrack/preprocessing.py index 77fd04f..c5a0bb2 100644 --- a/src/tctrack/preprocessing.py +++ b/src/tctrack/preprocessing.py @@ -3,7 +3,7 @@ import glob import importlib.util from collections.abc import Sequence -from typing import Any, TypedDict, TypeAlias, TypeVar +from typing import Any, TypeAlias, TypedDict, TypeVar import cf import numpy as np @@ -187,7 +187,7 @@ def _load_field(source: FieldSource) -> cf.Field: if len(fields) != 1: msg = ( f"Expected one field from '{source}', but found {len(fields)}. " - "Use {\"files\": files, \"var_name\": variable_name} to select a field." + 'Use {"files": files, "var_name": variable_name} to select a field.' ) raise ValueError(msg) return fields[0] diff --git a/tests/unit/preprocessing/test_preprocessing.py b/tests/unit/preprocessing/test_preprocessing.py index bf49cc2..dd2e30e 100644 --- a/tests/unit/preprocessing/test_preprocessing.py +++ b/tests/unit/preprocessing/test_preprocessing.py @@ -39,7 +39,7 @@ def make_field(var_name: str, time: str | None = None) -> cf.Field: return field -def write_fields(fields: list[cf.Field], path: Path) -> str: +def write_fields(fields: cf.Field | list[cf.Field], path: Path) -> str: """Write one or more fields to a file and return the file path.""" cf.write(fields, str(path)) # type: ignore[operator] return str(path) @@ -155,7 +155,7 @@ def test_load_field_rejects_multifield_files(self, tmp_path): tmp_path / "input.nc", ) - with pytest.raises(ValueError, match="Use {.*} to select a field"): + with pytest.raises(ValueError, match=r"Use {.*} to select a field"): _load_field(input_file) def test_load_field_accepts_field_select(self, tmp_path): From b0a69102460988cfe8dd5bc867fd0596c5f064d2 Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Thu, 11 Jun 2026 16:45:17 +0100 Subject: [PATCH 15/21] fixup! Rename combine_time -> select_time_range --- src/tctrack/preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tctrack/preprocessing.py b/src/tctrack/preprocessing.py index c5a0bb2..114d758 100644 --- a/src/tctrack/preprocessing.py +++ b/src/tctrack/preprocessing.py @@ -122,7 +122,7 @@ def select_time_range( input_files : str | Sequence[str] Input file path(s) to combine. ``glob`` pattern matching allowed. time_bounds : tuple[str, str] - Start and end datetime strings in format ``"YYYY-MM-DD[ HH:MM]". + Start and end datetime strings in format ``"YYYY-MM-DD[ HH:MM]"``. The end bound is open / exclusive. output_file : str | None, optional Output file to write the result to. From d7050e808738a1ec71e4a36f1204512d9c07a5ee Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Wed, 17 Jun 2026 10:43:02 +0100 Subject: [PATCH 16/21] Correct absolute -> relative vorticity in metadata --- docs/data/preprocessing_data.rst | 2 +- src/tctrack/preprocessing.py | 2 +- tests/unit/preprocessing/test_preprocessing.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/data/preprocessing_data.rst b/docs/data/preprocessing_data.rst index f24dd93..182f33d 100644 --- a/docs/data/preprocessing_data.rst +++ b/docs/data/preprocessing_data.rst @@ -162,7 +162,7 @@ For example, to calculate vorticity from coincident velocity data we can use ``c # calculate vorticity w_field = cf.curl_xy(u_field, v_field, radius="earth") w_field.nc_set_variable("vorticity") - w_field.set_property("standard_name", "atmosphere_upward_absolute_vorticity") + w_field.set_property("standard_name", "atmosphere_upward_relative_vorticity") w_field.set_property("units", "s-1") # Save the new variable to NetCDF diff --git a/src/tctrack/preprocessing.py b/src/tctrack/preprocessing.py index 114d758..283a9de 100644 --- a/src/tctrack/preprocessing.py +++ b/src/tctrack/preprocessing.py @@ -338,7 +338,7 @@ def calculate_vorticity( input_v, variable_name="vorticity", variable_info={ - "standard_name": "atmosphere_upward_absolute_vorticity", + "standard_name": "atmosphere_upward_relative_vorticity", "units": "s-1", }, output_file=output_file, diff --git a/tests/unit/preprocessing/test_preprocessing.py b/tests/unit/preprocessing/test_preprocessing.py index dd2e30e..864b195 100644 --- a/tests/unit/preprocessing/test_preprocessing.py +++ b/tests/unit/preprocessing/test_preprocessing.py @@ -213,7 +213,7 @@ def test_calculate_vorticity(self): assert vorticity.nc_get_variable() == "vorticity" assert ( vorticity.get_property("standard_name") - == "atmosphere_upward_absolute_vorticity" + == "atmosphere_upward_relative_vorticity" ) assert vorticity.get_property("units") == "s-1" From 497c338c9723b75269a2b953243cd60e51b08694 Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Wed, 17 Jun 2026 11:00:03 +0100 Subject: [PATCH 17/21] fixup! Add preprocessing function examples to the docs --- docs/data/preprocessing_data.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/data/preprocessing_data.rst b/docs/data/preprocessing_data.rst index 182f33d..f01b21d 100644 --- a/docs/data/preprocessing_data.rst +++ b/docs/data/preprocessing_data.rst @@ -201,10 +201,11 @@ Using TCTrack: Setting Fill Values ^^^^^^^^^^^^^^^^^^^ -Sometimes it us useful to replace fill values after an operation before writing to file. +Sometimes it is useful to replace the 'fill values' after an operation but before +writing to file. This can be done using cf-python's ``filled`` routine. -For example, after to set any null or masked values to ``0.0`` after calculating -vorticity above use the following before writing to file. +For example, to set any null or masked values to ``0.0`` (e.g. after calculating +vorticity above) use the following before writing to file. .. code-block:: python From fc12ab391c6e50d96f94959b78fc3704010f7474 Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Wed, 17 Jun 2026 19:23:37 +0100 Subject: [PATCH 18/21] Fix error in calculation of curl --- src/tctrack/preprocessing.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/tctrack/preprocessing.py b/src/tctrack/preprocessing.py index 283a9de..39b2f96 100644 --- a/src/tctrack/preprocessing.py +++ b/src/tctrack/preprocessing.py @@ -302,6 +302,12 @@ def calculate_curl_xy( field_y = _load_field(input_y) curl = cf.curl_xy(field_x, field_y, radius="earth") + + # Negate the curl due to a suspected error in cf.curl_xy for spherical polar coords + # (In the first term the gradient is taken wrt latitude, not theta) + # (The second term is not negated) + curl.data = -curl.data + curl.nc_set_variable(variable_name) for name, value in variable_info.items(): if name == "units": From cab7a34fdeaa5417a427343ef51140f23b88b5ec Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Wed, 17 Jun 2026 19:38:04 +0100 Subject: [PATCH 19/21] fixup! Make output_file a kwarg in preprocessing fns --- tutorial/preprocess_data.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tutorial/preprocess_data.py b/tutorial/preprocess_data.py index f7aae12..af4557d 100644 --- a/tutorial/preprocess_data.py +++ b/tutorial/preprocess_data.py @@ -24,7 +24,7 @@ field_psl = preprocessing.select_time_range( f"{data_dir}/psl_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500101-19501230.nc", time_bounds, - f"{data_out}/psl_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500801-19501030.nc", + output_file=f"{data_out}/psl_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500801-19501030.nc", )[0] print("done.") @@ -32,7 +32,7 @@ preprocessing.select_time_range( f"{data_dir}/sfcWind_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500101-19501230.nc", time_bounds, - f"{data_out}/sfcWind_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500801-19501030.nc", + output_file=f"{data_out}/sfcWind_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500801-19501030.nc", ) print("done.") @@ -42,7 +42,7 @@ preprocessing.subsample_field( f"{data_dir}/zg7h_*.nc", {"T": slice(4, None, 8)}, - f"{data_out}/zg7h_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500801-19501030.nc", + output_file=f"{data_out}/zg7h_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500801-19501030.nc", ) print("done.") @@ -59,7 +59,7 @@ preprocessing.set_netcdf_variable_name( field_psl, "slp", - f"{data_out}/slp_day_ASO50.nc", + output_file=f"{data_out}/slp_day_ASO50.nc", ) del field_psl print("done.") @@ -72,7 +72,7 @@ field_uas = preprocessing.set_netcdf_variable_name( field_uas, "u_ref", - f"{data_out}/u_ref_day_ASO50.nc", + output_file=f"{data_out}/u_ref_day_ASO50.nc", ) print("done.") @@ -85,7 +85,7 @@ preprocessing.set_netcdf_variable_name( field_vas, "v_ref", - f"{data_out}/v_ref_day_ASO50.nc", + output_file=f"{data_out}/v_ref_day_ASO50.nc", ) del field_vas print("done.") @@ -102,7 +102,7 @@ field_u850 = preprocessing.set_netcdf_variable_name( field_u850, "u850", - f"{data_out}/u850_day_ASO50.nc", + output_file=f"{data_out}/u850_day_ASO50.nc", ) print("done.") @@ -117,7 +117,7 @@ field_v850 = preprocessing.set_netcdf_variable_name( field_v850, "v850", - f"{data_out}/v850_day_ASO50.nc", + output_file=f"{data_out}/v850_day_ASO50.nc", ) print("done.") @@ -129,7 +129,7 @@ preprocessing.set_netcdf_variable_name( field_vort850, "vort850", - f"{data_out}/vort850_day_ASO50.nc", + output_file=f"{data_out}/vort850_day_ASO50.nc", ) del field_vort850 print("done.") @@ -145,7 +145,7 @@ preprocessing.set_netcdf_variable_name( field_ta, "tm", - f"{data_out}/tm_day_ASO50.nc", + output_file=f"{data_out}/tm_day_ASO50.nc", ) del field_ta print("done.") From e98ae6d13ea742df797497a14815ec590cbae731 Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Thu, 18 Jun 2026 09:23:49 +0100 Subject: [PATCH 20/21] Simplify preprocessing function usage Return cf.Field objects instead of size-1 lists. Also allow size-1 lists to be passed as inputs to functions that expect single fields. --- src/tctrack/preprocessing.py | 44 +++++++++++++++---- .../unit/preprocessing/test_preprocessing.py | 25 ++++++++++- tutorial/preprocess_data.py | 6 +-- 3 files changed, 62 insertions(+), 13 deletions(-) diff --git a/src/tctrack/preprocessing.py b/src/tctrack/preprocessing.py index 39b2f96..5932fda 100644 --- a/src/tctrack/preprocessing.py +++ b/src/tctrack/preprocessing.py @@ -3,7 +3,7 @@ import glob import importlib.util from collections.abc import Sequence -from typing import Any, TypeAlias, TypedDict, TypeVar +from typing import Any, Literal, TypeAlias, TypedDict, overload import cf import numpy as np @@ -42,7 +42,7 @@ class FieldSelect(TypedDict): var_name: str -FieldSource: TypeAlias = str | Sequence[str] | FieldSelect | cf.Field +FieldSource: TypeAlias = str | Sequence[str] | FieldSelect | cf.Field | list[cf.Field] """Type alias for the allowed sources for ``cf.Field`` arguments. The ``cf.Field`` can be passed directly or using the path(s) CF-NetCDF file(s). @@ -69,14 +69,35 @@ def _expand_input_paths(paths: str | Sequence[str]) -> list[str]: return expanded_paths -T = TypeVar("T", cf.Field, list[cf.Field]) +@overload +def _write_output( + result: cf.Field, output_file: str | None, squeeze: bool = True +) -> cf.Field: ... -def _write_output(result: T, output_file: str | None) -> T: - """Optionally write output before returning.""" +@overload +def _write_output( + result: list[cf.Field], output_file: str | None, squeeze: Literal[True] = True +) -> cf.Field | list[cf.Field]: ... + + +@overload +def _write_output( + result: list[cf.Field], output_file: str | None, squeeze: Literal[False] +) -> list[cf.Field]: ... + + +def _write_output( + result: cf.Field | list[cf.Field], output_file: str | None, squeeze: bool = True +) -> cf.Field | list[cf.Field]: + """Optionally write output before returning and squeeze size-1 lists.""" if output_file is not None: cf.write(result, output_file) # type: ignore[operator] - return result + + if squeeze and isinstance(result, list) and len(result) == 1: + return result[0] + else: + return result def read_files( @@ -106,7 +127,7 @@ def read_files( _expand_input_paths(input_files), select=select, netcdf_backend="netCDF4" ) ) - return _write_output(fields, output_file) + return _write_output(fields, output_file, squeeze=False) def select_time_range( @@ -114,7 +135,7 @@ def select_time_range( time_bounds: tuple[str, str], *, output_file: str | None = None, -) -> list[cf.Field]: +) -> cf.Field | list[cf.Field]: """Combine files in time and select a time range. Parameters @@ -171,6 +192,13 @@ def separate_variables( def _load_field(source: FieldSource) -> cf.Field: """Load a single field from an in-memory field or file input.""" + if isinstance(source, list) and all(isinstance(s, cf.Field) for s in source): + if len(source) == 1: + return source[0] + else: + msg = "Expected one field but multiple were provided." + raise ValueError(msg) + if isinstance(source, cf.Field): return source diff --git a/tests/unit/preprocessing/test_preprocessing.py b/tests/unit/preprocessing/test_preprocessing.py index 864b195..5a19066 100644 --- a/tests/unit/preprocessing/test_preprocessing.py +++ b/tests/unit/preprocessing/test_preprocessing.py @@ -101,12 +101,33 @@ def test_select_time_range_bounds(self, tmp_path): input_files = [ write_fields(make_field("mslp", "2000-01-01"), tmp_path / "a.nc"), write_fields(make_field("mslp", "2000-01-02"), tmp_path / "b.nc"), + write_fields(make_field("u", "2000-01-01"), tmp_path / "c.nc"), + write_fields(make_field("u", "2000-01-02"), tmp_path / "d.nc"), ] fields = select_time_range(input_files, ("2000-01-01", "2000-01-02")) - assert len(fields) == 1 # Same field (mslp) - assert fields[0].coordinate("T").size == 1 # Upper bound is excluded + # Same fields (mslp, u) + assert len(fields) == 2 + assert fields[0].nc_get_variable() == "mslp" + assert fields[1].nc_get_variable() == "u" + # Upper bound is excluded + assert fields[0].coordinate("T").size == 1 + assert fields[1].coordinate("T").size == 1 + + def test_select_time_range_squeeze(self, tmp_path): + """Test select_time_range squeezes size-1 list outputs.""" + input_files = [ + write_fields(make_field("mslp", "2000-01-01"), tmp_path / "a.nc"), + write_fields(make_field("mslp", "2000-01-02"), tmp_path / "b.nc"), + ] + + output = select_time_range(input_files, ("2000-01-01", "2000-01-02")) + + # Check it returns just the mslp field, not a list + assert isinstance(output, cf.Field) + assert output.nc_get_variable() == "mslp" + assert output.coordinate("T").size == 1 def test_separate_varibles(self, tmp_path): """Test separate_variables correctly splits variables across multiple files.""" diff --git a/tutorial/preprocess_data.py b/tutorial/preprocess_data.py index af4557d..8c227c6 100644 --- a/tutorial/preprocess_data.py +++ b/tutorial/preprocess_data.py @@ -25,7 +25,7 @@ f"{data_dir}/psl_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500101-19501230.nc", time_bounds, output_file=f"{data_out}/psl_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500801-19501030.nc", -)[0] +) print("done.") print("Extracting subspace from sfcWind...", end="", flush=True) @@ -68,7 +68,7 @@ field_uas = preprocessing.select_time_range( f"{data_dir}/uas_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500101-19501230.nc", time_bounds, -)[0] +) field_uas = preprocessing.set_netcdf_variable_name( field_uas, "u_ref", @@ -80,7 +80,7 @@ field_vas = preprocessing.select_time_range( f"{data_dir}/vas_day_HadGEM3-GC31-HM_hist-1950_r1i1p1f1_gn_19500701-19501230.nc", time_bounds, -)[0] +) field_vas = preprocessing.regrid_to_field(field_vas, field_uas, method="linear") preprocessing.set_netcdf_variable_name( field_vas, From c0213a2eb4ae9e3ef2664efe98ce1cdebf3e3c95 Mon Sep 17 00:00:00 2001 From: Sam Avis Date: Tue, 23 Jun 2026 20:45:18 +0100 Subject: [PATCH 21/21] fixup! Fix error in calculation of curl --- src/tctrack/preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tctrack/preprocessing.py b/src/tctrack/preprocessing.py index 5932fda..b41b652 100644 --- a/src/tctrack/preprocessing.py +++ b/src/tctrack/preprocessing.py @@ -333,7 +333,7 @@ def calculate_curl_xy( # Negate the curl due to a suspected error in cf.curl_xy for spherical polar coords # (In the first term the gradient is taken wrt latitude, not theta) - # (The second term is not negated) + # (The second term should be the gradient of the southward windspeed) curl.data = -curl.data curl.nc_set_variable(variable_name)