diff --git a/test/unit/test_linalg.py b/test/unit/test_linalg.py index 9bdff0b5cf..2e97ad6c9b 100644 --- a/test/unit/test_linalg.py +++ b/test/unit/test_linalg.py @@ -83,6 +83,15 @@ class TestLinAlg(unittest.TestCase): s_diag = (S.unsqueeze(-2) * Tensor.eye(2)) reconstruction_helper([U, s_diag, V], a) + def test_svd_identity_4x4(self): + a = Tensor.eye(4) + U,S,V = a.svd() + assert not np.isnan(U.numpy()).any() + assert not np.isnan(S.numpy()).any() + assert not np.isnan(V.numpy()).any() + s_diag = (S.unsqueeze(-2) * Tensor.eye(4)) + reconstruction_helper([U, s_diag, V], a) + def test_svd_rank1(self): a = Tensor([[1.0, 1.0], [2.0, 2.0]]).realize() U, S, V = a.svd() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 605142ed9d..21c6fbb6fc 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3462,8 +3462,9 @@ class Tensor(OpMixin): #preprocess the matrix Q, R = (self.qr() if m >= n else self.transpose(-2, -1).qr()) num, q_num = min(m, n), max(m, n) - U = R.shrink(tuple([None] * len(b_shape) + [(0, num), (0, num)])) - V = Tensor.eye(num, dtype=self.dtype).reshape((1,) * len(b_shape) + (num, num)).expand(b_shape + (num, num)) + # TODO: codegen infinite loop without contiguous + U = R.shrink(tuple([None] * len(b_shape) + [(0, num), (0, num)])).contiguous() + V = Tensor.eye(num, dtype=self.dtype).reshape((1,) * len(b_shape) + (num, num)).expand(b_shape + (num, num)).contiguous() #prepare round robin pairing permute, inverse_permute = Tensor.arange(0, num, dtype=dtypes.int), Tensor.zeros(num, dtype=dtypes.int) permute[num//2:num] = permute[num//2:num].flip(0)