Tensor.diagonal (#11122)

only implemented main diagonal for 2-D tensors. with diagonal and qr, we can get determinant
This commit is contained in:
chenyu
2025-07-07 16:21:26 -04:00
committed by GitHub
parent 584fd6af5a
commit 341a686799
2 changed files with 18 additions and 0 deletions

View File

@@ -1933,6 +1933,9 @@ class TestOps(unittest.TestCase):
def test_diag(self):
helper_test_op([(5,)], lambda x: x.diag())
def test_diagonal(self):
helper_test_op([(5,5)], lambda x: x.diagonal())
def test_roll(self):
helper_test_op([(2, 4)], lambda x: torch.roll(x, 1, 0), lambda x: x.roll(1, 0))
helper_test_op([(2, 4)], lambda x: torch.roll(x, -1, 0), lambda x: x.roll(-1, 0))

View File

@@ -1567,6 +1567,21 @@ class Tensor(MathTrait):
if self.ndim != 1: raise ValueError(f"expect input to be 1-D, getting {self.ndim}-D")
return self.unsqueeze(-1).pad((None,(0,n:=self.shape[0]))).flatten().shrink(((0,n*n),)).reshape(n,n)
def diagonal(self) -> Tensor:
"""
Returns a view of input tensor with its main diagonal elements.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(9).reshape(3, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.diagonal().numpy())
```
"""
if self.ndim != 2 or (n:=self.shape[0]) != self.shape[1]: raise ValueError(f"only 2-D square tensor is supported, getting {self.shape=}")
return self.flatten().pad(((0, n))).reshape(n, n+1)[:, 0]
def roll(self, shifts:int|tuple[int, ...], dims:int|tuple[int, ...]) -> Tensor:
"""
Rolls the tensor along specified dimension(s).