mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
update Tensor.triu and Tensor.tril (#5109)
renamed arg to `diagonal` that matches torch api, and added document and examples
This commit is contained in:
@@ -298,7 +298,9 @@ class TestOps(unittest.TestCase):
|
||||
def test_tril(self):
|
||||
helper_test_op([(3,3)], lambda x: x.tril())
|
||||
helper_test_op([(3,3)], lambda x: x.tril(1))
|
||||
helper_test_op([(3,3)], lambda x: x.tril(2))
|
||||
helper_test_op([(3,3)], lambda x: x.tril(-1))
|
||||
helper_test_op([(3,3)], lambda x: x.tril(-2))
|
||||
helper_test_op([(5,3,3)], lambda x: x.tril())
|
||||
helper_test_op([(5,0,3)], lambda x: x.tril())
|
||||
helper_test_op([(5,3,3)], lambda x: x.tril(1))
|
||||
@@ -306,7 +308,9 @@ class TestOps(unittest.TestCase):
|
||||
def test_triu(self):
|
||||
helper_test_op([(3,3)], lambda x: x.triu())
|
||||
helper_test_op([(3,3)], lambda x: x.triu(1))
|
||||
helper_test_op([(3,3)], lambda x: x.triu(2))
|
||||
helper_test_op([(3,3)], lambda x: x.triu(-1))
|
||||
helper_test_op([(3,3)], lambda x: x.triu(-2))
|
||||
helper_test_op([(5,3,3)], lambda x: x.triu())
|
||||
helper_test_op([(5,0,3)], lambda x: x.triu())
|
||||
helper_test_op([(5,3,3)], lambda x: x.triu(1))
|
||||
|
||||
@@ -1795,36 +1795,56 @@ class Tensor:
|
||||
return fix(ret) + fix(base_add)
|
||||
|
||||
@staticmethod
|
||||
def _tri(r:sint, c:sint, k:int=0, **kwargs) -> Tensor:
|
||||
def _tri(r:sint, c:sint, diagonal:int=0, **kwargs) -> Tensor:
|
||||
assert all_int((r,c)), "does not support symbolic"
|
||||
if r == 0: return Tensor.zeros((r, c), **kwargs)
|
||||
return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
|
||||
def triu(self, k:int=0) -> Tensor:
|
||||
return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-diagonal, c-diagonal, **kwargs).unsqueeze(0).expand(r,c)
|
||||
|
||||
def triu(self, diagonal:int=0) -> Tensor:
|
||||
"""
|
||||
Returns the upper triangular part of the tensor, the other elements are set to 0.
|
||||
|
||||
The argument `diagonal` determines which diagonal is on the boundary. `diagonal = 0` means the main diagonal.
|
||||
Positive `diagonal` means above the main diagonal, and negative `diagonal` means below the main diagonal.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([[1, 2, 3], [4, 5, 6]])
|
||||
t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.triu(k=1).numpy())
|
||||
print(t.triu(diagonal=0).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.triu(diagonal=1).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.triu(diagonal=-1).numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0).cast(self.dtype)
|
||||
def tril(self, k:int=0) -> Tensor:
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal, device=self.device).where(self, 0).cast(self.dtype)
|
||||
|
||||
def tril(self, diagonal:int=0) -> Tensor:
|
||||
"""
|
||||
Returns the lower triangular part of the tensor, the other elements are set to 0.
|
||||
|
||||
The argument `diagonal` determines which diagonal is on the boundary. `diagonal = 0` means the main diagonal.
|
||||
Positive `diagonal` means above the main diagonal, and negative `diagonal` means below the main diagonal.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([[1, 2, 3], [4, 5, 6]])
|
||||
t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.tril().numpy())
|
||||
print(t.tril(diagonal=0).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.tril(diagonal=1).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.tril(diagonal=-1).numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self).cast(self.dtype)
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device).where(0, self).cast(self.dtype)
|
||||
|
||||
# ***** unary ops *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user