diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index 80fc038c76..eba99816ef 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -128,6 +128,12 @@ def _linalg_eigh(self, UPLO: str = 'U'): w, v = torch.linalg.eigh(self.cpu(), UPLO=UPLO) return w.tiny(), v.tiny() +@torch.library.impl("aten::_linalg_det", "privateuseone") +# TODO: move to tinygrad +def _linalg_det(self: torch.Tensor): + result = aten._linalg_det(self.cpu()) + return result[0].tiny(), result[1].tiny(), result[2].tiny() + def upsample_backward(grad_out, output_size, input_size, *args, f=None): return f(grad_out.cpu(), output_size, input_size, *args).tiny() for i in [ diff --git a/extra/torch_backend/test.py b/extra/torch_backend/test.py index 10c077001d..087f944204 100644 --- a/extra/torch_backend/test.py +++ b/extra/torch_backend/test.py @@ -198,6 +198,11 @@ class TestTorchBackend(unittest.TestCase): recon = (v @ torch.diag(w) @ v.T).cpu().numpy() np.testing.assert_allclose(recon, a.cpu().numpy(), atol=1e-6) + def test_linalg_det(self): + a = torch.diag(torch.tensor([1,2,3,4,5], dtype = torch.float32, device=device)) + b = torch.linalg.det(a) + np.testing.assert_equal(b.cpu().numpy(), 120.0) + def test_scalar_assign(self): a = torch.tensor([1, 2, 3], device=device) a[1] = 4