mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Tensor.diagonal (#11122)
only implemented main diagonal for 2-D tensors. with diagonal and qr, we can get determinant
This commit is contained in:
@@ -1933,6 +1933,9 @@ class TestOps(unittest.TestCase):
|
||||
def test_diag(self):
|
||||
helper_test_op([(5,)], lambda x: x.diag())
|
||||
|
||||
def test_diagonal(self):
|
||||
helper_test_op([(5,5)], lambda x: x.diagonal())
|
||||
|
||||
def test_roll(self):
|
||||
helper_test_op([(2, 4)], lambda x: torch.roll(x, 1, 0), lambda x: x.roll(1, 0))
|
||||
helper_test_op([(2, 4)], lambda x: torch.roll(x, -1, 0), lambda x: x.roll(-1, 0))
|
||||
|
||||
@@ -1567,6 +1567,21 @@ class Tensor(MathTrait):
|
||||
if self.ndim != 1: raise ValueError(f"expect input to be 1-D, getting {self.ndim}-D")
|
||||
return self.unsqueeze(-1).pad((None,(0,n:=self.shape[0]))).flatten().shrink(((0,n*n),)).reshape(n,n)
|
||||
|
||||
def diagonal(self) -> Tensor:
|
||||
"""
|
||||
Returns a view of input tensor with its main diagonal elements.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(9).reshape(3, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.diagonal().numpy())
|
||||
```
|
||||
"""
|
||||
if self.ndim != 2 or (n:=self.shape[0]) != self.shape[1]: raise ValueError(f"only 2-D square tensor is supported, getting {self.shape=}")
|
||||
return self.flatten().pad(((0, n))).reshape(n, n+1)[:, 0]
|
||||
|
||||
def roll(self, shifts:int|tuple[int, ...], dims:int|tuple[int, ...]) -> Tensor:
|
||||
"""
|
||||
Rolls the tensor along specified dimension(s).
|
||||
|
||||
Reference in New Issue
Block a user