diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index da49fbcc1b..0cfde467cc 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -352,6 +352,11 @@ def sort_values(input, dim=-1, descending=False, stable=True, values=None, indic unwrap(indices).assign(out_indices.cast(dtypes.int64)) return wrap(out_values), wrap(out_indices) +@torch.library.impl("aten::_linalg_svd", "privateuseone") +def _linalg_svd(self, full_matrices=True): + U, S, Vh = unwrap(self).svd(full_matrices) + return wrap(U), wrap(S), wrap(Vh) + # register some decompositions from torch._decomp import get_decompositions decomps = [ @@ -412,6 +417,7 @@ decomps = [ #aten.lgamma, # this needs copy_strided #aten.lerp, + aten.norm, ] for k,v in get_decompositions(decomps).items(): key = str(k._schema).split("(")[0] diff --git a/test/test_ops.py b/test/test_ops.py index be332f0bb1..8339fbc653 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -3058,6 +3058,19 @@ class TestOps(unittest.TestCase): def test_bitcast(self): helper_test_op([(3, 3)], lambda x: x.view(torch.int32), lambda x: x.bitcast(dtypes.int32), forward_only=True) + def test_svd(self): + # test for tiny backend. real svd tests are in test_linalg + A = torch.randn(5, 5) + U, S, Vh = torch.linalg.svd(A, full_matrices=True) + np.testing.assert_allclose(torch.dist(A, U @ torch.diag(S) @ Vh).cpu().numpy(), 0, atol=1e-5) + U, S, Vh = torch.linalg.svd(A, full_matrices=False) + np.testing.assert_allclose(torch.dist(A, U @ torch.diag(S) @ Vh).cpu().numpy(), 0, atol=1e-5) + + # # TODO: this works with torch, but not TINY_BACKEND. U has a wrong shape + # A = torch.randn(5, 3) + # U, S, Vh = torch.linalg.svd(A, full_matrices=False) + # np.testing.assert_allclose(torch.dist(A, U @ torch.diag(S) @ Vh).cpu().numpy(), 0, atol=1e-5) + @unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}") class TestOpsUint8(unittest.TestCase): def test_cast(self):