mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-14 09:28:04 -05:00
clean up svd tests (#14133)
removed from test_ops and added to TestTorchBackend
This commit is contained in:
@@ -191,6 +191,19 @@ class TestTorchBackend(unittest.TestCase):
|
|||||||
assert torch.equal(tensor_a, tensor_b)
|
assert torch.equal(tensor_a, tensor_b)
|
||||||
assert not torch.equal(tensor_a, tensor_c)
|
assert not torch.equal(tensor_a, tensor_c)
|
||||||
|
|
||||||
|
def test_linalg_svd(self):
|
||||||
|
A = torch.randn(5, 5, device=device)
|
||||||
|
U, S, Vh = torch.linalg.svd(A)
|
||||||
|
np.testing.assert_equal(U.shape, (5,5))
|
||||||
|
np.testing.assert_equal(Vh.shape, (5,5))
|
||||||
|
np.testing.assert_allclose(torch.dist(A, U @ torch.diag(S) @ Vh).cpu().numpy(), 0, atol=1e-5)
|
||||||
|
|
||||||
|
A = torch.randn(5, 3, device=device)
|
||||||
|
U, S, Vh = torch.linalg.svd(A, full_matrices=False)
|
||||||
|
np.testing.assert_equal(U.shape, (5,3))
|
||||||
|
np.testing.assert_equal(Vh.shape, (3,3))
|
||||||
|
np.testing.assert_allclose(torch.dist(A, U @ torch.diag(S) @ Vh).cpu().numpy(), 0, atol=1e-5)
|
||||||
|
|
||||||
def test_linalg_eigh(self):
|
def test_linalg_eigh(self):
|
||||||
a = torch.tensor([[1, 2], [2, 1]], dtype=torch.float32, device=device)
|
a = torch.tensor([[1, 2], [2, 1]], dtype=torch.float32, device=device)
|
||||||
w, v = torch.linalg.eigh(a)
|
w, v = torch.linalg.eigh(a)
|
||||||
|
|||||||
@@ -3278,21 +3278,6 @@ class TestOps(unittest.TestCase):
|
|||||||
def test_bitcast(self):
|
def test_bitcast(self):
|
||||||
helper_test_op([(3, 3)], lambda x: x.view(torch.int32), lambda x: x.bitcast(dtypes.int32), forward_only=True)
|
helper_test_op([(3, 3)], lambda x: x.view(torch.int32), lambda x: x.bitcast(dtypes.int32), forward_only=True)
|
||||||
|
|
||||||
@unittest.skip("we have test_linalg, no need to test here. TODO: should be in torch backend tests")
|
|
||||||
def test_svd(self):
|
|
||||||
# test for tiny backend. real svd tests are in test_linalg
|
|
||||||
A = torch.randn(5, 5)
|
|
||||||
U, S, Vh = torch.linalg.svd(A)
|
|
||||||
np.testing.assert_equal(U.shape, (5,5))
|
|
||||||
np.testing.assert_equal(Vh.shape, (5,5))
|
|
||||||
np.testing.assert_allclose(torch.dist(A, U @ torch.diag(S) @ Vh).cpu().numpy(), 0, atol=1e-5)
|
|
||||||
|
|
||||||
A = torch.randn(5, 3)
|
|
||||||
U, S, Vh = torch.linalg.svd(A, full_matrices=False)
|
|
||||||
np.testing.assert_equal(U.shape, (5,3))
|
|
||||||
np.testing.assert_equal(Vh.shape, (3,3))
|
|
||||||
np.testing.assert_allclose(torch.dist(A, U @ torch.diag(S) @ Vh).cpu().numpy(), 0, atol=1e-5)
|
|
||||||
|
|
||||||
@unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}")
|
@unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}")
|
||||||
class TestOpsUint8(unittest.TestCase):
|
class TestOpsUint8(unittest.TestCase):
|
||||||
def test_cast(self):
|
def test_cast(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user