Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def _aggregate_patch_parameters(
if isinstance(param, torch.nn.Parameter) and type(param.data) is torch.Tensor:
pass
elif type(param) is torch.Tensor:
# Plain tensor (e.g. after cast_to_device moved a Parameter to another device).
pass
elif type(param) is GGMLTensor:
# Move to device and dequantize here. Doing it in the patch layer can result in redundant casts /
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.dora_layer import DoRALayer
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.lokr_layer import LoKRLayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
Expand Down Expand Up @@ -346,6 +347,7 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La
"concatenated_lora",
"flux_control_lora",
"single_lokr",
"single_dora",
]
)
def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest:
Expand Down Expand Up @@ -432,6 +434,20 @@ def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest:
)
input = torch.randn(1, in_features)
return ([(lokr_layer, 0.7)], input)
elif layer_type == "single_dora":
# Regression coverage for #8624: DoRA + partial-loading + CPU->device autocast.
# Scaled down so the patched weight stays well-conditioned for allclose comparisons.
# dora_scale has shape (1, in_features) to broadcast against direction_norm in
# DoRALayer.get_weight — see dora_layer.py:74-82.
dora_layer = DoRALayer(
up=torch.randn(out_features, rank) * 0.01,
down=torch.randn(rank, in_features) * 0.01,
dora_scale=torch.ones(1, in_features),
alpha=1.0,
bias=torch.randn(out_features) * 0.01,
)
input = torch.randn(1, in_features)
return ([(dora_layer, 0.7)], input)
else:
raise ValueError(f"Unsupported layer_type: {layer_type}")

Expand Down Expand Up @@ -676,3 +692,45 @@ def test_conv2d_mixed_dtype_sidecar_parameter_patch(dtype: torch.dtype):

assert output.dtype == input.dtype
assert output.shape == (2, 16, 3, 3)


@torch.no_grad()
def test_aggregate_patch_parameters_preserves_plain_tensor_with_dora():
"""Regression test for #8624: when partial-loading autocasts a CPU Parameter onto the
compute device, cast_to_device returns a plain torch.Tensor (not a Parameter). The
aggregator must treat that as a real tensor and not substitute a meta-device dummy —
otherwise DoRA's quantization guard falsely triggers on non-quantized base models.

This test is CPU-only and simulates the hand-off by constructing a plain torch.Tensor
directly; the equivalent CUDA/MPS E2E flow is exercised by the "single_dora" variant
of test_linear_sidecar_patches_with_autocast_from_cpu_to_device.
"""
layer = wrap_single_custom_layer(torch.nn.Linear(32, 64))

rank = 4
dora_patch = DoRALayer(
up=torch.randn(64, rank) * 0.01,
down=torch.randn(rank, 32) * 0.01,
dora_scale=torch.ones(1, 32),
alpha=1.0,
bias=None,
)

# Plain torch.Tensor — the shape _cast_weight_bias_for_input hands into
# _aggregate_patch_parameters after autocasting a Parameter across devices.
plain_weight = torch.randn(64, 32)
assert type(plain_weight) is torch.Tensor

orig_params = {"weight": plain_weight}
params = layer._aggregate_patch_parameters(
patches_and_weights=[(dora_patch, 1.0)],
orig_params=orig_params,
device=torch.device("cpu"),
)

# Pre-fix, orig_params["weight"] would have been replaced by a meta-device dummy,
# causing DoRALayer.get_parameters to raise "not compatible with DoRA patches".
assert orig_params["weight"].device.type == "cpu"
assert params["weight"].shape == (64, 32)
assert params["weight"].device.type == "cpu"
assert not torch.isnan(params["weight"]).any()
Loading