pipe linalg svd to torch (#11109)

and found a bug in svd
This commit is contained in:
chenyu
2025-07-06 08:37:25 -04:00
committed by GitHub
parent 845a4d32bc
commit ba88ec3ad0
2 changed files with 19 additions and 0 deletions

View File

@@ -3058,6 +3058,19 @@ class TestOps(unittest.TestCase):
def test_bitcast(self):
helper_test_op([(3, 3)], lambda x: x.view(torch.int32), lambda x: x.bitcast(dtypes.int32), forward_only=True)
def test_svd(self):
# test for tiny backend. real svd tests are in test_linalg
A = torch.randn(5, 5)
U, S, Vh = torch.linalg.svd(A, full_matrices=True)
np.testing.assert_allclose(torch.dist(A, U @ torch.diag(S) @ Vh).cpu().numpy(), 0, atol=1e-5)
U, S, Vh = torch.linalg.svd(A, full_matrices=False)
np.testing.assert_allclose(torch.dist(A, U @ torch.diag(S) @ Vh).cpu().numpy(), 0, atol=1e-5)
# # TODO: this works with torch, but not TINY_BACKEND. U has a wrong shape
# A = torch.randn(5, 3)
# U, S, Vh = torch.linalg.svd(A, full_matrices=False)
# np.testing.assert_allclose(torch.dist(A, U @ torch.diag(S) @ Vh).cpu().numpy(), 0, atol=1e-5)
@unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}")
class TestOpsUint8(unittest.TestCase):
def test_cast(self):