diff --git a/test/test_ops.py b/test/test_ops.py index 426c621651..4ff437b413 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1933,6 +1933,9 @@ class TestOps(unittest.TestCase): def test_diag(self): helper_test_op([(5,)], lambda x: x.diag()) + def test_diagonal(self): + helper_test_op([(5,5)], lambda x: x.diagonal()) + def test_roll(self): helper_test_op([(2, 4)], lambda x: torch.roll(x, 1, 0), lambda x: x.roll(1, 0)) helper_test_op([(2, 4)], lambda x: torch.roll(x, -1, 0), lambda x: x.roll(-1, 0)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d49dd21b45..952459ac40 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1567,6 +1567,21 @@ class Tensor(MathTrait): if self.ndim != 1: raise ValueError(f"expect input to be 1-D, getting {self.ndim}-D") return self.unsqueeze(-1).pad((None,(0,n:=self.shape[0]))).flatten().shrink(((0,n*n),)).reshape(n,n) + def diagonal(self) -> Tensor: + """ + Returns a view of input tensor with its main diagonal elements. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.arange(9).reshape(3, 3) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.diagonal().numpy()) + ``` + """ + if self.ndim != 2 or (n:=self.shape[0]) != self.shape[1]: raise ValueError(f"only 2-D square tensor is supported, getting {self.shape=}") + return self.flatten().pad(((0, n))).reshape(n, n+1)[:, 0] + def roll(self, shifts:int|tuple[int, ...], dims:int|tuple[int, ...]) -> Tensor: """ Rolls the tensor along specified dimension(s).