From 20fabd8a5b028f3adda5b505cd577735844062cb Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 22 Jun 2024 21:59:50 -0400 Subject: [PATCH] update Tensor.triu and Tensor.tril (#5109) renamed arg to `diagonal` that matches torch api, and added document and examples --- test/test_ops.py | 4 ++++ tinygrad/tensor.py | 40 ++++++++++++++++++++++++++++++---------- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 243332a23e..92b9f1b7fc 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -298,7 +298,9 @@ class TestOps(unittest.TestCase): def test_tril(self): helper_test_op([(3,3)], lambda x: x.tril()) helper_test_op([(3,3)], lambda x: x.tril(1)) + helper_test_op([(3,3)], lambda x: x.tril(2)) helper_test_op([(3,3)], lambda x: x.tril(-1)) + helper_test_op([(3,3)], lambda x: x.tril(-2)) helper_test_op([(5,3,3)], lambda x: x.tril()) helper_test_op([(5,0,3)], lambda x: x.tril()) helper_test_op([(5,3,3)], lambda x: x.tril(1)) @@ -306,7 +308,9 @@ class TestOps(unittest.TestCase): def test_triu(self): helper_test_op([(3,3)], lambda x: x.triu()) helper_test_op([(3,3)], lambda x: x.triu(1)) + helper_test_op([(3,3)], lambda x: x.triu(2)) helper_test_op([(3,3)], lambda x: x.triu(-1)) + helper_test_op([(3,3)], lambda x: x.triu(-2)) helper_test_op([(5,3,3)], lambda x: x.triu()) helper_test_op([(5,0,3)], lambda x: x.triu()) helper_test_op([(5,3,3)], lambda x: x.triu(1)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 28a13c927c..713445fd6d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1795,36 +1795,56 @@ class Tensor: return fix(ret) + fix(base_add) @staticmethod - def _tri(r:sint, c:sint, k:int=0, **kwargs) -> Tensor: + def _tri(r:sint, c:sint, diagonal:int=0, **kwargs) -> Tensor: assert all_int((r,c)), "does not support symbolic" if r == 0: return Tensor.zeros((r, c), **kwargs) - return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c) - def triu(self, k:int=0) -> Tensor: + return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-diagonal, c-diagonal, **kwargs).unsqueeze(0).expand(r,c) + + def triu(self, diagonal:int=0) -> Tensor: """ Returns the upper triangular part of the tensor, the other elements are set to 0. + The argument `diagonal` determines which diagonal is on the boundary. `diagonal = 0` means the main diagonal. + Positive `diagonal` means above the main diagonal, and negative `diagonal` means below the main diagonal. + ```python exec="true" source="above" session="tensor" result="python" - t = Tensor([[1, 2, 3], [4, 5, 6]]) + t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" - print(t.triu(k=1).numpy()) + print(t.triu(diagonal=0).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.triu(diagonal=1).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.triu(diagonal=-1).numpy()) ``` """ - return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0).cast(self.dtype) - def tril(self, k:int=0) -> Tensor: + return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal, device=self.device).where(self, 0).cast(self.dtype) + + def tril(self, diagonal:int=0) -> Tensor: """ Returns the lower triangular part of the tensor, the other elements are set to 0. + The argument `diagonal` determines which diagonal is on the boundary. `diagonal = 0` means the main diagonal. + Positive `diagonal` means above the main diagonal, and negative `diagonal` means below the main diagonal. + ```python exec="true" source="above" session="tensor" result="python" - t = Tensor([[1, 2, 3], [4, 5, 6]]) + t = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" - print(t.tril().numpy()) + print(t.tril(diagonal=0).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.tril(diagonal=1).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.tril(diagonal=-1).numpy()) ``` """ - return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self).cast(self.dtype) + return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device).where(0, self).cast(self.dtype) # ***** unary ops *****