update tests get_runner (#4522)

This commit is contained in:
George Hotz
2024-05-10 20:09:22 -07:00
committed by GitHub
parent a0448ff595
commit 827058f030
5 changed files with 50 additions and 49 deletions

View File

@@ -25,12 +25,12 @@ if __name__ == "__main__":
# cuda compile
culin = ast_str_to_lin(ast, opts=cudev.compiler.compiler_opts)
culin.hand_coded_optimizations()
cuda_prg = cudev.to_program(culin)
cuda_prg = cudev.to_runner(culin)
cubufs = bufs_from_lin(culin)
nvlin = ast_str_to_lin(ast, opts=nvdev.compiler.compiler_opts)
nvlin.hand_coded_optimizations()
nv_prg = nvdev.to_program(nvlin)
nv_prg = nvdev.to_runner(nvlin)
nvbufs = bufs_from_lin(nvlin)
# warmup

View File

@@ -1,9 +1,9 @@
import unittest
import time
import numpy as np
from tinygrad import Tensor, dtypes, Device
from tinygrad import Tensor, dtypes
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.realize import lower_schedule_item, run_schedule
class TestFusionOp(unittest.TestCase):
def test_contiguous_add(self):
@@ -27,9 +27,9 @@ class TestFusionOp(unittest.TestCase):
a = Tensor([1,2,3,4])
for _ in range(24): a = a + a
sched = create_schedule([a.lazydata], None)
ji = Device[Device.DEFAULT].get_runner(*sched[-1].ast)
ei = lower_schedule_item(sched[-1])
self.assertLess(time.perf_counter()-st, 1.0)
assert len(ji.p.src.splitlines()) < 250
assert len(ei.prg.p.src.splitlines()) < 250
def test_recursive_add_cmp(self):
st = time.perf_counter()

View File

@@ -1,6 +1,7 @@
import unittest
from tinygrad import Tensor, Device
from tinygrad import Tensor
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import lower_schedule_item
# TODO: can copy this in here when we remove it
#from tinygrad.ops import get_lazyop_info
@@ -12,8 +13,8 @@ from tinygrad.engine.schedule import create_schedule
def get_stats(x:Tensor):
si = create_schedule([x.lazydata])[-1]
runner = Device[Device.DEFAULT].get_runner(*si.ast)
return runner.op_estimate, runner.mem_estimate
ei = lower_schedule_item(si)
return ei.prg.p.op_estimate, ei.prg.p.mem_estimate
class TestUOpsStats(unittest.TestCase):
def test_simple_add(self):