mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Tensor.diagonal support offset and dims (#14130)
This commit is contained in:
@@ -74,6 +74,7 @@ view_ops = {
|
||||
"aten.select.int": lambda self, dim, idx: self[(slice(None),) * (dim%self.ndim) + (idx,)],
|
||||
"aten.permute": Tensor.permute,
|
||||
"aten.alias": lambda self: self,
|
||||
"aten.diagonal": Tensor.diagonal,
|
||||
}
|
||||
|
||||
# torch 2.10 handles this natively
|
||||
@@ -751,16 +752,3 @@ def _pad_circular(self, padding): return _PadCircular.apply(self, padding)
|
||||
|
||||
@torch.library.impl("aten::_pad_circular", "AutogradPrivateUse1")
|
||||
def _pad_circular_autograd(self, padding): return _PadCircular.apply(self, padding)
|
||||
|
||||
# only needed for test_diag_backward_gradient_values
|
||||
# was going through torch before, but now we are using tinygrad directly and tracking views
|
||||
# Tensor.diagonal does not support all cases tests in the tests
|
||||
@torch.library.impl("aten::diagonal", "privateuseone")
|
||||
@wrap_view_op
|
||||
def diagonal(self, offset=0, dim1=0, dim2=1):
|
||||
if offset != 0: raise NotImplementedError(f"diagonal with {offset=} not implemented")
|
||||
dim1, dim2 = dim1 % self.ndim, dim2 % self.ndim
|
||||
if dim1 != self.ndim - 2 or dim2 != self.ndim - 1: raise NotImplementedError(f"diagonal with {dim1=}, {dim2=} not implemented, only last two dims supported")
|
||||
batch_shape, m, n = self.shape[:-2], self.shape[-2], self.shape[-1]
|
||||
diag_len = min(m, n)
|
||||
return self.reshape(*batch_shape, m*n).pad(tuple((0,0) for _ in batch_shape) + ((0, diag_len),)).reshape(*batch_shape, diag_len, n+1)[..., :, 0]
|
||||
|
||||
@@ -2024,6 +2024,15 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_diagonal(self):
|
||||
helper_test_op([(5,5)], lambda x: x.diagonal())
|
||||
helper_test_op([(3,4)], lambda x: x.diagonal()) # rectangular
|
||||
helper_test_op([(4,3)], lambda x: x.diagonal()) # rectangular (other way)
|
||||
helper_test_op([(3,3,3)], lambda x: x.diagonal(dim1=-2, dim2=-1)) # batched
|
||||
helper_test_op([(4,5,6)], lambda x: x.diagonal(dim1=-2, dim2=-1)) # batched rectangular
|
||||
helper_test_op([(2,3,4,5)], lambda x: x.diagonal(dim1=-2, dim2=-1)) # 4D batched
|
||||
helper_test_op([(5,5)], lambda x: x.diagonal(offset=1)) # positive offset
|
||||
helper_test_op([(5,5)], lambda x: x.diagonal(offset=-1)) # negative offset
|
||||
helper_test_op([(3,5)], lambda x: x.diagonal(offset=2)) # offset on rectangular
|
||||
self.helper_test_exception([(3,3)], lambda x: x.diagonal(dim1=0, dim2=0), expected=RuntimeError)
|
||||
|
||||
def test_roll(self):
|
||||
helper_test_op([(2, 4)], lambda x: x.roll(1))
|
||||
|
||||
@@ -1440,9 +1440,10 @@ class Tensor(OpMixin):
|
||||
if self.ndim != 1: raise ValueError(f"expect input to be 1-D, getting {self.ndim}-D")
|
||||
return self.unsqueeze(-1).pad((None,(0,n:=self.shape[0]))).flatten().shrink(((0,n*n),)).reshape(n,n)
|
||||
|
||||
def diagonal(self) -> Tensor:
|
||||
def diagonal(self, offset:int=0, dim1:int=0, dim2:int=1) -> Tensor:
|
||||
"""
|
||||
Returns a view of input tensor with its main diagonal elements.
|
||||
Returns a view of the diagonal elements with respect to `dim1` and `dim2`.
|
||||
`offset` controls which diagonal: 0 is main, positive is above, negative is below.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(9).reshape(3, 3)
|
||||
@@ -1451,9 +1452,15 @@ class Tensor(OpMixin):
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.diagonal().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.diagonal(offset=1).numpy())
|
||||
```
|
||||
"""
|
||||
if self.ndim != 2 or (n:=self.shape[0]) != self.shape[1]: raise ValueError(f"only 2-D square tensor is supported, getting {self.shape=}")
|
||||
return self.flatten().pad(((0, n))).reshape(n, n+1)[:, 0]
|
||||
if (dim1:=self._resolve_dim(dim1)) == (dim2:=self._resolve_dim(dim2)): raise RuntimeError("dim1 and dim2 cannot be the same dimension")
|
||||
x = self.permute(*[i for i in range(self.ndim) if i != dim1 and i != dim2], dim1, dim2)
|
||||
x = x[..., :, offset:] if offset >= 0 else x[..., -offset:, :]
|
||||
if (d := min(int(x.shape[-2]), int(x.shape[-1]))) <= 0: return x.reshape(*x.shape[:-2], 0)
|
||||
return x[..., :d, :d].flatten(-2).pad(tuple((0,0) for _ in x.shape[:-2])+((0,d),)).reshape(*x.shape[:-2], d, d+1)[..., 0]
|
||||
|
||||
def roll(self, shifts:int|tuple[int, ...], dims:int|tuple[int, ...]|None=None) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user