Fix QR/SVD NaNs on zero/orthogonal inputs (#13943)

This commit is contained in:
haofei
2025-12-31 20:40:09 -08:00
committed by GitHub
parent 0ed58c1fcd
commit 20777f30b9
2 changed files with 29 additions and 6 deletions

View File

@@ -65,6 +65,24 @@ class TestLinAlg(unittest.TestCase):
orthogonality_helper(Q)
reconstruction_helper([Q,R],a)
def test_qr_zero_column(self):
a = Tensor([[0.0, 1.0], [0.0, 2.0]]).realize()
Q,R = a.qr()
assert not np.isnan(Q.numpy()).any()
assert not np.isnan(R.numpy()).any()
orthogonality_helper(Q)
reconstruction_helper([Q,R], a)
def test_svd_identity(self):
for a in (Tensor.eye(2), Tensor.zeros(2, 2)):
a = a.realize()
U,S,V = a.svd()
assert not np.isnan(U.numpy()).any()
assert not np.isnan(S.numpy()).any()
assert not np.isnan(V.numpy()).any()
s_diag = (S.unsqueeze(-2) * Tensor.eye(2))
reconstruction_helper([U, s_diag, V], a)
def test_newton_schulz(self):
coefficients = [(2, -1.5, 0.5), (2.0, -1.4, 0.2, 0.2)]#these params map to the sign function
sizes = [(2,2), (3,2), (2,3), (2,2,2)]