From 526fd4ec7104eda1ef8114e64d99b2788910a8fd Mon Sep 17 00:00:00 2001 From: haofei Date: Wed, 31 Dec 2025 21:30:18 -0800 Subject: [PATCH] =?UTF-8?q?Fix=20SVD=20rank=E2=80=911=20Jacobi=20rotation?= =?UTF-8?q?=20when=20tau=20=3D=3D=200=20(#13945)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/unit/test_linalg.py | 6 ++++++ tinygrad/tensor.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/test/unit/test_linalg.py b/test/unit/test_linalg.py index d9c5f52507..5bc33590b1 100644 --- a/test/unit/test_linalg.py +++ b/test/unit/test_linalg.py @@ -83,6 +83,12 @@ class TestLinAlg(unittest.TestCase): s_diag = (S.unsqueeze(-2) * Tensor.eye(2)) reconstruction_helper([U, s_diag, V], a) + def test_svd_rank1(self): + a = Tensor([[1.0, 1.0], [2.0, 2.0]]).realize() + U, S, V = a.svd() + np.testing.assert_allclose(S.numpy(), [np.sqrt(10), 0.0], atol=1e-4, rtol=1e-4) + reconstruction_helper([U, S.unsqueeze(-2) * Tensor.eye(2), V], a) + def test_newton_schulz(self): coefficients = [(2, -1.5, 0.5), (2.0, -1.4, 0.2, 0.2)]#these params map to the sign function sizes = [(2,2), (3,2), (2,3), (2,2,2)] diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 736e6868de..4e650acb49 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3672,7 +3672,7 @@ class Tensor(OpMixin): alpha, beta = U_permuted.square().sum(-2).unsqueeze(-2).split(num//2, -1) rot = gamma != 0 tau = (beta - alpha) / (2 * rot.where(gamma, 1)) - t = tau.sign() / (tau.abs() + (1 + tau.square()).sqrt()) + t = (tau != 0).where(tau.sign(), 1) / (tau.abs() + (1 + tau.square()).sqrt()) t = rot.where(t, 0) c = 1 / (1 + t.square()).sqrt() s = c * t