Tensor.diag (#11108)

also updated Tensor.eye to use it
This commit is contained in:
chenyu
2025-07-05 23:03:02 -04:00
committed by GitHub
parent 4905af4ae0
commit 845a4d32bc
3 changed files with 16 additions and 1 deletions

View File

@@ -26,5 +26,6 @@
::: tinygrad.Tensor.transpose
::: tinygrad.Tensor.flatten
::: tinygrad.Tensor.unflatten
::: tinygrad.Tensor.diag
::: tinygrad.Tensor.roll
::: tinygrad.Tensor.rearrange

View File

@@ -1924,6 +1924,9 @@ class TestOps(unittest.TestCase):
helper_test_op([(4,3,6,6)], lambda x: x.unflatten(3, (3, 2)))
helper_test_op([(4,3,6,6)], lambda x: x.unflatten(-1, (3, 2, 1)))
def test_diag(self):
helper_test_op([(5,)], lambda x: x.diag())
def test_roll(self):
helper_test_op([(2, 4)], lambda x: torch.roll(x, 1, 0), lambda x: x.roll(1, 0))
helper_test_op([(2, 4)], lambda x: torch.roll(x, -1, 0), lambda x: x.roll(-1, 0))

View File

@@ -661,7 +661,7 @@ class Tensor(MathTrait):
```
"""
if n < 0 or (m is not None and m < 0): raise ValueError(f"cannot have negative {n=}, {m=}")
x = Tensor.ones((n,1),**kwargs).pad((None,(0,n))).flatten().shrink(((0,n*n),)).reshape(n,n)
x = Tensor.ones(n, **kwargs).diag()
return x if m is None else x.pad((None, (0, m-n))) if m > n else x.shrink((None, (0, m)))
def full_like(self, fill_value:ConstType, **kwargs) -> Tensor:
@@ -1558,6 +1558,17 @@ class Tensor(MathTrait):
dim = self._resolve_dim(dim)
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
def diag(self) -> Tensor:
"""
Returns a 2-D square tensor with the elements of input as the main diagonal.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1, 2, 3]).diag().numpy())
```
"""
if self.ndim != 1: raise ValueError(f"expect input to be 1-D, getting {self.ndim}-D")
return self.unsqueeze(-1).pad((None,(0,n:=self.shape[0]))).flatten().shrink(((0,n*n),)).reshape(n,n)
def roll(self, shifts:int|tuple[int, ...], dims:int|tuple[int, ...]) -> Tensor:
"""
Rolls the tensor along specified dimension(s).