Tensor.diagonal support offset and dims (#14130)

This commit is contained in:
chenyu
2026-01-13 14:49:06 -05:00
committed by GitHub
parent 2a217ba206
commit 176a934ddd
3 changed files with 21 additions and 17 deletions

View File

@@ -74,6 +74,7 @@ view_ops = {
"aten.select.int": lambda self, dim, idx: self[(slice(None),) * (dim%self.ndim) + (idx,)],
"aten.permute": Tensor.permute,
"aten.alias": lambda self: self,
"aten.diagonal": Tensor.diagonal,
}
# torch 2.10 handles this natively
@@ -751,16 +752,3 @@ def _pad_circular(self, padding): return _PadCircular.apply(self, padding)
@torch.library.impl("aten::_pad_circular", "AutogradPrivateUse1")
def _pad_circular_autograd(self, padding): return _PadCircular.apply(self, padding)
# only needed for test_diag_backward_gradient_values
# was going through torch before, but now we are using tinygrad directly and tracking views
# Tensor.diagonal does not support all cases tests in the tests
@torch.library.impl("aten::diagonal", "privateuseone")
@wrap_view_op
def diagonal(self, offset=0, dim1=0, dim2=1):
if offset != 0: raise NotImplementedError(f"diagonal with {offset=} not implemented")
dim1, dim2 = dim1 % self.ndim, dim2 % self.ndim
if dim1 != self.ndim - 2 or dim2 != self.ndim - 1: raise NotImplementedError(f"diagonal with {dim1=}, {dim2=} not implemented, only last two dims supported")
batch_shape, m, n = self.shape[:-2], self.shape[-2], self.shape[-1]
diag_len = min(m, n)
return self.reshape(*batch_shape, m*n).pad(tuple((0,0) for _ in batch_shape) + ((0, diag_len),)).reshape(*batch_shape, diag_len, n+1)[..., :, 0]

View File

@@ -2024,6 +2024,15 @@ class TestOps(unittest.TestCase):
def test_diagonal(self):
helper_test_op([(5,5)], lambda x: x.diagonal())
helper_test_op([(3,4)], lambda x: x.diagonal()) # rectangular
helper_test_op([(4,3)], lambda x: x.diagonal()) # rectangular (other way)
helper_test_op([(3,3,3)], lambda x: x.diagonal(dim1=-2, dim2=-1)) # batched
helper_test_op([(4,5,6)], lambda x: x.diagonal(dim1=-2, dim2=-1)) # batched rectangular
helper_test_op([(2,3,4,5)], lambda x: x.diagonal(dim1=-2, dim2=-1)) # 4D batched
helper_test_op([(5,5)], lambda x: x.diagonal(offset=1)) # positive offset
helper_test_op([(5,5)], lambda x: x.diagonal(offset=-1)) # negative offset
helper_test_op([(3,5)], lambda x: x.diagonal(offset=2)) # offset on rectangular
self.helper_test_exception([(3,3)], lambda x: x.diagonal(dim1=0, dim2=0), expected=RuntimeError)
def test_roll(self):
helper_test_op([(2, 4)], lambda x: x.roll(1))

View File

@@ -1440,9 +1440,10 @@ class Tensor(OpMixin):
if self.ndim != 1: raise ValueError(f"expect input to be 1-D, getting {self.ndim}-D")
return self.unsqueeze(-1).pad((None,(0,n:=self.shape[0]))).flatten().shrink(((0,n*n),)).reshape(n,n)
def diagonal(self) -> Tensor:
def diagonal(self, offset:int=0, dim1:int=0, dim2:int=1) -> Tensor:
"""
Returns a view of input tensor with its main diagonal elements.
Returns a view of the diagonal elements with respect to `dim1` and `dim2`.
`offset` controls which diagonal: 0 is main, positive is above, negative is below.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(9).reshape(3, 3)
@@ -1451,9 +1452,15 @@ class Tensor(OpMixin):
```python exec="true" source="above" session="tensor" result="python"
print(t.diagonal().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.diagonal(offset=1).numpy())
```
"""
if self.ndim != 2 or (n:=self.shape[0]) != self.shape[1]: raise ValueError(f"only 2-D square tensor is supported, getting {self.shape=}")
return self.flatten().pad(((0, n))).reshape(n, n+1)[:, 0]
if (dim1:=self._resolve_dim(dim1)) == (dim2:=self._resolve_dim(dim2)): raise RuntimeError("dim1 and dim2 cannot be the same dimension")
x = self.permute(*[i for i in range(self.ndim) if i != dim1 and i != dim2], dim1, dim2)
x = x[..., :, offset:] if offset >= 0 else x[..., -offset:, :]
if (d := min(int(x.shape[-2]), int(x.shape[-1]))) <= 0: return x.reshape(*x.shape[:-2], 0)
return x[..., :d, :d].flatten(-2).pad(tuple((0,0) for _ in x.shape[:-2])+((0,d),)).reshape(*x.shape[:-2], d, d+1)[..., 0]
def roll(self, shifts:int|tuple[int, ...], dims:int|tuple[int, ...]|None=None) -> Tensor:
"""