mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
@@ -26,5 +26,6 @@
|
||||
::: tinygrad.Tensor.transpose
|
||||
::: tinygrad.Tensor.flatten
|
||||
::: tinygrad.Tensor.unflatten
|
||||
::: tinygrad.Tensor.diag
|
||||
::: tinygrad.Tensor.roll
|
||||
::: tinygrad.Tensor.rearrange
|
||||
@@ -1924,6 +1924,9 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.unflatten(3, (3, 2)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.unflatten(-1, (3, 2, 1)))
|
||||
|
||||
def test_diag(self):
|
||||
helper_test_op([(5,)], lambda x: x.diag())
|
||||
|
||||
def test_roll(self):
|
||||
helper_test_op([(2, 4)], lambda x: torch.roll(x, 1, 0), lambda x: x.roll(1, 0))
|
||||
helper_test_op([(2, 4)], lambda x: torch.roll(x, -1, 0), lambda x: x.roll(-1, 0))
|
||||
|
||||
@@ -661,7 +661,7 @@ class Tensor(MathTrait):
|
||||
```
|
||||
"""
|
||||
if n < 0 or (m is not None and m < 0): raise ValueError(f"cannot have negative {n=}, {m=}")
|
||||
x = Tensor.ones((n,1),**kwargs).pad((None,(0,n))).flatten().shrink(((0,n*n),)).reshape(n,n)
|
||||
x = Tensor.ones(n, **kwargs).diag()
|
||||
return x if m is None else x.pad((None, (0, m-n))) if m > n else x.shrink((None, (0, m)))
|
||||
|
||||
def full_like(self, fill_value:ConstType, **kwargs) -> Tensor:
|
||||
@@ -1558,6 +1558,17 @@ class Tensor(MathTrait):
|
||||
dim = self._resolve_dim(dim)
|
||||
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
|
||||
|
||||
def diag(self) -> Tensor:
|
||||
"""
|
||||
Returns a 2-D square tensor with the elements of input as the main diagonal.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([1, 2, 3]).diag().numpy())
|
||||
```
|
||||
"""
|
||||
if self.ndim != 1: raise ValueError(f"expect input to be 1-D, getting {self.ndim}-D")
|
||||
return self.unsqueeze(-1).pad((None,(0,n:=self.shape[0]))).flatten().shrink(((0,n*n),)).reshape(n,n)
|
||||
|
||||
def roll(self, shifts:int|tuple[int, ...], dims:int|tuple[int, ...]) -> Tensor:
|
||||
"""
|
||||
Rolls the tensor along specified dimension(s).
|
||||
|
||||
Reference in New Issue
Block a user