Fix torch.linalg.diagonal crash due to invalid shrink in to_movement_ops (#10945)

* fix as_strided shrink bug breaking torch.linalg.diagonal on tinygrad backend

* cleanup

* generic fix

* tests

* cmp with diagonal too

* oops

* move tests

* fix test

* remove unnecessary import

* fix assert

* compare against numpy

---------

Co-authored-by: Utkarsh Gill <engelbart@Utkarshs-MacBook-Pro.local>
This commit is contained in:
Utkarsh Gill
2025-06-25 01:06:06 +05:30
committed by GitHub
parent 26ddf8d714
commit 7f9958b632
2 changed files with 11 additions and 0 deletions

View File

@@ -59,6 +59,7 @@ view_ops = {
"aten.squeeze.dim": Tensor.squeeze,
"aten.unsqueeze": Tensor.unsqueeze,
"aten.detach": Tensor.detach,
"aten.select.int": lambda self, dim, idx: self[(slice(None),) * (dim%self.ndim) + (idx,)],
}
for k,v in view_ops.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_view_op(v))

View File

@@ -206,5 +206,15 @@ class TestTorchBackend(unittest.TestCase):
X.cpu(), Y.cpu()
self.assertLessEqual(GlobalCounters.global_ops, 10_000_000)
def _check_diag(self, *shape, dtype=torch.float32):
a = torch.randn(*shape, dtype=dtype)
ref = np.diagonal(a.numpy(), axis1=-2, axis2=-1)
linalg_tiny = torch.linalg.diagonal(a)
np.testing.assert_equal(linalg_tiny.cpu().numpy(), ref)
def test_cube(self): self._check_diag(3, 3, 3)
def test_rectangular_last_dims(self): self._check_diag(4, 5, 6)
def test_high_dimensional(self): self._check_diag(2, 3, 4, 5)
if __name__ == "__main__":
unittest.main()