clean up svd tests (#14133)

removed from test_ops and added to TestTorchBackend
This commit is contained in:
chenyu
2026-01-13 16:32:21 -05:00
committed by GitHub
parent 84b88a0a31
commit fe00682502
2 changed files with 13 additions and 15 deletions

View File

@@ -191,6 +191,19 @@ class TestTorchBackend(unittest.TestCase):
assert torch.equal(tensor_a, tensor_b) assert torch.equal(tensor_a, tensor_b)
assert not torch.equal(tensor_a, tensor_c) assert not torch.equal(tensor_a, tensor_c)
def test_linalg_svd(self):
A = torch.randn(5, 5, device=device)
U, S, Vh = torch.linalg.svd(A)
np.testing.assert_equal(U.shape, (5,5))
np.testing.assert_equal(Vh.shape, (5,5))
np.testing.assert_allclose(torch.dist(A, U @ torch.diag(S) @ Vh).cpu().numpy(), 0, atol=1e-5)
A = torch.randn(5, 3, device=device)
U, S, Vh = torch.linalg.svd(A, full_matrices=False)
np.testing.assert_equal(U.shape, (5,3))
np.testing.assert_equal(Vh.shape, (3,3))
np.testing.assert_allclose(torch.dist(A, U @ torch.diag(S) @ Vh).cpu().numpy(), 0, atol=1e-5)
def test_linalg_eigh(self): def test_linalg_eigh(self):
a = torch.tensor([[1, 2], [2, 1]], dtype=torch.float32, device=device) a = torch.tensor([[1, 2], [2, 1]], dtype=torch.float32, device=device)
w, v = torch.linalg.eigh(a) w, v = torch.linalg.eigh(a)

View File

@@ -3278,21 +3278,6 @@ class TestOps(unittest.TestCase):
def test_bitcast(self): def test_bitcast(self):
helper_test_op([(3, 3)], lambda x: x.view(torch.int32), lambda x: x.bitcast(dtypes.int32), forward_only=True) helper_test_op([(3, 3)], lambda x: x.view(torch.int32), lambda x: x.bitcast(dtypes.int32), forward_only=True)
@unittest.skip("we have test_linalg, no need to test here. TODO: should be in torch backend tests")
def test_svd(self):
# test for tiny backend. real svd tests are in test_linalg
A = torch.randn(5, 5)
U, S, Vh = torch.linalg.svd(A)
np.testing.assert_equal(U.shape, (5,5))
np.testing.assert_equal(Vh.shape, (5,5))
np.testing.assert_allclose(torch.dist(A, U @ torch.diag(S) @ Vh).cpu().numpy(), 0, atol=1e-5)
A = torch.randn(5, 3)
U, S, Vh = torch.linalg.svd(A, full_matrices=False)
np.testing.assert_equal(U.shape, (5,3))
np.testing.assert_equal(Vh.shape, (3,3))
np.testing.assert_allclose(torch.dist(A, U @ torch.diag(S) @ Vh).cpu().numpy(), 0, atol=1e-5)
@unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}") @unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}")
class TestOpsUint8(unittest.TestCase): class TestOpsUint8(unittest.TestCase):
def test_cast(self): def test_cast(self):