mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
test_diagonal touchup (#10962)
This commit is contained in:
@@ -206,15 +206,16 @@ class TestTorchBackend(unittest.TestCase):
|
||||
X.cpu(), Y.cpu()
|
||||
self.assertLessEqual(GlobalCounters.global_ops, 10_000_000)
|
||||
|
||||
def _check_diag(self, *shape, dtype=torch.float32):
|
||||
a = torch.randn(*shape, dtype=dtype)
|
||||
ref = np.diagonal(a.numpy(), axis1=-2, axis2=-1)
|
||||
linalg_tiny = torch.linalg.diagonal(a)
|
||||
np.testing.assert_equal(linalg_tiny.cpu().numpy(), ref)
|
||||
def _test_diagonal(self, *shape):
|
||||
a = torch.randn(*shape, dtype=torch.float32, device=device)
|
||||
ref = np.diagonal(a.cpu().numpy(), axis1=-2, axis2=-1)
|
||||
diag = torch.linalg.diagonal(a)
|
||||
np.testing.assert_equal(diag.cpu().numpy(), ref)
|
||||
np.testing.assert_equal(diag[-1].cpu().numpy(), ref[-1])
|
||||
|
||||
def test_cube(self): self._check_diag(3, 3, 3)
|
||||
def test_rectangular_last_dims(self): self._check_diag(4, 5, 6)
|
||||
def test_high_dimensional(self): self._check_diag(2, 3, 4, 5)
|
||||
def test_diagonal_cube(self): self._test_diagonal(3, 3, 3)
|
||||
def test_diagonal_rectangular(self): self._test_diagonal(4, 5, 6)
|
||||
def test_diagonal_4d(self): self._test_diagonal(2, 3, 4, 5)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user