From a9cc9305e5a9193b184e80ad99649d3fdff247af Mon Sep 17 00:00:00 2001 From: lantudou <78593615@qq.com> Date: Wed, 24 Dec 2025 16:23:22 +0800 Subject: [PATCH] Perf: Replace full SVD with torch.svd_lowrank for acceleration --- deepcompressor/nn/patch/lowrank.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/deepcompressor/nn/patch/lowrank.py b/deepcompressor/nn/patch/lowrank.py index 6247741..d7fb8b9 100644 --- a/deepcompressor/nn/patch/lowrank.py +++ b/deepcompressor/nn/patch/lowrank.py @@ -46,11 +46,12 @@ def reset_parameters(self, weight: torch.Tensor | None = None) -> None: if self.rank < 0: self.a.weight.data.copy_(weight) elif self.rank > 0: - u, s, vh = torch.linalg.svd(weight.double()) - # tensor: [oc, ic], u: [oc, oc], s: [oc], vh: [ic, ic] + buffer = 10 + u, s, v = torch.svd_lowrank(weight.double(), q=self.rank + buffer, niter=4) + # tensor: [oc, ic], u: [oc, q], s: [q], v: [ic, q] # us: [oc, rank], vh: [rank, ic] us = u[:, : self.rank] * s[: self.rank] - vh = vh[: self.rank] + vh = v[:, : self.rank].t() assert not us.isnan().any(), "NaN in U * S" assert not vh.isnan().any(), "NaN in V^T" assert not us.isinf().any(), "Inf in U * S"