diff --git a/deepcompressor/nn/patch/lowrank.py b/deepcompressor/nn/patch/lowrank.py index 6247741..d7fb8b9 100644 --- a/deepcompressor/nn/patch/lowrank.py +++ b/deepcompressor/nn/patch/lowrank.py @@ -46,11 +46,12 @@ def reset_parameters(self, weight: torch.Tensor | None = None) -> None: if self.rank < 0: self.a.weight.data.copy_(weight) elif self.rank > 0: - u, s, vh = torch.linalg.svd(weight.double()) - # tensor: [oc, ic], u: [oc, oc], s: [oc], vh: [ic, ic] + buffer = 10 + u, s, v = torch.svd_lowrank(weight.double(), q=self.rank + buffer, niter=4) + # tensor: [oc, ic], u: [oc, q], s: [q], v: [ic, q] # us: [oc, rank], vh: [rank, ic] us = u[:, : self.rank] * s[: self.rank] - vh = vh[: self.rank] + vh = v[:, : self.rank].t() assert not us.isnan().any(), "NaN in U * S" assert not vh.isnan().any(), "NaN in V^T" assert not us.isinf().any(), "Inf in U * S"