From a0cd8b0a865c9d53547d405269c0e8a5b3f4dcba Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Fri, 13 Mar 2026 22:36:05 -0500 Subject: [PATCH] Update geotransolver for 2d and 3d use cases --- .../models/geotransolver/__init__.py | 11 +- .../models/geotransolver/context_projector.py | 402 ++++++++++++++---- .../experimental/models/geotransolver/gale.py | 389 ++++++++++++----- .../models/geotransolver/geotransolver.py | 127 +++++- .../geotransolver/test_geotransolver.py | 92 ++++ 5 files changed, 824 insertions(+), 197 deletions(-) diff --git a/physicsnemo/experimental/models/geotransolver/__init__.py b/physicsnemo/experimental/models/geotransolver/__init__.py index 109496b618..596f5da0a9 100644 --- a/physicsnemo/experimental/models/geotransolver/__init__.py +++ b/physicsnemo/experimental/models/geotransolver/__init__.py @@ -51,8 +51,12 @@ torch.Size([2, 1000, 3]) """ -from .context_projector import ContextProjector, GlobalContextBuilder -from .gale import GALE, GALE_block +from .context_projector import ( + ContextProjector, + GlobalContextBuilder, + StructuredContextProjector, +) +from .gale import GALE, GALE_block, GALEStructuredMesh2D, GALEStructuredMesh3D from .geotransolver import GeoTransolver, GeoTransolverMetaData __all__ = [ @@ -60,6 +64,9 @@ "GeoTransolverMetaData", "GALE", "GALE_block", + "GALEStructuredMesh2D", + "GALEStructuredMesh3D", "ContextProjector", "GlobalContextBuilder", + "StructuredContextProjector", ] \ No newline at end of file diff --git a/physicsnemo/experimental/models/geotransolver/context_projector.py b/physicsnemo/experimental/models/geotransolver/context_projector.py index 36031ec36c..547fa95642 100644 --- a/physicsnemo/experimental/models/geotransolver/context_projector.py +++ b/physicsnemo/experimental/models/geotransolver/context_projector.py @@ -49,7 +49,200 @@ import transformer_engine.pytorch as te -class ContextProjector(nn.Module): +def _compute_slices_from_projections_impl( + slice_projections: Float[torch.Tensor, "batch heads tokens slices"], + fx: Float[torch.Tensor, "batch heads tokens dim"], + temperature: torch.Tensor, + plus: bool, + proj_temperature: nn.Module | None = None, +) -> tuple[ + Float[torch.Tensor, "batch heads tokens slices"], + Float[torch.Tensor, "batch heads slices dim"], +]: + r"""Shared slice aggregation: temperature-weighted softmax then weighted sum over tokens. + + Used by both :class:`ContextProjector` and :class:`StructuredContextProjector` + to avoid duplicating the slice-weight and slice-token computation. + + Parameters + ---------- + slice_projections : torch.Tensor + Projection of each token onto each slice, shape :math:`(B, H, N, S)`. + fx : torch.Tensor + Latent features to aggregate per slice, shape :math:`(B, H, N, D)`. + temperature : torch.Tensor + Scalar temperature for softmax/gumbel, shape broadcastable to projections. + plus : bool + If ``True``, use Gumbel softmax with optional adaptive temperature. + proj_temperature : nn.Module or None, optional + If ``plus`` is ``True``, module mapping :math:`(B, H, N, D)` to adaptive + temperature; ignored otherwise. Default is ``None``. + + Returns + ------- + slice_weights : torch.Tensor + Normalized weights per token and slice, shape :math:`(B, H, N, S)`. + slice_token : torch.Tensor + Aggregated features per slice, shape :math:`(B, H, S, D)`. + """ + if plus and proj_temperature is not None: + temp = temperature + proj_temperature(fx) + clamped_temp = torch.clamp(temp, min=0.01).to(slice_projections.dtype) + slice_weights = gumbel_softmax(slice_projections, clamped_temp) + else: + clamped_temp = torch.clamp(temperature, min=0.5, max=5).to( + slice_projections.dtype + ) + slice_weights = nn.functional.softmax( + slice_projections / clamped_temp, dim=-1 + ) + slice_weights = slice_weights.to(slice_projections.dtype) + slice_norm = slice_weights.sum(2) + normed_weights = slice_weights / (slice_norm[:, :, None, :] + 1e-2) + slice_token = torch.matmul(normed_weights.transpose(2, 3), fx) + return slice_weights, slice_token + + +def _structured_grid_to_conv_input( + x: Float[torch.Tensor, "batch tokens channels"], + batch: int, + tokens: int, + channels: int, + ndim: int, + spatial_shape: tuple[int, ...], +) -> Float[torch.Tensor, "batch channels ..."]: + r"""Reshape flat token tensor to spatial layout for Conv2d/Conv3d. + + Converts :math:`(B, N, C)` to :math:`(B, C, H, W)` for 2D or + :math:`(B, C, H, W, D)` for 3D so that structured context projectors + can apply spatial convolutions. Validates that :math:`N` matches the + grid size. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(B, N, C)` (batch, tokens, channels). + batch : int + Batch size :math:`B`. + tokens : int + Number of tokens :math:`N` (must equal :math:`H \\times W` or + :math:`H \\times W \\times D`). + channels : int + Channel dimension :math:`C`. + ndim : int + Number of spatial dimensions; must be 2 or 3. + spatial_shape : tuple[int, ...] + :math:`(H, W)` for 2D or :math:`(H, W, D)` for 3D. + + Returns + ------- + torch.Tensor + Reshaped tensor of shape :math:`(B, C, H, W)` or + :math:`(B, C, H, W, D)` for use as conv input. + + Raises + ------ + ValueError + If ``tokens`` does not match the product of ``spatial_shape``. + """ + if ndim == 2: + H, W = spatial_shape + if tokens != H * W: + raise ValueError( + f"Expected N={H * W} tokens for 2D grid, got N={tokens}" + ) + return x.view(batch, H, W, channels).permute(0, 3, 1, 2) + H, W, D = spatial_shape + if tokens != H * W * D: + raise ValueError( + f"Expected N={H * W * D} tokens for 3D grid, got N={tokens}" + ) + return x.view(batch, H, W, D, channels).permute(0, 4, 1, 2, 3) + + +class _SliceToContextMixin: + r"""Internal mixin providing shared slice-to-context init and slice aggregation. + + Used by :class:`ContextProjector` and :class:`StructuredContextProjector` to + avoid duplicating in_project_slice, temperature, proj_temperature, and + compute_slices_from_projections. + """ + + def _init_slice_components( + self, + dim_head: int, + slice_num: int, + heads: int, + use_te: bool, + plus: bool, + ) -> None: + r"""Initialize slice projection, temperature, and optional adaptive temperature. + + Sets ``in_project_slice``, ``temperature``, and (when ``plus`` is True) + ``proj_temperature`` on this instance. Uses Transformer Engine linear + when ``use_te`` is True and TE is available. + + Parameters + ---------- + dim_head : int + Head dimension for the slice projection input. + slice_num : int + Number of slices (output dimension of ``in_project_slice``). + heads : int + Number of heads (used for temperature shape). + use_te : bool + Whether to prefer Transformer Engine for linear layers. + plus : bool + If True, add ``proj_temperature`` for Transolver++. + """ + linear_layer = te.Linear if (use_te and TE_AVAILABLE) else nn.Linear + self.in_project_slice = linear_layer(dim_head, slice_num) + self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) + if plus: + self.proj_temperature = nn.Sequential( + linear_layer(dim_head, slice_num), + nn.GELU(), + linear_layer(slice_num, 1), + nn.GELU(), + ) + + def compute_slices_from_projections( + self, + slice_projections: Float[torch.Tensor, "batch heads tokens slices"], + fx: Float[torch.Tensor, "batch heads tokens dim"], + ) -> tuple[ + Float[torch.Tensor, "batch heads tokens slices"], + Float[torch.Tensor, "batch heads slices dim"], + ]: + r"""Compute slice weights and slice tokens from projections and latent features. + + Delegates to :func:`_compute_slices_from_projections_impl` using this + instance's ``temperature``, ``plus``, and (when plus) ``proj_temperature``. + + Parameters + ---------- + slice_projections : torch.Tensor + Shape :math:`(B, H, N, S)`. + fx : torch.Tensor + Shape :math:`(B, H, N, D)`. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + ``(slice_weights, slice_token)`` with shapes :math:`(B, H, N, S)` + and :math:`(B, H, S, D)`. + """ + proj_temp = getattr(self, "proj_temperature", None) if self.plus else None + return _compute_slices_from_projections_impl( + slice_projections, + fx, + self.temperature, + self.plus, + proj_temperature=proj_temp, + ) + + +class ContextProjector(_SliceToContextMixin, nn.Module): r"""Projects context features onto physical state space. This context projector is conceptually similar to half of a GALE attention layer. @@ -136,19 +329,8 @@ def __init__( # Attention components self.softmax = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) - self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) - - # Transolver++ adaptive temperature projection - if plus: - self.proj_temperature = nn.Sequential( - linear_layer(self.dim_head, slice_num), - nn.GELU(), - linear_layer(slice_num, 1), - nn.GELU(), - ) - # Slice projection layer maps from head dimension to slice space - self.in_project_slice = linear_layer(dim_head, slice_num) + self._init_slice_components(dim_head, slice_num, heads, use_te, plus) def project_input_onto_slices( self, x: Float[torch.Tensor, "batch tokens channels"] @@ -193,72 +375,6 @@ def project_input_onto_slices( ) return projected_x, feature_projection - def compute_slices_from_projections( - self, - slice_projections: Float[torch.Tensor, "batch heads tokens slices"], - fx: Float[torch.Tensor, "batch heads tokens dim"], - ) -> tuple[ - Float[torch.Tensor, "batch heads tokens slices"], - Float[torch.Tensor, "batch heads slices dim"], - ]: - r"""Compute slice weights and slice tokens from input projections and latent features. - - Parameters - ---------- - slice_projections : torch.Tensor - Projected input tensor of shape :math:`(B, H, N, S)` where :math:`B` is batch size, - :math:`H` is number of heads, :math:`N` is number of tokens, and :math:`S` is number of - slices, representing the projection of each token onto each slice for each - attention head. - fx : torch.Tensor - Latent feature tensor of shape :math:`(B, H, N, D)` where :math:`D` is head dimension, - representing the learned states to be aggregated by the slice weights. - - Returns - ------- - tuple[torch.Tensor, torch.Tensor] - - ``slice_weights``: Tensor of shape :math:`(B, H, N, S)`, normalized weights for - each slice per token and head. - - ``slice_token``: Tensor of shape :math:`(B, H, S, D)`, aggregated latent features - for each slice, head, and batch. - - Notes - ----- - The function computes a temperature-scaled softmax over the slice projections to - obtain slice weights, then aggregates the latent features for each slice using - these weights. The aggregated features are normalized by the sum of weights for - numerical stability. - """ - # Compute temperature-adjusted softmax weights - if self.plus: - # Transolver++ uses adaptive temperature with Gumbel softmax - temperature = self.temperature + self.proj_temperature(fx) - clamped_temp = torch.clamp(temperature, min=0.01).to( - slice_projections.dtype - ) - slice_weights = gumbel_softmax(slice_projections, clamped_temp) - else: - # Standard Transolver uses fixed temperature with regular softmax - clamped_temp = torch.clamp(self.temperature, min=0.5, max=5).to( - slice_projections.dtype - ) - slice_weights = nn.functional.softmax( - slice_projections / clamped_temp, dim=-1 - ) - - # Ensure weights match the computation dtype - slice_weights = slice_weights.to(slice_projections.dtype) - - # Aggregate features by slice weights with normalization - # Normalize first to prevent overflow in reduced precision - slice_norm = slice_weights.sum(2) # Sum over tokens: (B, H, S) - normed_weights = slice_weights / (slice_norm[:, :, None, :] + 1e-2) - - # Weighted aggregation: (B, H, S, N) @ (B, H, N, D) -> (B, H, S, D) - slice_token = torch.matmul(normed_weights.transpose(2, 3), fx) - - return slice_weights, slice_token - def forward( self, x: Float[torch.Tensor, "batch tokens channels"] ) -> Float[torch.Tensor, "batch heads slices dim"]: @@ -313,6 +429,101 @@ def forward( return slice_tokens +class StructuredContextProjector(_SliceToContextMixin, nn.Module): + r"""Context projector with Conv2d/Conv3d geometry encoding on structured grids. + + Same output interface as :class:`ContextProjector`—slice tokens + :math:`(B, H, S, D)`—but projects per-cell geometry via spatial convolutions + aligned with structured GALE attention. + """ + + def __init__( + self, + dim: int, + spatial_shape: tuple[int, ...], + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + slice_num: int = 64, + kernel: int = 3, + use_te: bool = True, + plus: bool = False, + ) -> None: + super().__init__() + if len(spatial_shape) not in (2, 3): + raise ValueError( + f"StructuredContextProjector expects spatial_shape of length 2 or 3, got {spatial_shape!r}" + ) + inner_dim = dim_head * heads + self.dim_head = dim_head + self.heads = heads + self.plus = plus + self.use_te = use_te + self.spatial_shape = tuple(int(s) for s in spatial_shape) + self._nd = len(self.spatial_shape) + pad = kernel // 2 + if self._nd == 2: + H, W = self.spatial_shape + self.H, self.W = H, W + self.in_project_x = nn.Conv2d(dim, inner_dim, kernel, 1, pad) + if not plus: + self.in_project_fx = nn.Conv2d(dim, inner_dim, kernel, 1, pad) + else: + H, W, D_ = self.spatial_shape + self.H, self.W, self.D = H, W, D_ + self.in_project_x = nn.Conv3d(dim, inner_dim, kernel, 1, pad) + if not plus: + self.in_project_fx = nn.Conv3d(dim, inner_dim, kernel, 1, pad) + + self.softmax = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + self._init_slice_components(dim_head, slice_num, heads, use_te, plus) + + def _grid_project( + self, x: Float[torch.Tensor, "batch tokens channels"] + ) -> tuple[ + Float[torch.Tensor, "batch heads tokens dim"], + Float[torch.Tensor, "batch heads tokens dim"], + ]: + B, N, C = x.shape + grid = _structured_grid_to_conv_input( + x, B, N, C, self._nd, self.spatial_shape + ) + pattern = ( + "B (H D) h w -> B H (h w) D" + if self._nd == 2 + else "B (H D) h w d -> B H (h w d) D" + ) + px = rearrange( + self.in_project_x(grid), pattern, H=self.heads, D=self.dim_head + ) + if self.plus: + return px, px + pfx = rearrange( + self.in_project_fx(grid), pattern, H=self.heads, D=self.dim_head + ) + return px, pfx + + def forward( + self, x: Float[torch.Tensor, "batch tokens channels"] + ) -> Float[torch.Tensor, "batch heads slices dim"]: + if not torch.compiler.is_compiling(): + if x.ndim != 3: + raise ValueError( + f"Expected 3D input (B, N, C), got {x.ndim}D shape {tuple(x.shape)}" + ) + if self.plus: + projected_x = self._grid_project(x)[0] + feature_projection = projected_x + else: + projected_x, feature_projection = self._grid_project(x) + slice_projections = self.in_project_slice(projected_x) + _, slice_tokens = self.compute_slices_from_projections( + slice_projections, feature_projection + ) + return slice_tokens + + class GeometricFeatureProcessor(nn.Module): r"""Processes geometric features at a single spatial scale using BQWarp. @@ -616,6 +827,10 @@ class GlobalContextBuilder(nn.Module): Whether to use Transolver++ features. Default is ``False``. include_local_features : bool, optional Enable local feature extraction. Default is ``False``. + structured_shape : tuple[int, ...] | None, optional + If set, disables ball-query extractors and uses + :class:`StructuredContextProjector` for geometry when ``geometry_dim`` + is set. Default is ``None``. Forward ------- @@ -663,6 +878,7 @@ def __init__( use_te: bool = True, plus: bool = False, include_local_features: bool = False, + structured_shape: tuple[int, ...] | None = None, ) -> None: super().__init__() @@ -674,9 +890,17 @@ def __init__( dim_head = n_hidden // n_head context_dim = 0 + self.structured_shape = structured_shape + + # Ball-query local features are not used on structured grids + use_local_bq = ( + geometry_dim is not None + and include_local_features + and structured_shape is None + ) # Multi-scale extractors for local features (one per functional dim) - if geometry_dim is not None and include_local_features: + if use_local_bq: self.local_extractors = nn.ModuleList( [ MultiScaleFeatureExtractor( @@ -700,9 +924,21 @@ def __init__( # Geometry tokenizer for global geometry context if geometry_dim is not None: - self.geometry_tokenizer = ContextProjector( - geometry_dim, n_head, dim_head, dropout, slice_num, use_te, plus - ) + if structured_shape is not None: + self.geometry_tokenizer = StructuredContextProjector( + geometry_dim, + structured_shape, + n_head, + dim_head, + dropout, + slice_num, + use_te=use_te, + plus=plus, + ) + else: + self.geometry_tokenizer = ContextProjector( + geometry_dim, n_head, dim_head, dropout, slice_num, use_te, plus + ) context_dim += dim_head else: self.geometry_tokenizer = None diff --git a/physicsnemo/experimental/models/geotransolver/gale.py b/physicsnemo/experimental/models/geotransolver/gale.py index b64f8e9df5..d8b4870be5 100644 --- a/physicsnemo/experimental/models/geotransolver/gale.py +++ b/physicsnemo/experimental/models/geotransolver/gale.py @@ -33,6 +33,8 @@ from physicsnemo.nn import Mlp from physicsnemo.nn.module.physics_attention import ( PhysicsAttentionIrregularMesh, + PhysicsAttentionStructuredMesh2D, + PhysicsAttentionStructuredMesh3D, ) # Check optional dependency availability @@ -41,6 +43,148 @@ import transformer_engine.pytorch as te +def _gale_compute_slice_attention_cross( + module: nn.Module, + slice_tokens: list[Float[torch.Tensor, "batch heads slices dim"]], + context: Float[torch.Tensor, "batch heads context_slices context_dim"], +) -> list[Float[torch.Tensor, "batch heads slices dim"]]: + r"""Shared cross-attention between slice tokens and context. + + Used by :class:`GALE` and :class:`_GALEStructuredForwardMixin` so the + cross-attention implementation lives in one place. Projects queries from + concatenated slice tokens, keys and values from context; runs Transformer + Engine or SDPA attention; splits the result back to one tensor per input. + + Parameters + ---------- + module : nn.Module + Module with ``cross_q``, ``cross_k``, ``cross_v``, ``use_te``, + ``heads``, ``dim_head``, and (if ``use_te``) ``attn_fn``. + slice_tokens : list[torch.Tensor] + One tensor per input, each of shape :math:`(B, H, S, D)`. + context : torch.Tensor + Context tensor of shape :math:`(B, H, S_c, D_c)`. + + Returns + ------- + list[torch.Tensor] + One cross-attention output per element of ``slice_tokens``, each + of shape :math:`(B, H, S, D)`. + """ + q_input = torch.cat(slice_tokens, dim=-2) + q = module.cross_q(q_input) + k = module.cross_k(context) + v = module.cross_v(context) + if module.use_te: + q = rearrange(q, "b h s d -> b s h d") + k = rearrange(k, "b h s d -> b s h d") + v = rearrange(v, "b h s d -> b s h d") + cross_attention = module.attn_fn(q, k, v) + cross_attention = rearrange( + cross_attention, + "b s (h d) -> b h s d", + h=module.heads, + d=module.dim_head, + ) + else: + cross_attention = torch.nn.functional.scaled_dot_product_attention( + q, k, v, is_causal=False + ) + cross_attention = torch.split( + cross_attention, slice_tokens[0].shape[-2], dim=-2 + ) + return list(cross_attention) + + +def _gale_forward_impl( + module: nn.Module, + x: tuple[Float[torch.Tensor, "batch tokens channels"], ...], + context: Float[torch.Tensor, "batch heads context_slices context_dim"] + | None, +) -> list[Float[torch.Tensor, "batch tokens channels"]]: + r"""Single implementation of the GALE forward pipeline. + + Shared by :class:`GALE` and :class:`_GALEStructuredForwardMixin`. Steps: + validate inputs; project onto slices; compute slice weights and tokens; + apply self-attention on slices; optionally cross-attend to context and + mix with ``state_mixing``; project attention outputs back to token space. + + Parameters + ---------- + module : nn.Module + GALE-like module with ``project_input_onto_slices``, + ``in_project_slice``, ``_compute_slices_from_projections``, + ``_compute_slice_attention_te``, ``_compute_slice_attention_sdpa``, + ``compute_slice_attention_cross``, ``_project_attention_outputs``, + plus attributes ``use_te``, ``plus``, ``state_mixing``. + x : tuple[torch.Tensor, ...] + Input tensors, each of shape :math:`(B, N, C)`; must be non-empty. + context : torch.Tensor or None + Optional context of shape :math:`(B, H, S_c, D_c)` for cross-attention. + If ``None``, only self-attention is applied. + + Returns + ------- + list[torch.Tensor] + One output tensor per input, each of shape :math:`(B, N, C)`. + + Raises + ------ + ValueError + If ``x`` is empty or any element is not 3D. + """ + if not torch.compiler.is_compiling(): + if len(x) == 0: + raise ValueError("Expected non-empty tuple of input tensors") + for i, tensor in enumerate(x): + if tensor.ndim != 3: + raise ValueError( + f"Expected 3D input tensor (B, N, C) at index {i}, " + f"got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}" + ) + if module.plus: + x_mid = [module.project_input_onto_slices(_x) for _x in x] + fx_mid = [_x_mid for _x_mid in x_mid] + else: + x_mid, fx_mid = zip( + *[module.project_input_onto_slices(_x) for _x in x] + ) + slice_projections = [module.in_project_slice(_x_mid) for _x_mid in x_mid] + slice_weights, slice_tokens = zip( + *[ + module._compute_slices_from_projections(proj, _fx_mid) + for proj, _fx_mid in zip(slice_projections, fx_mid) + ] + ) + if module.use_te: + self_slice_token = [ + module._compute_slice_attention_te(_slice_token) + for _slice_token in slice_tokens + ] + else: + self_slice_token = [ + module._compute_slice_attention_sdpa(_slice_token) + for _slice_token in slice_tokens + ] + if context is not None: + cross_slice_token = [ + module.compute_slice_attention_cross([_slice_token], context)[0] + for _slice_token in slice_tokens + ] + mixing_weight = torch.sigmoid(module.state_mixing) + out_slice_token = [ + mixing_weight * sst + (1 - mixing_weight) * cst + for sst, cst in zip(self_slice_token, cross_slice_token) + ] + else: + out_slice_token = self_slice_token + outputs = [ + module._project_attention_outputs(ost, sw) + for ost, sw in zip(out_slice_token, slice_weights) + ] + return outputs + + class GALE(PhysicsAttentionIrregularMesh): r"""Geometry-Aware Latent Embeddings (GALE) attention layer. @@ -121,7 +265,7 @@ def __init__( ) -> None: super().__init__(dim, heads, dim_head, dropout, slice_num, use_te, plus) - linear_layer = te.Linear if self.use_te else nn.Linear + linear_layer = te.Linear if (self.use_te and TE_AVAILABLE) else nn.Linear # Cross-attention projection layers for context integration self.cross_q = linear_layer(dim_head, dim_head) @@ -154,39 +298,10 @@ def compute_slice_attention_cross( list[torch.Tensor] List of cross-attention outputs, each of shape :math:`(B, H, S, D)`. """ - # Concatenate all slice tokens for batched projection - q_input = torch.cat(slice_tokens, dim=-2) # (B, H, total_slices, D) - - # Project queries from slice tokens - q = self.cross_q(q_input) # (B, H, total_slices, D) - - # Project keys and values from context - k = self.cross_k(context) # (B, H, S_c, D) - v = self.cross_v(context) # (B, H, S_c, D) - - # Compute cross-attention using appropriate backend - if self.use_te: - # Transformer Engine expects (B, S, H, D) format - q = rearrange(q, "b h s d -> b s h d") - k = rearrange(k, "b h s d -> b s h d") - v = rearrange(v, "b h s d -> b s h d") - cross_attention = self.attn_fn(q, k, v) - cross_attention = rearrange( - cross_attention, "b s (h d) -> b h s d", h=self.heads, d=self.dim_head - ) - else: - # Use PyTorch's scaled dot-product attention - cross_attention = torch.nn.functional.scaled_dot_product_attention( - q, k, v, is_causal=False - ) - - # Split back into individual slice token outputs - cross_attention = torch.split( - cross_attention, slice_tokens[0].shape[-2], dim=-2 + return _gale_compute_slice_attention_cross( + self, slice_tokens, context ) - return list(cross_attention) - def forward( self, x: tuple[Float[torch.Tensor, "batch tokens channels"], ...], @@ -216,74 +331,102 @@ def forward( List of output tensors, each of shape :math:`(B, N, C)``, same shape as inputs. """ - ### Input validation - if not torch.compiler.is_compiling(): - if len(x) == 0: - raise ValueError("Expected non-empty tuple of input tensors") - for i, tensor in enumerate(x): - if tensor.ndim != 3: - raise ValueError( - f"Expected 3D input tensor (B, N, C) at index {i}, " - f"got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}" - ) + return _gale_forward_impl(self, x, context) - # Project inputs onto learned latent spaces - if self.plus: - x_mid = [self.project_input_onto_slices(_x) for _x in x] - # In Transolver++, x_mid is reused for both projections - fx_mid = [_x_mid for _x_mid in x_mid] - else: - x_mid, fx_mid = zip( - *[self.project_input_onto_slices(_x) for _x in x] - ) - # Project latent representations onto physical state slices - slice_projections = [self.in_project_slice(_x_mid) for _x_mid in x_mid] +def _gale_cross_init( + self: nn.Module, + dim_head: int, + context_dim: int, + use_te: bool, +) -> None: + # Match GALE: TE linear only when TE is installed (GALE_block already errors if use_te without TE) + linear_layer = te.Linear if (use_te and TE_AVAILABLE) else nn.Linear + self.cross_q = linear_layer(dim_head, dim_head) + self.cross_k = linear_layer(context_dim, dim_head) + self.cross_v = linear_layer(context_dim, dim_head) + self.state_mixing = nn.Parameter(torch.tensor(0.0)) + - # Compute slice weights and aggregated slice tokens - slice_weights, slice_tokens = zip( - *[ - self._compute_slices_from_projections(proj, _fx_mid) - for proj, _fx_mid in zip(slice_projections, fx_mid) - ] +class _GALEStructuredForwardMixin: + """Shared cross-attention and forward for structured GALE (2D/3D conv projection).""" + + def compute_slice_attention_cross( + self, + slice_tokens: list[Float[torch.Tensor, "batch heads slices dim"]], + context: Float[torch.Tensor, "batch heads context_slices context_dim"], + ) -> list[Float[torch.Tensor, "batch heads slices dim"]]: + return _gale_compute_slice_attention_cross( + self, slice_tokens, context ) - # Apply self-attention to slice tokens - if self.use_te: - self_slice_token = [ - self._compute_slice_attention_te(_slice_token) - for _slice_token in slice_tokens - ] - else: - self_slice_token = [ - self._compute_slice_attention_sdpa(_slice_token) - for _slice_token in slice_tokens - ] - - # Apply cross-attention with context if provided - if context is not None: - cross_slice_token = [ - self.compute_slice_attention_cross([_slice_token], context)[0] - for _slice_token in slice_tokens - ] - - # Blend self-attention and cross-attention with learnable mixing weight - mixing_weight = torch.sigmoid(self.state_mixing) - out_slice_token = [ - mixing_weight * sst + (1 - mixing_weight) * cst - for sst, cst in zip(self_slice_token, cross_slice_token) - ] - else: - # Use only self-attention when no context is provided - out_slice_token = self_slice_token + def forward( + self, + x: tuple[Float[torch.Tensor, "batch tokens channels"], ...], + context: Float[torch.Tensor, "batch heads context_slices context_dim"] + | None = None, + ) -> list[Float[torch.Tensor, "batch tokens channels"]]: + return _gale_forward_impl(self, x, context) - # Project attention outputs back to original space using slice weights - outputs = [ - self._project_attention_outputs(ost, sw) - for ost, sw in zip(out_slice_token, slice_weights) - ] - return outputs +class GALEStructuredMesh2D(_GALEStructuredForwardMixin, PhysicsAttentionStructuredMesh2D): + r"""GALE with Conv2d slice projection for 2D structured grids (see :class:`GALE`).""" + + def __init__( + self, + dim: int, + spatial_shape: tuple[int, int], + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + slice_num: int = 64, + kernel: int = 3, + use_te: bool = True, + plus: bool = False, + context_dim: int = 0, + ) -> None: + super().__init__( + dim, + spatial_shape, + heads, + dim_head, + dropout, + slice_num, + kernel, + use_te, + plus, + ) + _gale_cross_init(self, dim_head, context_dim, use_te) + + +class GALEStructuredMesh3D(_GALEStructuredForwardMixin, PhysicsAttentionStructuredMesh3D): + r"""GALE with Conv3d slice projection for 3D structured grids (see :class:`GALE`).""" + + def __init__( + self, + dim: int, + spatial_shape: tuple[int, int, int], + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + slice_num: int = 64, + kernel: int = 3, + use_te: bool = True, + plus: bool = False, + context_dim: int = 0, + ) -> None: + super().__init__( + dim, + spatial_shape, + heads, + dim_head, + dropout, + slice_num, + kernel, + use_te, + plus, + ) + _gale_cross_init(self, dim_head, context_dim, use_te) class GALE_block(nn.Module): @@ -317,6 +460,10 @@ class GALE_block(nn.Module): Whether to use Transolver++ features. Default is ``False``. context_dim : int, optional Dimension of the context vector for cross-attention. Default is 0. + spatial_shape : tuple[int, ...] | None, optional + If ``None``, uses irregular-mesh GALE. Length-2 tuple enables 2D Conv2d + projection; length-3 tuple enables 3D Conv3d projection (flattened + :math:`N = H \times W` or :math:`H \times W \times D`). Default is ``None``. Forward ------- @@ -369,6 +516,7 @@ def __init__( use_te: bool = True, plus: bool = False, context_dim: int = 0, + spatial_shape: tuple[int, ...] | None = None, ) -> None: super().__init__() @@ -386,17 +534,50 @@ def __init__( else: self.ln_1 = nn.LayerNorm(hidden_dim) - # GALE attention layer - self.Attn = GALE( - hidden_dim, - heads=num_heads, - dim_head=hidden_dim // num_heads, - dropout=dropout, - slice_num=slice_num, - use_te=use_te, - plus=plus, - context_dim=context_dim, - ) + dim_head = hidden_dim // num_heads + if spatial_shape is None: + self.Attn = GALE( + hidden_dim, + heads=num_heads, + dim_head=dim_head, + dropout=dropout, + slice_num=slice_num, + use_te=use_te, + plus=plus, + context_dim=context_dim, + ) + elif len(spatial_shape) == 2: + self.Attn = GALEStructuredMesh2D( + hidden_dim, + spatial_shape=(int(spatial_shape[0]), int(spatial_shape[1])), + heads=num_heads, + dim_head=dim_head, + dropout=dropout, + slice_num=slice_num, + use_te=use_te, + plus=plus, + context_dim=context_dim, + ) + elif len(spatial_shape) == 3: + self.Attn = GALEStructuredMesh3D( + hidden_dim, + spatial_shape=( + int(spatial_shape[0]), + int(spatial_shape[1]), + int(spatial_shape[2]), + ), + heads=num_heads, + dim_head=dim_head, + dropout=dropout, + slice_num=slice_num, + use_te=use_te, + plus=plus, + context_dim=context_dim, + ) + else: + raise ValueError( + f"spatial_shape must be None, length-2, or length-3; got {spatial_shape!r}" + ) # Feed-forward network with layer normalization if use_te: diff --git a/physicsnemo/experimental/models/geotransolver/geotransolver.py b/physicsnemo/experimental/models/geotransolver/geotransolver.py index 5b106810df..b743b6c966 100644 --- a/physicsnemo/experimental/models/geotransolver/geotransolver.py +++ b/physicsnemo/experimental/models/geotransolver/geotransolver.py @@ -23,6 +23,7 @@ from __future__ import annotations +import math from collections.abc import Sequence from dataclasses import dataclass @@ -143,6 +144,47 @@ def _normalize_tensor( raise TypeError(f"Invalid tensor structure") +def _structured_num_tokens(spatial_shape: tuple[int, ...]) -> int: + return int(math.prod(spatial_shape)) + + +def _flatten_for_structured( + t: torch.Tensor, + spatial_shape: tuple[int, ...], + name: str, +) -> torch.Tensor: + """Flatten (B,H,W,C) or (B,H,W,D,C) to (B,N,C); pass through (B,N,C) if N matches. + + Mirrors Transolver's structured flatten/unflatten behavior so the rest of + GeoTransolver can assume a single token layout (B, N, C). + """ + n = _structured_num_tokens(spatial_shape) + if t.ndim == 3: + if not torch.compiler.is_compiling() and t.shape[1] != n: + raise ValueError( + f"{name} token count {t.shape[1]} != structured grid size {n}" + ) + return t + if len(spatial_shape) == 2 and t.ndim == 4: + B, H, W, C = t.shape + if (H, W) != spatial_shape: + raise ValueError( + f"{name} spatial dims {(H, W)} != structured_shape {spatial_shape}" + ) + return t.reshape(B, n, C) + if len(spatial_shape) == 3 and t.ndim == 5: + B, H, W, D, C = t.shape + if (H, W, D) != spatial_shape: + raise ValueError( + f"{name} spatial dims {(H, W, D)} != structured_shape {spatial_shape}" + ) + return t.reshape(B, n, C) + raise ValueError( + f"{name}: expected (B,N,C) with N={n}, or spatial layout matching " + f"structured_shape {spatial_shape}; got shape {tuple(t.shape)}" + ) + + class GeoTransolver(Module): r"""GeoTransolver: Geometry-Aware Physics Attention Transformer. @@ -204,13 +246,18 @@ class GeoTransolver(Module): Neighbors in radius for the local features. Default is ``[8, 32]``. n_hidden_local : int, optional Hidden dimension for the local features. Default is 32. + structured_shape : tuple[int, ...] | None, optional + If set to ``(H, W)`` or ``(H, W, D)``, enables structured 2D/3D paths + (Conv2d/Conv3d GALE; no ball-query local features). Inputs may be + flattened :math:`(B, N, C)` with :math:`N = H W` or :math:`H W D`, or + spatial :math:`(B, H, W, C)` / :math:`(B, H, W, D, C)`. Default is ``None``. Forward ------- local_embedding : torch.Tensor | tuple[torch.Tensor, ...] - Local embedding of the input data of shape :math:`(B, N, C)` where :math:`B` - is batch size, :math:`N` is number of nodes/tokens, and :math:`C` is - ``functional_dim``. Can be a single tensor or tuple for multiple input types. + Local embedding: unstructured :math:`(B, N, C)`; structured 2D + :math:`(B, H, W, C)` or flattened :math:`(B, H W, C)`; structured 3D + :math:`(B, H, W, D, C)` or flattened. Can be a tuple for multiple input types. local_positions : torch.Tensor | tuple[torch.Tensor, ...] | None, optional Local positions for each input, each of shape :math:`(B, N, 3)`. Required if ``include_local_features=True``. Default is ``None``. @@ -228,9 +275,9 @@ class GeoTransolver(Module): Outputs ------- torch.Tensor | tuple[torch.Tensor, ...] - Output tensor of shape :math:`(B, N, C_{out})` where :math:`C_{out}` is - ``out_dim``. Returns a single tensor if input was a single tensor, or a - tuple if input was a tuple. + Unstructured: :math:`(B, N, C_{out})`. Structured: same as input layout— + flattened :math:`(B, N, C_{out})` or spatial :math:`(B, H, W, C_{out})` / + :math:`(B, H, W, D, C_{out})` when inputs were 4D/5D. Tuple if tuple in. Raises ------ @@ -244,8 +291,9 @@ class GeoTransolver(Module): Notes ----- - GeoTransolver currently supports unstructured mesh input only. Enhancements for - image-based and voxel-based inputs may be available in the future. + Unstructured mesh uses linear GALE projection; structured ``structured_shape`` + uses the same Conv2d/Conv3d slice projection as :class:`~physicsnemo.models.transolver.Transolver`. + Ball-query local features are disabled when ``structured_shape`` is set. For more details on Transolver, see: @@ -293,6 +341,21 @@ class GeoTransolver(Module): >>> output = model(local_emb, global_embedding=global_emb, geometry=geometry) >>> output.shape torch.Size([2, 1000, 3]) + + Structured 2D grid: + + >>> model = GeoTransolver( + ... functional_dim=3, + ... out_dim=1, + ... structured_shape=(8, 8), + ... n_hidden=64, + ... n_head=4, + ... n_layers=2, + ... use_te=False, + ... ) + >>> y = model(torch.randn(2, 8, 8, 3)) + >>> y.shape + torch.Size([2, 8, 8, 1]) """ def __init__( @@ -315,6 +378,7 @@ def __init__( radii: list[float] | None = None, neighbors_in_radius: list[int] | None = None, n_hidden_local: int = 32, + structured_shape: tuple[int, ...] | None = None, ) -> None: super().__init__(meta=GeoTransolverMetaData()) self.__name__ = "GeoTransolver" @@ -325,8 +389,22 @@ def __init__( if neighbors_in_radius is None: neighbors_in_radius = [8, 32] + if structured_shape is not None: + if include_local_features: + raise ValueError( + "include_local_features=True is not supported with structured_shape " + "(ball-query path is mesh-only)." + ) + if len(structured_shape) not in (2, 3): + raise ValueError( + f"structured_shape must have length 2 or 3, got {structured_shape!r}" + ) + if not all(int(s) > 0 for s in structured_shape): + raise ValueError(f"structured_shape must be positive ints, got {structured_shape!r}") + self.include_local_features = include_local_features self.use_te = use_te + self.structured_shape = structured_shape # Validate head dimension compatibility if not n_hidden % n_head == 0: @@ -357,6 +435,7 @@ def __init__( use_te=use_te, plus=plus, include_local_features=self.include_local_features, + structured_shape=structured_shape, ) context_dim = self.context_builder.get_context_dim() @@ -404,6 +483,7 @@ def __init__( use_te=use_te, plus=plus, context_dim=context_dim, + spatial_shape=structured_shape, ) for layer_idx in range(n_layers) ] @@ -507,6 +587,27 @@ def forward( if local_positions is not None: local_positions = _normalize_tensor(local_positions) + unflatten_output = False + if self.structured_shape is not None: + unflatten_output = any(le.ndim in (4, 5) for le in local_embedding) + local_embedding = tuple( + _flatten_for_structured( + le, self.structured_shape, f"local_embedding[{i}]" + ) + for i, le in enumerate(local_embedding) + ) + if geometry is not None: + geometry = _flatten_for_structured( + geometry, self.structured_shape, "geometry" + ) + n_tok = _structured_num_tokens(self.structured_shape) + for i, le in enumerate(local_embedding): + if le.shape[1] != n_tok: + raise ValueError( + f"structured GeoTransolver: all streams must have N={n_tok} tokens; " + f"local_embedding[{i}] has N={le.shape[1]}" + ) + ### Input validation if not torch.compiler.is_compiling(): if len(local_embedding) == 0: @@ -550,6 +651,16 @@ def forward( # Project to output dimensions: (B, N, n_hidden) -> (B, N, out_dim) x = [self.ln_mlp_out[i](x[i]) for i in range(len(x))] + if self.structured_shape is not None and unflatten_output: + B = x[0].shape[0] + for i in range(len(x)): + if len(self.structured_shape) == 2: + H, W = self.structured_shape + x[i] = x[i].reshape(B, H, W, -1) + else: + H, W, D_ = self.structured_shape + x[i] = x[i].reshape(B, H, W, D_, -1) + # Return same format as input (single tensor or tuple) if single_input: x = x[0] diff --git a/test/models/geotransolver/test_geotransolver.py b/test/models/geotransolver/test_geotransolver.py index c0a8d33968..82691fc527 100644 --- a/test/models/geotransolver/test_geotransolver.py +++ b/test/models/geotransolver/test_geotransolver.py @@ -556,6 +556,98 @@ def test_geotransolver_mismatched_functional_out_dims(): ) +def test_geotransolver_structured_rejects_local_features(): + """Ball-query local features are incompatible with structured_shape.""" + with pytest.raises(ValueError, match="include_local_features=True"): + GeoTransolver( + functional_dim=8, + out_dim=1, + structured_shape=(4, 4), + include_local_features=True, + geometry_dim=2, + use_te=False, + ) + + +def test_geotransolver_structured_2d_forward(device): + """Structured 2D: spatial input (B,H,W,C) and flattened (B,N,C); optional geometry.""" + torch.manual_seed(0) + H, W = 4, 4 + model = GeoTransolver( + functional_dim=3, + out_dim=2, + structured_shape=(H, W), + geometry_dim=2, + global_dim=None, + n_layers=2, + n_hidden=32, + n_head=4, + slice_num=8, + mlp_ratio=2, + use_te=False, + ).to(device) + B = 2 + x4 = torch.randn(B, H, W, 3, device=device) + g = torch.randn(B, H, W, 2, device=device) + y4 = model(x4, geometry=g) + assert y4.shape == (B, H, W, 2) + assert not torch.isnan(y4).any() + + x3 = x4.reshape(B, H * W, 3) + g3 = g.reshape(B, H * W, 2) + y3 = model(x3, geometry=g3) + assert y3.shape == (B, H * W, 2) + + y_none = model(x4) + assert y_none.shape == (B, H, W, 2) + + +def test_geotransolver_structured_3d_forward(device): + """Structured 3D voxel input (B,H,W,D,C).""" + torch.manual_seed(1) + H, W, Dg = 2, 2, 2 + model = GeoTransolver( + functional_dim=4, + out_dim=1, + structured_shape=(H, W, Dg), + n_layers=1, + n_hidden=32, + n_head=4, + slice_num=4, + mlp_ratio=2, + use_te=False, + ).to(device) + B = 1 + x = torch.randn(B, H, W, Dg, 4, device=device) + y = model(x) + assert y.shape == (B, H, W, Dg, 1) + + +def test_geotransolver_structured_global_context(device): + """Structured grid with global embedding context.""" + torch.manual_seed(2) + H, W = 4, 4 + model = GeoTransolver( + functional_dim=2, + out_dim=1, + structured_shape=(H, W), + geometry_dim=2, + global_dim=8, + n_layers=2, + n_hidden=32, + n_head=4, + slice_num=8, + mlp_ratio=2, + use_te=False, + ).to(device) + B = 2 + x = torch.randn(B, H, W, 2, device=device) + geo = torch.randn(B, H, W, 2, device=device) + glob = torch.randn(B, 3, 8, device=device) + y = model(x, geometry=geo, global_embedding=glob) + assert y.shape == (B, H, W, 1) + + # ============================================================================= # Activation Function Tests # =============================================================================