mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Tensor._tri with arange (#13297)
This commit is contained in:
@@ -5,6 +5,8 @@ from tinygrad.helpers import CI, Context, getenv
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.renderer import Estimates
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
|
||||
class TestArange(unittest.TestCase):
|
||||
def _get_flops(self, tensor, desired):
|
||||
@@ -29,6 +31,14 @@ class TestArange(unittest.TestCase):
|
||||
# NOTE: not every backend supports CMPEQ
|
||||
self.assertLessEqual(self._get_flops(Tensor.eye(2560).contiguous(), np.eye(2560)), 2*2560*2560)
|
||||
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX indexing is weird")
|
||||
def test_tri_complexity(self):
|
||||
with Context(NOOPT=1):
|
||||
t = Tensor.ones(256, 256).contiguous().realize()
|
||||
sched = t.triu().schedule()
|
||||
p = get_program(sched[-1].ast)
|
||||
self.assertLessEqual(Estimates.from_uops(p.uops).ops, 4 * 256 * 256)
|
||||
|
||||
DSET, DDIM = 2048, 32
|
||||
|
||||
class TestIndexing(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user