diff --git a/extra/torch_backend/test.py b/extra/torch_backend/test.py index c9389e0a82..236da279b9 100644 --- a/extra/torch_backend/test.py +++ b/extra/torch_backend/test.py @@ -206,15 +206,16 @@ class TestTorchBackend(unittest.TestCase): X.cpu(), Y.cpu() self.assertLessEqual(GlobalCounters.global_ops, 10_000_000) - def _check_diag(self, *shape, dtype=torch.float32): - a = torch.randn(*shape, dtype=dtype) - ref = np.diagonal(a.numpy(), axis1=-2, axis2=-1) - linalg_tiny = torch.linalg.diagonal(a) - np.testing.assert_equal(linalg_tiny.cpu().numpy(), ref) + def _test_diagonal(self, *shape): + a = torch.randn(*shape, dtype=torch.float32, device=device) + ref = np.diagonal(a.cpu().numpy(), axis1=-2, axis2=-1) + diag = torch.linalg.diagonal(a) + np.testing.assert_equal(diag.cpu().numpy(), ref) + np.testing.assert_equal(diag[-1].cpu().numpy(), ref[-1]) - def test_cube(self): self._check_diag(3, 3, 3) - def test_rectangular_last_dims(self): self._check_diag(4, 5, 6) - def test_high_dimensional(self): self._check_diag(2, 3, 4, 5) + def test_diagonal_cube(self): self._test_diagonal(3, 3, 3) + def test_diagonal_rectangular(self): self._test_diagonal(4, 5, 6) + def test_diagonal_4d(self): self._test_diagonal(2, 3, 4, 5) if __name__ == "__main__": unittest.main()