add contiguous back to svd (#15074)

can cause infinite loop
This commit is contained in:
chenyu
2026-02-28 16:49:26 -05:00
committed by GitHub
parent fe0fa8333b
commit 103ea16ec0
2 changed files with 12 additions and 2 deletions

View File

@@ -83,6 +83,15 @@ class TestLinAlg(unittest.TestCase):
s_diag = (S.unsqueeze(-2) * Tensor.eye(2))
reconstruction_helper([U, s_diag, V], a)
def test_svd_identity_4x4(self):
a = Tensor.eye(4)
U,S,V = a.svd()
assert not np.isnan(U.numpy()).any()
assert not np.isnan(S.numpy()).any()
assert not np.isnan(V.numpy()).any()
s_diag = (S.unsqueeze(-2) * Tensor.eye(4))
reconstruction_helper([U, s_diag, V], a)
def test_svd_rank1(self):
a = Tensor([[1.0, 1.0], [2.0, 2.0]]).realize()
U, S, V = a.svd()

View File

@@ -3462,8 +3462,9 @@ class Tensor(OpMixin):
#preprocess the matrix
Q, R = (self.qr() if m >= n else self.transpose(-2, -1).qr())
num, q_num = min(m, n), max(m, n)
U = R.shrink(tuple([None] * len(b_shape) + [(0, num), (0, num)]))
V = Tensor.eye(num, dtype=self.dtype).reshape((1,) * len(b_shape) + (num, num)).expand(b_shape + (num, num))
# TODO: codegen infinite loop without contiguous
U = R.shrink(tuple([None] * len(b_shape) + [(0, num), (0, num)])).contiguous()
V = Tensor.eye(num, dtype=self.dtype).reshape((1,) * len(b_shape) + (num, num)).expand(b_shape + (num, num)).contiguous()
#prepare round robin pairing
permute, inverse_permute = Tensor.arange(0, num, dtype=dtypes.int), Tensor.zeros(num, dtype=dtypes.int)
permute[num//2:num] = permute[num//2:num].flip(0)