diff --git a/test/test_arange.py b/test/test_arange.py index 7c8411a891..8cb8f8972d 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -5,6 +5,8 @@ from tinygrad.helpers import CI, Context, getenv from tinygrad.engine.realize import run_schedule from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program from tinygrad.uop.ops import Ops +from tinygrad.renderer import Estimates +from tinygrad.renderer.ptx import PTXRenderer class TestArange(unittest.TestCase): def _get_flops(self, tensor, desired): @@ -29,6 +31,14 @@ class TestArange(unittest.TestCase): # NOTE: not every backend supports CMPEQ self.assertLessEqual(self._get_flops(Tensor.eye(2560).contiguous(), np.eye(2560)), 2*2560*2560) + @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX indexing is weird") + def test_tri_complexity(self): + with Context(NOOPT=1): + t = Tensor.ones(256, 256).contiguous().realize() + sched = t.triu().schedule() + p = get_program(sched[-1].ast) + self.assertLessEqual(Estimates.from_uops(p.uops).ops, 4 * 256 * 256) + DSET, DDIM = 2048, 32 class TestIndexing(unittest.TestCase): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d405280eea..339ff82913 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2456,14 +2456,9 @@ class Tensor(OpMixin): return self._split_cumalu(axis, Ops.MAX) @staticmethod - def _tri(r:sint, c:sint, diagonal:int=0, **kwargs) -> Tensor: + def _tri(r:sint, c:sint, diagonal:int=0, device=None, requires_grad:bool|None=None) -> Tensor: assert isinstance(r, int) and isinstance(c, int), f"does not support symbolic, getting {r=}, {c=}" - if r == 0 or c == 0 or diagonal >= c: return Tensor.zeros(r,c,**kwargs) - if r+diagonal <= 0: return Tensor.ones(r,c,**kwargs) - s = r+c-1 - # build a (s, s) upper triangle - t = Tensor.ones(s,s,**kwargs).pad((None,(0,s))).flatten().shrink(((0,s*(2*s-1)),)).reshape(s,-1).shrink((None,(0,s))) - return t[:r,-diagonal:c-diagonal] if diagonal <= 0 else t[diagonal:r+diagonal,:c] + return (Tensor.arange(r, device=device).unsqueeze(-1) + diagonal <= Tensor.arange(c, device=device)).requires_grad_(requires_grad) def triu(self, diagonal:int=0) -> Tensor: """ @@ -2486,7 +2481,7 @@ class Tensor(OpMixin): print(t.triu(diagonal=-1).numpy()) ``` """ - return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal, device=self.device, dtype=dtypes.bool).where(self, self.zeros_like()) + return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal, device=self.device).where(self, self.zeros_like()) def tril(self, diagonal:int=0) -> Tensor: """ @@ -2509,7 +2504,7 @@ class Tensor(OpMixin): print(t.tril(diagonal=-1).numpy()) ``` """ - return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(self.zeros_like(), self) + return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device).where(self.zeros_like(), self) def interpolate(self, size:tuple[int, ...], mode:str="linear", align_corners:bool=False) -> Tensor: """