mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Fix SVD rank‑1 Jacobi rotation when tau == 0 (#13945)
This commit is contained in:
@@ -83,6 +83,12 @@ class TestLinAlg(unittest.TestCase):
|
||||
s_diag = (S.unsqueeze(-2) * Tensor.eye(2))
|
||||
reconstruction_helper([U, s_diag, V], a)
|
||||
|
||||
def test_svd_rank1(self):
|
||||
a = Tensor([[1.0, 1.0], [2.0, 2.0]]).realize()
|
||||
U, S, V = a.svd()
|
||||
np.testing.assert_allclose(S.numpy(), [np.sqrt(10), 0.0], atol=1e-4, rtol=1e-4)
|
||||
reconstruction_helper([U, S.unsqueeze(-2) * Tensor.eye(2), V], a)
|
||||
|
||||
def test_newton_schulz(self):
|
||||
coefficients = [(2, -1.5, 0.5), (2.0, -1.4, 0.2, 0.2)]#these params map to the sign function
|
||||
sizes = [(2,2), (3,2), (2,3), (2,2,2)]
|
||||
|
||||
@@ -3672,7 +3672,7 @@ class Tensor(OpMixin):
|
||||
alpha, beta = U_permuted.square().sum(-2).unsqueeze(-2).split(num//2, -1)
|
||||
rot = gamma != 0
|
||||
tau = (beta - alpha) / (2 * rot.where(gamma, 1))
|
||||
t = tau.sign() / (tau.abs() + (1 + tau.square()).sqrt())
|
||||
t = (tau != 0).where(tau.sign(), 1) / (tau.abs() + (1 + tau.square()).sqrt())
|
||||
t = rot.where(t, 0)
|
||||
c = 1 / (1 + t.square()).sqrt()
|
||||
s = c * t
|
||||
|
||||
Reference in New Issue
Block a user