test_diagonal touchup (#10962)

This commit is contained in:
chenyu
2025-06-24 15:51:19 -04:00
committed by GitHub
parent 7f9958b632
commit ffb032e31d

View File

@@ -206,15 +206,16 @@ class TestTorchBackend(unittest.TestCase):
X.cpu(), Y.cpu()
self.assertLessEqual(GlobalCounters.global_ops, 10_000_000)
def _check_diag(self, *shape, dtype=torch.float32):
a = torch.randn(*shape, dtype=dtype)
ref = np.diagonal(a.numpy(), axis1=-2, axis2=-1)
linalg_tiny = torch.linalg.diagonal(a)
np.testing.assert_equal(linalg_tiny.cpu().numpy(), ref)
def _test_diagonal(self, *shape):
a = torch.randn(*shape, dtype=torch.float32, device=device)
ref = np.diagonal(a.cpu().numpy(), axis1=-2, axis2=-1)
diag = torch.linalg.diagonal(a)
np.testing.assert_equal(diag.cpu().numpy(), ref)
np.testing.assert_equal(diag[-1].cpu().numpy(), ref[-1])
def test_cube(self): self._check_diag(3, 3, 3)
def test_rectangular_last_dims(self): self._check_diag(4, 5, 6)
def test_high_dimensional(self): self._check_diag(2, 3, 4, 5)
def test_diagonal_cube(self): self._test_diagonal(3, 3, 3)
def test_diagonal_rectangular(self): self._test_diagonal(4, 5, 6)
def test_diagonal_4d(self): self._test_diagonal(2, 3, 4, 5)
if __name__ == "__main__":
unittest.main()