From 845a4d32bc4234713abb6451a8376825eba1b5ee Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 5 Jul 2025 23:03:02 -0400 Subject: [PATCH] Tensor.diag (#11108) also updated Tensor.eye to use it --- docs/tensor/movement.md | 1 + test/test_ops.py | 3 +++ tinygrad/tensor.py | 13 ++++++++++++- 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/docs/tensor/movement.md b/docs/tensor/movement.md index 7c40bdc94e..fa0461028d 100644 --- a/docs/tensor/movement.md +++ b/docs/tensor/movement.md @@ -26,5 +26,6 @@ ::: tinygrad.Tensor.transpose ::: tinygrad.Tensor.flatten ::: tinygrad.Tensor.unflatten +::: tinygrad.Tensor.diag ::: tinygrad.Tensor.roll ::: tinygrad.Tensor.rearrange \ No newline at end of file diff --git a/test/test_ops.py b/test/test_ops.py index adc368dee3..be332f0bb1 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1924,6 +1924,9 @@ class TestOps(unittest.TestCase): helper_test_op([(4,3,6,6)], lambda x: x.unflatten(3, (3, 2))) helper_test_op([(4,3,6,6)], lambda x: x.unflatten(-1, (3, 2, 1))) + def test_diag(self): + helper_test_op([(5,)], lambda x: x.diag()) + def test_roll(self): helper_test_op([(2, 4)], lambda x: torch.roll(x, 1, 0), lambda x: x.roll(1, 0)) helper_test_op([(2, 4)], lambda x: torch.roll(x, -1, 0), lambda x: x.roll(-1, 0)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b0cbd875a3..fb9a4fb260 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -661,7 +661,7 @@ class Tensor(MathTrait): ``` """ if n < 0 or (m is not None and m < 0): raise ValueError(f"cannot have negative {n=}, {m=}") - x = Tensor.ones((n,1),**kwargs).pad((None,(0,n))).flatten().shrink(((0,n*n),)).reshape(n,n) + x = Tensor.ones(n, **kwargs).diag() return x if m is None else x.pad((None, (0, m-n))) if m > n else x.shrink((None, (0, m))) def full_like(self, fill_value:ConstType, **kwargs) -> Tensor: @@ -1558,6 +1558,17 @@ class Tensor(MathTrait): dim = self._resolve_dim(dim) return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:]) + def diag(self) -> Tensor: + """ + Returns a 2-D square tensor with the elements of input as the main diagonal. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([1, 2, 3]).diag().numpy()) + ``` + """ + if self.ndim != 1: raise ValueError(f"expect input to be 1-D, getting {self.ndim}-D") + return self.unsqueeze(-1).pad((None,(0,n:=self.shape[0]))).flatten().shrink(((0,n*n),)).reshape(n,n) + def roll(self, shifts:int|tuple[int, ...], dims:int|tuple[int, ...]) -> Tensor: """ Rolls the tensor along specified dimension(s).