From 176a934ddd006e1fe8da4a4cbc0a53d46b6c162b Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 13 Jan 2026 14:49:06 -0500 Subject: [PATCH] Tensor.diagonal support offset and dims (#14130) --- extra/torch_backend/backend.py | 14 +------------- test/test_ops.py | 9 +++++++++ tinygrad/tensor.py | 15 +++++++++++---- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index aba875c648..0498ec5874 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -74,6 +74,7 @@ view_ops = { "aten.select.int": lambda self, dim, idx: self[(slice(None),) * (dim%self.ndim) + (idx,)], "aten.permute": Tensor.permute, "aten.alias": lambda self: self, + "aten.diagonal": Tensor.diagonal, } # torch 2.10 handles this natively @@ -751,16 +752,3 @@ def _pad_circular(self, padding): return _PadCircular.apply(self, padding) @torch.library.impl("aten::_pad_circular", "AutogradPrivateUse1") def _pad_circular_autograd(self, padding): return _PadCircular.apply(self, padding) - -# only needed for test_diag_backward_gradient_values -# was going through torch before, but now we are using tinygrad directly and tracking views -# Tensor.diagonal does not support all cases tests in the tests -@torch.library.impl("aten::diagonal", "privateuseone") -@wrap_view_op -def diagonal(self, offset=0, dim1=0, dim2=1): - if offset != 0: raise NotImplementedError(f"diagonal with {offset=} not implemented") - dim1, dim2 = dim1 % self.ndim, dim2 % self.ndim - if dim1 != self.ndim - 2 or dim2 != self.ndim - 1: raise NotImplementedError(f"diagonal with {dim1=}, {dim2=} not implemented, only last two dims supported") - batch_shape, m, n = self.shape[:-2], self.shape[-2], self.shape[-1] - diag_len = min(m, n) - return self.reshape(*batch_shape, m*n).pad(tuple((0,0) for _ in batch_shape) + ((0, diag_len),)).reshape(*batch_shape, diag_len, n+1)[..., :, 0] diff --git a/test/test_ops.py b/test/test_ops.py index 961f36fe5b..53f9cdf494 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2024,6 +2024,15 @@ class TestOps(unittest.TestCase): def test_diagonal(self): helper_test_op([(5,5)], lambda x: x.diagonal()) + helper_test_op([(3,4)], lambda x: x.diagonal()) # rectangular + helper_test_op([(4,3)], lambda x: x.diagonal()) # rectangular (other way) + helper_test_op([(3,3,3)], lambda x: x.diagonal(dim1=-2, dim2=-1)) # batched + helper_test_op([(4,5,6)], lambda x: x.diagonal(dim1=-2, dim2=-1)) # batched rectangular + helper_test_op([(2,3,4,5)], lambda x: x.diagonal(dim1=-2, dim2=-1)) # 4D batched + helper_test_op([(5,5)], lambda x: x.diagonal(offset=1)) # positive offset + helper_test_op([(5,5)], lambda x: x.diagonal(offset=-1)) # negative offset + helper_test_op([(3,5)], lambda x: x.diagonal(offset=2)) # offset on rectangular + self.helper_test_exception([(3,3)], lambda x: x.diagonal(dim1=0, dim2=0), expected=RuntimeError) def test_roll(self): helper_test_op([(2, 4)], lambda x: x.roll(1)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index deca3f3507..6ec5651bf4 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1440,9 +1440,10 @@ class Tensor(OpMixin): if self.ndim != 1: raise ValueError(f"expect input to be 1-D, getting {self.ndim}-D") return self.unsqueeze(-1).pad((None,(0,n:=self.shape[0]))).flatten().shrink(((0,n*n),)).reshape(n,n) - def diagonal(self) -> Tensor: + def diagonal(self, offset:int=0, dim1:int=0, dim2:int=1) -> Tensor: """ - Returns a view of input tensor with its main diagonal elements. + Returns a view of the diagonal elements with respect to `dim1` and `dim2`. + `offset` controls which diagonal: 0 is main, positive is above, negative is below. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(9).reshape(3, 3) @@ -1451,9 +1452,15 @@ class Tensor(OpMixin): ```python exec="true" source="above" session="tensor" result="python" print(t.diagonal().numpy()) ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.diagonal(offset=1).numpy()) + ``` """ - if self.ndim != 2 or (n:=self.shape[0]) != self.shape[1]: raise ValueError(f"only 2-D square tensor is supported, getting {self.shape=}") - return self.flatten().pad(((0, n))).reshape(n, n+1)[:, 0] + if (dim1:=self._resolve_dim(dim1)) == (dim2:=self._resolve_dim(dim2)): raise RuntimeError("dim1 and dim2 cannot be the same dimension") + x = self.permute(*[i for i in range(self.ndim) if i != dim1 and i != dim2], dim1, dim2) + x = x[..., :, offset:] if offset >= 0 else x[..., -offset:, :] + if (d := min(int(x.shape[-2]), int(x.shape[-1]))) <= 0: return x.reshape(*x.shape[:-2], 0) + return x[..., :d, :d].flatten(-2).pad(tuple((0,0) for _ in x.shape[:-2])+((0,d),)).reshape(*x.shape[:-2], d, d+1)[..., 0] def roll(self, shifts:int|tuple[int, ...], dims:int|tuple[int, ...]|None=None) -> Tensor: """