update Tensor.triu and Tensor.tril (#5109)

renamed arg to `diagonal` that matches torch api, and added document and examples
This commit is contained in:
chenyu
2024-06-22 21:59:50 -04:00
committed by GitHub
parent 8f6ae84e4a
commit 20fabd8a5b
2 changed files with 34 additions and 10 deletions

View File

@@ -298,7 +298,9 @@ class TestOps(unittest.TestCase):
def test_tril(self):
helper_test_op([(3,3)], lambda x: x.tril())
helper_test_op([(3,3)], lambda x: x.tril(1))
helper_test_op([(3,3)], lambda x: x.tril(2))
helper_test_op([(3,3)], lambda x: x.tril(-1))
helper_test_op([(3,3)], lambda x: x.tril(-2))
helper_test_op([(5,3,3)], lambda x: x.tril())
helper_test_op([(5,0,3)], lambda x: x.tril())
helper_test_op([(5,3,3)], lambda x: x.tril(1))
@@ -306,7 +308,9 @@ class TestOps(unittest.TestCase):
def test_triu(self):
helper_test_op([(3,3)], lambda x: x.triu())
helper_test_op([(3,3)], lambda x: x.triu(1))
helper_test_op([(3,3)], lambda x: x.triu(2))
helper_test_op([(3,3)], lambda x: x.triu(-1))
helper_test_op([(3,3)], lambda x: x.triu(-2))
helper_test_op([(5,3,3)], lambda x: x.triu())
helper_test_op([(5,0,3)], lambda x: x.triu())
helper_test_op([(5,3,3)], lambda x: x.triu(1))

View File

@@ -1795,36 +1795,56 @@ class Tensor:
return fix(ret) + fix(base_add)
@staticmethod
def _tri(r:sint, c:sint, k:int=0, **kwargs) -> Tensor:
def _tri(r:sint, c:sint, diagonal:int=0, **kwargs) -> Tensor:
assert all_int((r,c)), "does not support symbolic"
if r == 0: return Tensor.zeros((r, c), **kwargs)
return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
def triu(self, k:int=0) -> Tensor:
return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-diagonal, c-diagonal, **kwargs).unsqueeze(0).expand(r,c)
def triu(self, diagonal:int=0) -> Tensor:
"""
Returns the upper triangular part of the tensor, the other elements are set to 0.
The argument `diagonal` determines which diagonal is on the boundary. `diagonal = 0` means the main diagonal.
Positive `diagonal` means above the main diagonal, and negative `diagonal` means below the main diagonal.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2, 3], [4, 5, 6]])
t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.triu(k=1).numpy())
print(t.triu(diagonal=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.triu(diagonal=1).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.triu(diagonal=-1).numpy())
```
"""
return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0).cast(self.dtype)
def tril(self, k:int=0) -> Tensor:
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal, device=self.device).where(self, 0).cast(self.dtype)
def tril(self, diagonal:int=0) -> Tensor:
"""
Returns the lower triangular part of the tensor, the other elements are set to 0.
The argument `diagonal` determines which diagonal is on the boundary. `diagonal = 0` means the main diagonal.
Positive `diagonal` means above the main diagonal, and negative `diagonal` means below the main diagonal.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2, 3], [4, 5, 6]])
t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.tril().numpy())
print(t.tril(diagonal=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.tril(diagonal=1).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.tril(diagonal=-1).numpy())
```
"""
return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self).cast(self.dtype)
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device).where(0, self).cast(self.dtype)
# ***** unary ops *****