Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
ae8102e
Implements Predictor specialization for multi-diffusion
CharlelieLrt Apr 17, 2026
9bd0474
Compile denoiser in multi-diffusion sampling compile tests
CharlelieLrt Apr 17, 2026
9c84b00
Avoid fullgraph compile in multi-diffusion sampling test
CharlelieLrt Apr 17, 2026
f8b906a
Flatten MultiDiffusionPredictor hot path for torch.compile
CharlelieLrt Apr 17, 2026
1259e43
Loosen torch.compile euler check in multi-diffusion sampling tests
CharlelieLrt Apr 17, 2026
cefe1db
Force contiguous t_cur/t_next in Euler solvers
CharlelieLrt Apr 17, 2026
a000e1c
Drop dead is_compiling guard and inherit from Predictor in MultiDiffu…
CharlelieLrt Apr 18, 2026
8f945a5
Narrow _patching type and tighten multi-diffusion tests
CharlelieLrt Apr 18, 2026
7e1db11
Force contiguous pos_embd before patching
CharlelieLrt Apr 18, 2026
feb0d9e
Use functional F.pad in image_batching
CharlelieLrt Apr 20, 2026
3dfcdb5
Replace einops.rearrange with native torch reshape+permute
CharlelieLrt Apr 20, 2026
746518f
Materialise returned tensors in multi-diffusion fuse path
CharlelieLrt Apr 20, 2026
a007c46
Use clone instead of contiguous at fuse boundary
CharlelieLrt Apr 20, 2026
869a7c1
Revert speculative fuse-boundary copies and xfail full-sampler compil…
CharlelieLrt Apr 20, 2026
78c10d8
Merge remote-tracking branch 'upstream/main' into multi_diffusion_sam…
CharlelieLrt Apr 20, 2026
3207587
Minor updates to predictor.py
CharlelieLrt Apr 22, 2026
2759f57
Merge remote-tracking branch 'upstream/main' into multi_diffusion_sam…
CharlelieLrt Apr 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion physicsnemo/diffusion/guidance/dps_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def __call__(
...


class DPSScorePredictor:
class DPSScorePredictor(Predictor):
r"""
Score predictor that combines an x0-predictor with DPS-style guidance.

Expand Down
1 change: 1 addition & 0 deletions physicsnemo/diffusion/multi_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
image_batching,
image_fuse,
)
from .predictor import MultiDiffusionPredictor
29 changes: 3 additions & 26 deletions physicsnemo/diffusion/multi_diffusion/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,7 @@

from physicsnemo.diffusion.multi_diffusion.models import MultiDiffusionModel2D
from physicsnemo.diffusion.noise_schedulers import NoiseScheduler
from physicsnemo.diffusion.utils.utils import apply_loss_weight


def _unwrap_multi_diffusion(model: torch.nn.Module) -> MultiDiffusionModel2D:
"""Peel off DDP / torch.compile wrappers to reach the underlying
:class:`MultiDiffusionModel2D`.

The unwrapping order handles arbitrary nesting of
``DistributedDataParallel`` (``model.module``) and ``torch.compile``
(``OptimizedModule._orig_mod``).
"""
m = model
while not isinstance(m, MultiDiffusionModel2D):
if isinstance(m, torch._dynamo.eval_frame.OptimizedModule):
m = m._orig_mod
elif hasattr(m, "module"):
m = m.module
else:
raise TypeError(
f"Could not unwrap a MultiDiffusionModel2D from "
f"{type(model).__name__}. Found leaf type "
f"{type(m).__name__}."
)
return m
from physicsnemo.diffusion.utils.utils import _unwrap_module, apply_loss_weight


class _CompiledPatchX:
Expand Down Expand Up @@ -256,7 +233,7 @@ def __init__(
reduction: Literal["none", "mean", "sum"] = "mean",
) -> None:
self.model = model
self._md_model = _unwrap_multi_diffusion(model)
self._md_model = _unwrap_module(model, MultiDiffusionModel2D)
self.noise_scheduler = noise_scheduler
self._compiled_patch_x = _CompiledPatchX(self._md_model)

Expand Down Expand Up @@ -504,7 +481,7 @@ def __init__(
reduction: Literal["none", "mean", "sum"] = "mean",
) -> None:
self.model = model
self._md_model = _unwrap_multi_diffusion(model)
self._md_model = _unwrap_module(model, MultiDiffusionModel2D)
self.noise_scheduler = noise_scheduler
self._compiled_patch_x = _CompiledPatchX(self._md_model)

Expand Down
83 changes: 44 additions & 39 deletions physicsnemo/diffusion/multi_diffusion/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def __init__(
self._patching: RandomPatching2D | GridPatching2D | None = None
self._patching_type: Literal["random", "grid"] | None = None
self._fuse: bool = False
self._skip_positional_embedding_injection: bool = False
# Normalise condition flags to defaultdict for uniform access
if not isinstance(condition_patch, (bool, dict)):
raise TypeError(
Expand Down Expand Up @@ -890,26 +891,13 @@ def forward(
"The model will run without patching.",
stacklevel=2,
)
if self.pos_embd is not None:
if (
self.pos_embd is not None
and not self._skip_positional_embedding_injection
):
B = x.shape[0]
pos_embd = self.pos_embd.unsqueeze(0).expand(B, -1, -1, -1)
if condition is None:
condition = TensorDict(
{"positional_embedding": pos_embd}, batch_size=[B]
)
elif isinstance(condition, TensorDict):
condition["positional_embedding"] = pos_embd
elif isinstance(condition, Tensor):
condition = TensorDict(
{"condition": condition, "positional_embedding": pos_embd},
batch_size=[B],
)
else:
raise ValueError(
"When positional embeddings are configured, condition "
"must be a Tensor, TensorDict, or None, got "
f"{type(condition).__name__}."
)
condition = self._inject_patched_pos_embd(condition, pos_embd, B)
return self.model(x, t, condition=condition, **model_kwargs)

P = self._patching.patch_num
Expand All @@ -935,30 +923,15 @@ def forward(
if not condition_is_patched:
condition = self.patch_condition(condition)

# Positional embeddings are always injected (internal to the wrapper)
if self.pos_embd is not None:
# Positional embeddings injected here unless _skip_positional_embedding_injection
# is set (e.g. by MultiDiffusionPredictor which pre-patches PE at construction time)
if self.pos_embd is not None and not self._skip_positional_embedding_injection:
pos_embd_patched = self._patching.apply(
self.pos_embd.unsqueeze(0).expand(B, -1, -1, -1)
) # (P*B, C_PE, Hp, Wp)
PB = P * B
if condition is None:
condition = TensorDict(
{"positional_embedding": pos_embd_patched},
batch_size=[PB],
)
elif isinstance(condition, TensorDict):
condition["positional_embedding"] = pos_embd_patched
elif isinstance(condition, Tensor):
condition = TensorDict(
{"condition": condition, "positional_embedding": pos_embd_patched},
batch_size=[PB],
)
else:
raise ValueError(
"When positional embeddings are configured, condition "
"must be a Tensor, TensorDict, or None, got "
f"{type(condition).__name__}."
)
condition = self._inject_patched_pos_embd(
condition, pos_embd_patched, P * B
)

output = self.model(x, t, condition=condition, **model_kwargs)

Expand Down Expand Up @@ -1006,3 +979,35 @@ def _process_condition_tensor(

# Default: repeat along batch dimension
return tensor.repeat(P, *([1] * (tensor.ndim - 1)))

def _inject_patched_pos_embd(
self,
condition: Tensor | TensorDict | None,
pos_embd_patched: Float[Tensor, "P_times_B C_PE Hp Wp"],
PB: int,
) -> TensorDict:
"""Inject an already-patched positional embedding into the (possibly
already-patched) condition under the ``"positional_embedding"`` key.

Common logic factored out of :meth:`forward` so it can be reused by
:class:`~physicsnemo.diffusion.multi_diffusion.MultiDiffusionPredictor`.
When ``condition`` is a ``TensorDict`` it is mutated in place for
efficiency; otherwise a new ``TensorDict`` is built.
"""
if condition is None:
return TensorDict(
{"positional_embedding": pos_embd_patched},
batch_size=[PB],
)
if isinstance(condition, TensorDict):
condition["positional_embedding"] = pos_embd_patched
return condition
if isinstance(condition, Tensor):
return TensorDict(
{"condition": condition, "positional_embedding": pos_embd_patched},
batch_size=[PB],
)
raise ValueError(
"When positional embeddings are configured, condition must be a "
f"Tensor, TensorDict, or None, got {type(condition).__name__}."
)
Loading
Loading