mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
Fix QR/SVD NaNs on zero/orthogonal inputs (#13943)
This commit is contained in:
@@ -65,6 +65,24 @@ class TestLinAlg(unittest.TestCase):
|
||||
orthogonality_helper(Q)
|
||||
reconstruction_helper([Q,R],a)
|
||||
|
||||
def test_qr_zero_column(self):
|
||||
a = Tensor([[0.0, 1.0], [0.0, 2.0]]).realize()
|
||||
Q,R = a.qr()
|
||||
assert not np.isnan(Q.numpy()).any()
|
||||
assert not np.isnan(R.numpy()).any()
|
||||
orthogonality_helper(Q)
|
||||
reconstruction_helper([Q,R], a)
|
||||
|
||||
def test_svd_identity(self):
|
||||
for a in (Tensor.eye(2), Tensor.zeros(2, 2)):
|
||||
a = a.realize()
|
||||
U,S,V = a.svd()
|
||||
assert not np.isnan(U.numpy()).any()
|
||||
assert not np.isnan(S.numpy()).any()
|
||||
assert not np.isnan(V.numpy()).any()
|
||||
s_diag = (S.unsqueeze(-2) * Tensor.eye(2))
|
||||
reconstruction_helper([U, s_diag, V], a)
|
||||
|
||||
def test_newton_schulz(self):
|
||||
coefficients = [(2, -1.5, 0.5), (2.0, -1.4, 0.2, 0.2)]#these params map to the sign function
|
||||
sizes = [(2,2), (3,2), (2,3), (2,2,2)]
|
||||
|
||||
Reference in New Issue
Block a user