diff --git a/test/test_ops.py b/test/test_ops.py index 92b9f1b7fc..0e2f18ae8c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -301,6 +301,12 @@ class TestOps(unittest.TestCase): helper_test_op([(3,3)], lambda x: x.tril(2)) helper_test_op([(3,3)], lambda x: x.tril(-1)) helper_test_op([(3,3)], lambda x: x.tril(-2)) + helper_test_op([(4,5)], lambda x: x.tril(4)) + helper_test_op([(4,5)], lambda x: x.tril(5)) + helper_test_op([(4,5)], lambda x: x.tril(6)) + helper_test_op([(4,5)], lambda x: x.tril(-4)) + helper_test_op([(4,5)], lambda x: x.tril(-5)) + helper_test_op([(4,5)], lambda x: x.tril(-6)) helper_test_op([(5,3,3)], lambda x: x.tril()) helper_test_op([(5,0,3)], lambda x: x.tril()) helper_test_op([(5,3,3)], lambda x: x.tril(1)) @@ -311,6 +317,12 @@ class TestOps(unittest.TestCase): helper_test_op([(3,3)], lambda x: x.triu(2)) helper_test_op([(3,3)], lambda x: x.triu(-1)) helper_test_op([(3,3)], lambda x: x.triu(-2)) + helper_test_op([(4,5)], lambda x: x.triu(4)) + helper_test_op([(4,5)], lambda x: x.triu(5)) + helper_test_op([(4,5)], lambda x: x.triu(6)) + helper_test_op([(4,5)], lambda x: x.triu(-4)) + helper_test_op([(4,5)], lambda x: x.triu(-5)) + helper_test_op([(4,5)], lambda x: x.triu(-6)) helper_test_op([(5,3,3)], lambda x: x.triu()) helper_test_op([(5,0,3)], lambda x: x.triu()) helper_test_op([(5,3,3)], lambda x: x.triu(1)) diff --git a/test/test_schedule.py b/test/test_schedule.py index f333e89a7c..2023612cb3 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -31,7 +31,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt for i, s in enumerate(sched): print("kernel", i+1) for op in s.ast: print_tree(op) - if len(sched) != allowed: raise KernelCountException() + if len(sched) != allowed: raise KernelCountException(f"{len(sched)=} != {allowed}") # test the (non loadops) ops linearize for s in sched: if s.ast[0].op in LoadOps: continue @@ -797,7 +797,7 @@ class TestSchedule(unittest.TestCase): def test_scaled_dot_product_attention_causal_fusion(self): x, y, z, m = (Tensor.empty(32, 8, 16, 16) for _ in range(4)) out = Tensor.scaled_dot_product_attention(x, y, z, attn_mask=m, is_causal=True) - check_schedule(out, 7) + check_schedule(out, 6) def test_adam_step_fusion(self): with Tensor.train(): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 713445fd6d..992c31388c 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1796,9 +1796,13 @@ class Tensor: @staticmethod def _tri(r:sint, c:sint, diagonal:int=0, **kwargs) -> Tensor: - assert all_int((r,c)), "does not support symbolic" - if r == 0: return Tensor.zeros((r, c), **kwargs) - return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-diagonal, c-diagonal, **kwargs).unsqueeze(0).expand(r,c) + assert isinstance(r, int) and isinstance(c, int), f"does not support symbolic, getting {r=}, {c=}" + if r == 0 or c == 0 or diagonal >= c: return Tensor.zeros(r,c,**kwargs) + if r+diagonal <= 0: return Tensor.ones(r,c,**kwargs) + s = r+c-1 + # build a (s, s) upper triangle + t = Tensor.ones(s,s,**kwargs).pad((None,(0,s))).flatten().shrink(((0,s*(2*s-1)),)).reshape(s,-1).shrink((None,(0,s))) + return t[:r,-diagonal:c-diagonal] if diagonal <= 0 else t[diagonal:r+diagonal,:c] def triu(self, diagonal:int=0) -> Tensor: """ @@ -1821,7 +1825,7 @@ class Tensor: print(t.triu(diagonal=-1).numpy()) ``` """ - return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal, device=self.device).where(self, 0).cast(self.dtype) + return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal, device=self.device, dtype=dtypes.bool).where(self, 0).cast(self.dtype) def tril(self, diagonal:int=0) -> Tensor: """ @@ -1844,7 +1848,7 @@ class Tensor: print(t.tril(diagonal=-1).numpy()) ``` """ - return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device).where(0, self).cast(self.dtype) + return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(0, self).cast(self.dtype) # ***** unary ops *****