Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
ae8102e
Implements Predictor specialization for multi-diffusion
CharlelieLrt Apr 17, 2026
9bd0474
Compile denoiser in multi-diffusion sampling compile tests
CharlelieLrt Apr 17, 2026
9c84b00
Avoid fullgraph compile in multi-diffusion sampling test
CharlelieLrt Apr 17, 2026
f8b906a
Flatten MultiDiffusionPredictor hot path for torch.compile
CharlelieLrt Apr 17, 2026
1259e43
Loosen torch.compile euler check in multi-diffusion sampling tests
CharlelieLrt Apr 17, 2026
cefe1db
Force contiguous t_cur/t_next in Euler solvers
CharlelieLrt Apr 17, 2026
a000e1c
Drop dead is_compiling guard and inherit from Predictor in MultiDiffu…
CharlelieLrt Apr 18, 2026
8f945a5
Narrow _patching type and tighten multi-diffusion tests
CharlelieLrt Apr 18, 2026
7e1db11
Force contiguous pos_embd before patching
CharlelieLrt Apr 18, 2026
feb0d9e
Use functional F.pad in image_batching
CharlelieLrt Apr 20, 2026
3dfcdb5
Replace einops.rearrange with native torch reshape+permute
CharlelieLrt Apr 20, 2026
746518f
Materialise returned tensors in multi-diffusion fuse path
CharlelieLrt Apr 20, 2026
a007c46
Use clone instead of contiguous at fuse boundary
CharlelieLrt Apr 20, 2026
869a7c1
Revert speculative fuse-boundary copies and xfail full-sampler compil…
CharlelieLrt Apr 20, 2026
78c10d8
Merge remote-tracking branch 'upstream/main' into multi_diffusion_sam…
CharlelieLrt Apr 20, 2026
3207587
Minor updates to predictor.py
CharlelieLrt Apr 22, 2026
2759f57
Merge remote-tracking branch 'upstream/main' into multi_diffusion_sam…
CharlelieLrt Apr 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion physicsnemo/diffusion/guidance/dps_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __call__(
...


class DPSScorePredictor:
class DPSScorePredictor(Predictor):
r"""
Score predictor that combines an x0-predictor with DPS-style guidance.
Expand Down
1 change: 1 addition & 0 deletions physicsnemo/diffusion/multi_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
image_batching,
image_fuse,
)
from .predictor import MultiDiffusionPredictor
29 changes: 3 additions & 26 deletions physicsnemo/diffusion/multi_diffusion/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,7 @@
from physicsnemo.diffusion.base import PredictorType
from physicsnemo.diffusion.multi_diffusion.models import MultiDiffusionModel2D
from physicsnemo.diffusion.noise_schedulers import NoiseScheduler
from physicsnemo.diffusion.utils.utils import apply_loss_weight


def _unwrap_multi_diffusion(model: torch.nn.Module) -> MultiDiffusionModel2D:
"""Peel off DDP / torch.compile wrappers to reach the underlying
:class:`MultiDiffusionModel2D`.

The unwrapping order handles arbitrary nesting of
``DistributedDataParallel`` (``model.module``) and ``torch.compile``
(``OptimizedModule._orig_mod``).
"""
m = model
while not isinstance(m, MultiDiffusionModel2D):
if isinstance(m, torch._dynamo.eval_frame.OptimizedModule):
m = m._orig_mod
elif hasattr(m, "module"):
m = m.module
else:
raise TypeError(
f"Could not unwrap a MultiDiffusionModel2D from "
f"{type(model).__name__}. Found leaf type "
f"{type(m).__name__}."
)
return m
from physicsnemo.diffusion.utils.utils import _unwrap_module, apply_loss_weight


class _CompiledPatchX:
Expand Down Expand Up @@ -265,7 +242,7 @@ def __init__(
reduction: Literal["none", "mean", "sum"] = "mean",
) -> None:
self.model = model
self._md_model = _unwrap_multi_diffusion(model)
self._md_model = _unwrap_module(model, MultiDiffusionModel2D)
self.noise_scheduler = noise_scheduler
self._compiled_patch_x = _CompiledPatchX(self._md_model)

Expand Down Expand Up @@ -528,7 +505,7 @@ def __init__(
reduction: Literal["none", "mean", "sum"] = "mean",
) -> None:
self.model = model
self._md_model = _unwrap_multi_diffusion(model)
self._md_model = _unwrap_module(model, MultiDiffusionModel2D)
self.noise_scheduler = noise_scheduler
self._compiled_patch_x = _CompiledPatchX(self._md_model)

Expand Down
170 changes: 96 additions & 74 deletions physicsnemo/diffusion/multi_diffusion/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def __init__(
self._patching: RandomPatching2D | GridPatching2D | None = None
self._patching_type: Literal["random", "grid"] | None = None
self._fuse: bool = False
self._skip_positional_embedding_injection: bool = False
# Normalise condition flags to defaultdict for uniform access
if not isinstance(condition_patch, (bool, dict)):
raise TypeError(
Expand Down Expand Up @@ -545,7 +546,7 @@ def reset_patch_indices(self) -> None:
RuntimeError
If the current patching strategy is not random patching.
"""
if self._patching_type != "random":
if not isinstance(self._patching, RandomPatching2D):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the internal _patching_type attributes are not really used now, those can be dropped?

raise RuntimeError(
"reset_patch_indices() is only available when random "
"patching is active. Call set_random_patching() first."
Expand Down Expand Up @@ -633,12 +634,13 @@ def patch_x(
RuntimeError
If no patching strategy has been configured.
"""
patching = self._patching
if patching is None:
raise RuntimeError(
"No patching strategy set. Call set_random_patching() "
"or set_grid_patching() first."
)
if not torch.compiler.is_compiling():
if self._patching is None:
raise RuntimeError(
"No patching strategy set. Call set_random_patching() "
"or set_grid_patching() first."
)
if x.ndim != 4:
raise ValueError(
f"patch_x expects a 4D tensor (B, C, H, W), got {x.ndim}D."
Expand All @@ -648,7 +650,7 @@ def patch_x(
f"Spatial dimensions {tuple(x.shape[2:])} do not match "
f"global_spatial_shape {self.global_spatial_shape}."
)
return self._patching.apply(x)
return patching.apply(x)

def patch_t(self, t: Float[Tensor, " B"]) -> Float[Tensor, " P_times_B"]:
r"""Convert a diffusion-time tensor to patch-compatible format.
Expand All @@ -671,13 +673,13 @@ def patch_t(self, t: Float[Tensor, " B"]) -> Float[Tensor, " P_times_B"]:
RuntimeError
If no patching strategy has been configured.
"""
if not torch.compiler.is_compiling():
if self._patching is None:
raise RuntimeError(
"No patching strategy set. Call set_random_patching() "
"or set_grid_patching() first."
)
return t.repeat(self._patching.patch_num)
patching = self._patching
if patching is None:
raise RuntimeError(
"No patching strategy set. Call set_random_patching() "
"or set_grid_patching() first."
)
return t.repeat(patching.patch_num)

def patch_condition(
self,
Expand Down Expand Up @@ -737,16 +739,16 @@ def patch_condition(
>>> cp["vec"].shape # default: repeated P times
torch.Size([8, 5])
"""
if not torch.compiler.is_compiling():
if self._patching is None:
raise RuntimeError(
"No patching strategy set. Call set_random_patching() "
"or set_grid_patching() first."
)
patching = self._patching
if patching is None:
raise RuntimeError(
"No patching strategy set. Call set_random_patching() "
"or set_grid_patching() first."
)
if condition is None:
return None

P = self._patching.patch_num
P = patching.patch_num

if isinstance(condition, Tensor):
if self._condition_has_per_key_flags:
Expand Down Expand Up @@ -858,7 +860,7 @@ def fuse(
>>> torch.allclose(md.fuse(x_patched, batch_size=2), x)
True
"""
if self._patching_type != "grid":
if not isinstance(self._patching, GridPatching2D):
raise RuntimeError(
"Fusing is only supported with grid patching. "
"Call set_grid_patching() first."
Expand All @@ -883,36 +885,27 @@ def forward(
**model_kwargs: Any,
) -> Float[Tensor, "P_times_B C Hp Wp"] | Float[Tensor, "B C H W"]:
# No patching strategy: warn and pass through
if self._patching_type is None:
patching = self._patching
if patching is None:
if not torch.compiler.is_compiling():
warnings.warn(
"No patching strategy set on MultiDiffusionModel2D. "
"The model will run without patching.",
stacklevel=2,
)
if self.pos_embd is not None:
if (
self.pos_embd is not None
and not self._skip_positional_embedding_injection
):
B = x.shape[0]
pos_embd = self.pos_embd.unsqueeze(0).expand(B, -1, -1, -1)
if condition is None:
condition = TensorDict(
{"positional_embedding": pos_embd}, batch_size=[B]
)
elif isinstance(condition, TensorDict):
condition["positional_embedding"] = pos_embd
elif isinstance(condition, Tensor):
condition = TensorDict(
{"condition": condition, "positional_embedding": pos_embd},
batch_size=[B],
)
else:
raise ValueError(
"When positional embeddings are configured, condition "
"must be a Tensor, TensorDict, or None, got "
f"{type(condition).__name__}."
)
# .expand creates a stride-0 view that can trip up downstream
# torch ops (e.g. nn.ReflectionPad2d / F.unfold on torch 2.10).
# Materialise a contiguous copy before handing it off.
pos_embd = self.pos_embd.unsqueeze(0).expand(B, -1, -1, -1).contiguous()
condition = self._inject_patched_pos_embd(condition, pos_embd, B)
return self.model(x, t, condition=condition, **model_kwargs)

P = self._patching.patch_num
P = patching.patch_num

# Determine original batch size B
if x_is_patched:
Expand All @@ -935,35 +928,23 @@ def forward(
if not condition_is_patched:
condition = self.patch_condition(condition)

# Positional embeddings are always injected (internal to the wrapper)
if self.pos_embd is not None:
pos_embd_patched = self._patching.apply(
self.pos_embd.unsqueeze(0).expand(B, -1, -1, -1)
# Positional embeddings injected here unless _skip_positional_embedding_injection
# is set (e.g. by MultiDiffusionPredictor which pre-patches PE at construction time)
if self.pos_embd is not None and not self._skip_positional_embedding_injection:
# .expand creates a stride-0 view that can trip up downstream
# torch ops (e.g. nn.ReflectionPad2d / F.unfold on torch 2.10).
# Materialise a contiguous copy before passing to patching.
pos_embd_patched = patching.apply(
self.pos_embd.unsqueeze(0).expand(B, -1, -1, -1).contiguous()
) # (P*B, C_PE, Hp, Wp)
PB = P * B
if condition is None:
condition = TensorDict(
{"positional_embedding": pos_embd_patched},
batch_size=[PB],
)
elif isinstance(condition, TensorDict):
condition["positional_embedding"] = pos_embd_patched
elif isinstance(condition, Tensor):
condition = TensorDict(
{"condition": condition, "positional_embedding": pos_embd_patched},
batch_size=[PB],
)
else:
raise ValueError(
"When positional embeddings are configured, condition "
"must be a Tensor, TensorDict, or None, got "
f"{type(condition).__name__}."
)
condition = self._inject_patched_pos_embd(
condition, pos_embd_patched, P * B
)

output = self.model(x, t, condition=condition, **model_kwargs)

if self._fuse:
output = self._patching.fuse(output, batch_size=B)
output = patching.fuse(output, batch_size=B)

return output

Expand Down Expand Up @@ -996,13 +977,54 @@ def _process_condition_tensor(
f"(B, C, H, W), got {tensor.ndim}D."
)

# Default case: no patching needed, just repeat along the batch dim.
if not do_patch and not do_interp:
return tensor.repeat(P, *([1] * (tensor.ndim - 1)))

# Both patch and interp need an active patching strategy.
patching = self._patching
if patching is None:
raise RuntimeError(
"No patching strategy set. Call set_random_patching() "
"or set_grid_patching() first."
)

if do_patch:
return self._patching.apply(tensor)
return patching.apply(tensor)

if do_interp:
Hp, Wp = self._patching.patch_shape
tensor = F.interpolate(tensor, size=(Hp, Wp), mode="bilinear")
return tensor.repeat(P, 1, 1, 1)
# do_interp case
Hp, Wp = patching.patch_shape
tensor = F.interpolate(tensor, size=(Hp, Wp), mode="bilinear")
return tensor.repeat(P, 1, 1, 1)

# Default: repeat along batch dimension
return tensor.repeat(P, *([1] * (tensor.ndim - 1)))
def _inject_patched_pos_embd(
self,
condition: Tensor | TensorDict | None,
pos_embd_patched: Float[Tensor, "P_times_B C_PE Hp Wp"],
PB: int,
) -> TensorDict:
"""Inject an already-patched positional embedding into the (possibly
already-patched) condition under the ``"positional_embedding"`` key.

Common logic factored out of :meth:`forward` so it can be reused by
:class:`~physicsnemo.diffusion.multi_diffusion.MultiDiffusionPredictor`.
When ``condition`` is a ``TensorDict`` it is mutated in place for
efficiency; otherwise a new ``TensorDict`` is built.
"""
if condition is None:
return TensorDict(
{"positional_embedding": pos_embd_patched},
batch_size=[PB],
)
if isinstance(condition, TensorDict):
condition["positional_embedding"] = pos_embd_patched
return condition
if isinstance(condition, Tensor):
return TensorDict(
{"condition": condition, "positional_embedding": pos_embd_patched},
batch_size=[PB],
)
raise ValueError(
"When positional embeddings are configured, condition must be a "
f"Tensor, TensorDict, or None, got {type(condition).__name__}."
)
13 changes: 9 additions & 4 deletions physicsnemo/diffusion/multi_diffusion/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,12 +725,17 @@ def image_batching(
padded_shape_y = stride_y * (patch_num_y - 1) + patch_shape_y + boundary_pix
patch_num = patch_num_x * patch_num_y

# Reflection-pad to fit the grid
# Reflection-pad to fit the grid. Use the functional form (not
# ``torch.nn.ReflectionPad2d(...)(input)``) to avoid instantiating a fresh
# nn.Module on every call, which is much less friendly to ``torch.compile``
# / AOT autograd tracing.
pad_x_right = padded_shape_x - img_shape_x - boundary_pix
pad_y_right = padded_shape_y - img_shape_y - boundary_pix
input_padded = torch.nn.ReflectionPad2d(
(boundary_pix, pad_x_right, boundary_pix, pad_y_right)
)(input)
input_padded = torch.nn.functional.pad(
input,
(boundary_pix, pad_x_right, boundary_pix, pad_y_right),
mode="reflect",
)

# Integer dtypes are not supported by unfold — cast temporarily
if input.dtype == torch.int32:
Expand Down
Loading
Loading