Tensor._tri with arange (#13297)

This commit is contained in:
chenyu
2025-11-16 07:21:16 -08:00
committed by GitHub
parent 6372c95094
commit 8f0e747b3a
2 changed files with 14 additions and 9 deletions

View File

@@ -5,6 +5,8 @@ from tinygrad.helpers import CI, Context, getenv
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
from tinygrad.uop.ops import Ops
from tinygrad.renderer import Estimates
from tinygrad.renderer.ptx import PTXRenderer
class TestArange(unittest.TestCase):
def _get_flops(self, tensor, desired):
@@ -29,6 +31,14 @@ class TestArange(unittest.TestCase):
# NOTE: not every backend supports CMPEQ
self.assertLessEqual(self._get_flops(Tensor.eye(2560).contiguous(), np.eye(2560)), 2*2560*2560)
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX indexing is weird")
def test_tri_complexity(self):
with Context(NOOPT=1):
t = Tensor.ones(256, 256).contiguous().realize()
sched = t.triu().schedule()
p = get_program(sched[-1].ast)
self.assertLessEqual(Estimates.from_uops(p.uops).ops, 4 * 256 * 256)
DSET, DDIM = 2048, 32
class TestIndexing(unittest.TestCase):