mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -83,6 +83,15 @@ class TestLinAlg(unittest.TestCase):
|
||||
s_diag = (S.unsqueeze(-2) * Tensor.eye(2))
|
||||
reconstruction_helper([U, s_diag, V], a)
|
||||
|
||||
def test_svd_identity_4x4(self):
|
||||
a = Tensor.eye(4)
|
||||
U,S,V = a.svd()
|
||||
assert not np.isnan(U.numpy()).any()
|
||||
assert not np.isnan(S.numpy()).any()
|
||||
assert not np.isnan(V.numpy()).any()
|
||||
s_diag = (S.unsqueeze(-2) * Tensor.eye(4))
|
||||
reconstruction_helper([U, s_diag, V], a)
|
||||
|
||||
def test_svd_rank1(self):
|
||||
a = Tensor([[1.0, 1.0], [2.0, 2.0]]).realize()
|
||||
U, S, V = a.svd()
|
||||
|
||||
@@ -3462,8 +3462,9 @@ class Tensor(OpMixin):
|
||||
#preprocess the matrix
|
||||
Q, R = (self.qr() if m >= n else self.transpose(-2, -1).qr())
|
||||
num, q_num = min(m, n), max(m, n)
|
||||
U = R.shrink(tuple([None] * len(b_shape) + [(0, num), (0, num)]))
|
||||
V = Tensor.eye(num, dtype=self.dtype).reshape((1,) * len(b_shape) + (num, num)).expand(b_shape + (num, num))
|
||||
# TODO: codegen infinite loop without contiguous
|
||||
U = R.shrink(tuple([None] * len(b_shape) + [(0, num), (0, num)])).contiguous()
|
||||
V = Tensor.eye(num, dtype=self.dtype).reshape((1,) * len(b_shape) + (num, num)).expand(b_shape + (num, num)).contiguous()
|
||||
#prepare round robin pairing
|
||||
permute, inverse_permute = Tensor.arange(0, num, dtype=dtypes.int), Tensor.zeros(num, dtype=dtypes.int)
|
||||
permute[num//2:num] = permute[num//2:num].flip(0)
|
||||
|
||||
Reference in New Issue
Block a user