diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 8262abd4bb..2ebb8c17c2 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -14,7 +14,6 @@ try: from tinygrad.uop.ops import UOp, Ops, KernelInfo from tinygrad.codegen.opt import Opt from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm, BEAM - from tinygrad.device import Device except ImportError as e: print(repr(e)) exit(int(ASSERT_DIFF)) @@ -52,12 +51,10 @@ def replay_get_rangeify_map(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str, return "\n".join([f"{len(asts)} kernels", *asts]) return to_str(new_sink), to_str(big_sink.substitute(ret)), (big_sink,) -def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) -> tuple[str, str, tuple[Any, ...]]: +def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> tuple[str, str, tuple[Any, ...]]: # the ast.arg is non None if we are inside of search.py sink_arg = ast.arg or KernelInfo(opts_to_apply=tuple(opts) if opts is not None else p.applied_opts if BEAM>=1 else None) input_ast = ast.replace(arg=replace(sink_arg, name=p.name)) - # if no renderer was provided, open the device to get it - if renderer is None: renderer = Device[p.device].renderer p2 = get_program(input_ast, renderer=renderer) def to_str(ret:ProgramSpec) -> str: # PYTHON renderer pickles UOps, first unpickle and decode here diff --git a/test/opt/test_gen_float4.py b/test/opt/test_gen_float4.py index 357dccae6d..03780d5a0b 100644 --- a/test/opt/test_gen_float4.py +++ b/test/opt/test_gen_float4.py @@ -24,7 +24,7 @@ class TestFloat4(unittest.TestCase): s = c.schedule()[0] realized_ast = s.ast opts_to_apply = [Opt(op=OptOps.UPCAST, axis=0, arg=4)] - program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply) + program = get_program(realized_ast, renderer=Device[Device.DEFAULT].renderer, opts=opts_to_apply) assert TestFloat4.count_float4(program.uops) == (2, 1) @@ -35,7 +35,8 @@ class TestFloat4(unittest.TestCase): c = a + b s = c.schedule()[0] - uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=2)]).uops + uops = get_program(s.ast, renderer=Device[Device.DEFAULT].renderer, + opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=2)]).uops assert TestFloat4.count_float4(uops) == (4, 2) @unittest.skipUnless(Device.DEFAULT in {"CPU"} and AMX, "Only CPU with AMX upcasts float up to size 16") @@ -46,7 +47,8 @@ class TestFloat4(unittest.TestCase): c = a + b s = c.schedule()[0] - return get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=shift)]).uops + return get_program(s.ast, renderer=Device[Device.DEFAULT].renderer, + opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=shift)]).uops sizes = [12, 8, 16] shifts = [3, 2, 4] @@ -64,7 +66,7 @@ class TestFloat4(unittest.TestCase): s = c.schedule()[0] realized_ast = s.ast opts_to_apply = [Opt(op=OptOps.UPCAST, axis=0, arg=4)] - program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply) + program = get_program(realized_ast, renderer=Device[Device.DEFAULT].renderer, opts=opts_to_apply) assert TestFloat4.count_float4(program.uops) == (0, 1) @@ -75,7 +77,8 @@ class TestFloat4(unittest.TestCase): c = a + b s = c.schedule()[0] - uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2)]).uops + uops = get_program(s.ast, renderer=Device[Device.DEFAULT].renderer, + opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2)]).uops assert TestFloat4.count_float4(uops) == (0, 2) @@ -87,7 +90,8 @@ class TestFloat4(unittest.TestCase): c = a + b s = c.schedule()[0] - return get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=shift)]).uops + return get_program(s.ast, renderer=Device[Device.DEFAULT].renderer, + opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=shift)]).uops sizes = [13, 9, 17] shifts = [3, 2, 4] @@ -105,7 +109,7 @@ class TestFloat4(unittest.TestCase): # float4 should be emitted (the reduce axis of size 4 is the float4 axis here) s = c.schedule()[0] - uops = get_program(s.ast, opts=[Opt(op=OptOps.UNROLL, axis=0, arg=4)]).uops + uops = get_program(s.ast, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(op=OptOps.UNROLL, axis=0, arg=4)]).uops assert TestFloat4.count_float4(uops) == (0, 0) @@ -119,7 +123,8 @@ class TestFloat4(unittest.TestCase): # UPDATE: now we do this fusion s = c.schedule()[0] - uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops + uops = get_program(s.ast, renderer=Device[Device.DEFAULT].renderer, + opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops assert TestFloat4.count_float4(uops) in {(0,1), (1,1)} @@ -132,7 +137,7 @@ class TestFloat4(unittest.TestCase): # since the top axis is not contiguous. s = c.schedule()[0] - uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops + uops = get_program(s.ast, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops assert TestFloat4.count_float4(uops) == (0, 1) @@ -144,7 +149,7 @@ class TestFloat4(unittest.TestCase): # should float4 b but not a s = c.schedule()[0] - uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops + uops = get_program(s.ast, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops assert TestFloat4.count_float4(uops) == (1, 1) diff --git a/test/test_arange.py b/test/test_arange.py index 8cb8f8972d..962db05efa 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -13,7 +13,7 @@ class TestArange(unittest.TestCase): GlobalCounters.reset() sched = tensor.schedule() self.assertEqual(len(sched), 1) - p = get_program(sched[-1].ast) + p = get_program(sched[-1].ast, renderer=Device[Device.DEFAULT].renderer) ExecItem(CompiledRunner(p), [tensor.uop.buffer]).run() np.testing.assert_equal(tensor.numpy(), desired) return p.estimates.ops @@ -36,7 +36,7 @@ class TestArange(unittest.TestCase): with Context(NOOPT=1): t = Tensor.ones(256, 256).contiguous().realize() sched = t.triu().schedule() - p = get_program(sched[-1].ast) + p = get_program(sched[-1].ast, renderer=Device[Device.DEFAULT].renderer) self.assertLessEqual(Estimates.from_uops(p.uops).ops, 4 * 256 * 256) DSET, DDIM = 2048, 32 diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 183974ee1b..db11c260b9 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -45,7 +45,7 @@ class TestLinearizer(unittest.TestCase): tst = Tensor.ones(16, dtype=dtypes.int).contiguous().realize() out = tst.neg().cast(dtypes.char).cast(dtypes.int).cast(dtypes.char) * 2 ast = helper_linearizer_opt(out) - uops = get_program(ast, opts=[]).uops + uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1) @unittest.expectedFailure @@ -53,7 +53,7 @@ class TestLinearizer(unittest.TestCase): tst = Tensor.ones(16, dtype=dtypes.int).contiguous().realize() out = tst.neg().cast(dtypes.char).cast(dtypes.int) * 2 ast = helper_linearizer_opt(out) - uops = get_program(ast, opts=[]).uops + uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0) @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken on ptx") @@ -63,7 +63,7 @@ class TestLinearizer(unittest.TestCase): b = Tensor.empty(16) out = img.conv2d(w, b) ast = helper_linearizer_opt(out) - uops = get_program(ast, opts=[]).uops + uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops # slice at the last loop end uslice = [i for i,u in enumerate(uops) if u.op == Ops.END][-1] # only valid test if outermost range is the reduce @@ -84,7 +84,7 @@ class TestLinearizer(unittest.TestCase): a = Tensor.randn(2, ).realize() out = a.reshape(2, 1).expand(2, 3).sum() ast = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)).sum()]) - uops = get_program(ast, opts=[]).uops + uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE] assert len(ranges) == 1 # NOTE: it collapses now @@ -92,7 +92,7 @@ class TestLinearizer(unittest.TestCase): a = Tensor.randn(2, ).realize() out = a.reshape(2, 1).expand(2, 3).expand(2, 2, 3).sum() ast = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)), (2, 2, 3)).sum()]) - uops = get_program(ast, opts=[]).uops + uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE] assert len(ranges) == 1 # NOTE: it collapses now @@ -101,7 +101,7 @@ class TestLinearizer(unittest.TestCase): a = Tensor([2, 2]).realize() out = a.reshape(2, 1).pad(((1, 1), (1, 1)), value=2).sum() ast = helper_linearizer_opt(out, wanna_output=[24]) - uops = get_program(ast, opts=[]).uops + uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE] # RANGE -> ALU -> RANGE -> ALU + LOAD -> STORE assert any(x.op in GroupOp.ALU for x in uops[ranges[0]:ranges[1]]) @@ -114,7 +114,7 @@ class TestLinearizer(unittest.TestCase): b = Tensor.randn(1, 1).realize() out = (a + b[0]).sum() + b[0] ast = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()]) - uops = get_program(ast, opts=[]).uops + uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE] # LOAD -> RANGE -> LOAD -> STORE assert len([x for x in uops[:ranges[0]] if x.op is Ops.LOAD]) == 1 @@ -124,7 +124,7 @@ class TestLinearizer(unittest.TestCase): b = Tensor.randn(1, 1).realize() out = (a.reshape(2, 1).expand(2, 3) + b[0]).sum() + b[0] ast = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)) + b.numpy()[0]).sum() + b.numpy()]) - uops = get_program(ast, opts=[]).uops + uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE] assert len(ranges) == 1 # NOTE: it collapses now @@ -135,7 +135,7 @@ class TestLinearizer(unittest.TestCase): # these are of size 3 to avoid float4 coalesce r = a[:-1] + a[1:] - uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops + uops = get_program(r.schedule()[-1].ast, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops num_loads = len([uop for uop in uops if uop.op is Ops.LOAD]) assert num_loads <= 4, "more load uops than needed" assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?" @@ -147,7 +147,7 @@ class TestLinearizer(unittest.TestCase): a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() r = a.expand([2]) + b.expand([2]) - uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops + uops = get_program(r.schedule()[-1].ast, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops num_ops = len([uop for uop in uops if uop.op in GroupOp.ALU]) assert num_ops <= 1, "more alu uops than needed" @@ -156,7 +156,8 @@ class TestLinearizer(unittest.TestCase): x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize() r = Tensor.conv2d(x,w,padding=1).relu() - uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops + uops = get_program(r.schedule()[-1].ast, renderer=Device[Device.DEFAULT].renderer, + opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops accs = [u for u in uops if u.op is Ops.DEFINE_REG] stores = [u for u in uops if u.op is Ops.STORE] assert len(accs) == 0 # it's removed now @@ -179,7 +180,7 @@ class TestLinearizer(unittest.TestCase): x, y = Tensor.rand(1,128), Tensor.rand(128, 128) r = (x@y).relu() opts_to_apply = [Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)] - program = get_program(r.schedule()[-1].ast, opts=opts_to_apply) + program = get_program(r.schedule()[-1].ast, renderer=Device[Device.DEFAULT].renderer, opts=opts_to_apply) stores = [u for u in program.uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG] @@ -194,7 +195,7 @@ class TestLinearizer(unittest.TestCase): def test_zero_fold(self): a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() r = Tensor.stack(a, b) - uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops + uops = get_program(r.schedule()[-1].ast, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops num_ops = len([uop for uop in uops if uop.op in GroupOp.ALU]) assert num_ops == 0, "more alu uops than needed" @@ -204,14 +205,14 @@ class TestLinearizer(unittest.TestCase): if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype): a = Tensor([1, 2, 3], dtype=tensor_dtype).sum() realized_ast = a.schedule()[-1].ast - program = get_program(realized_ast, opts=[]) + program = get_program(realized_ast, renderer=Device[Device.DEFAULT].renderer, opts=[]) local = [uop for uop in program.uops if uop.op is Ops.DEFINE_REG] assert local[0].dtype.base == acc_dtype def test_arg_acc_dtype(self): def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType): realized_ast = c.schedule()[-1].ast - program = get_program(realized_ast, opts=[]) + program = get_program(realized_ast, renderer=Device[Device.DEFAULT].renderer, opts=[]) local = [uop for uop in program.uops if uop.op is Ops.DEFINE_REG] self.assertEqual(local[0].dtype.base, expected_dtype) @@ -239,7 +240,7 @@ class TestLinearizer(unittest.TestCase): opt = [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)] ast = helper_linearizer_opt(r, [opt]) # the uops graph is DEFINE_REG -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE - uops = get_program(ast, opts=opt).uops + uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=opt).uops begin_range = [i for i, x in enumerate(uops) if x.op is Ops.RANGE][-1] end_range = [i for i, x in enumerate(uops) if x.op is Ops.END][0] for i,u in enumerate(uops): print(i, u.op, [uops.index(s) for s in u.src], u.arg, u.dtype) @@ -353,7 +354,7 @@ class TestLinearizer(unittest.TestCase): # shrink so that the dims do not collapse t = Tensor.ones(5, 6, 7).contiguous().realize().shrink(((0, 4), (0, 5), (0, 6))) ast = helper_linearizer_opt(t+1) - uops = get_program(ast, opts=[]).uops + uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops idxs = dedup([uop for uop in uops if uop.op is Ops.SPECIAL]) idxs = sorted(idxs, key=lambda uop: uop.arg) assert (idxs[0].arg, idxs[0].src[0].arg) == ('gidx0', 6), idxs[0] @@ -386,13 +387,13 @@ class TestLinearizer(unittest.TestCase): sched_copy = sched[:] run_schedule(sched) np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.]) - program = get_program(sched_copy[-1].ast, opts=()) + program = get_program(sched_copy[-1].ast, renderer=Device[Device.DEFAULT].renderer, opts=()) assert not any(u.op == Ops.WHERE for u in program.uops), "found where where where should be folded" def test_phi_simplification(self): def helper(t, max_ops=0): ast = helper_linearizer_opt(t) - uops = get_program(ast).uops + uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer).uops # ignore kernel optimized IF statements for now if if_op:=next((u for u in uops if u.op is Ops.IF), None): uops = uops[:uops.index(if_op)] @@ -425,7 +426,7 @@ class TestLinearizer(unittest.TestCase): out = x.matmul(y) with Context(TC=0): ast = helper_linearizer_opt(out) - uops = get_program(ast).uops + uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer).uops # check that the float4 cast collapses store_vals = [u.src[1] for u in uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG] for val in store_vals: @@ -436,7 +437,7 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn((4,3,6,6)).realize() out = x.flip((0,1)).contiguous() ast = helper_linearizer_opt(out) - store_val = [u.src[1] for u in get_program(ast).uops if u.op is Ops.STORE][0] + store_val = [u.src[1] for u in get_program(ast, renderer=Device[Device.DEFAULT].renderer).uops if u.op is Ops.STORE][0] assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not Ops.VECTORIZE @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @@ -449,7 +450,7 @@ class TestLinearizer(unittest.TestCase): Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)] # upcast accs in both reduces ast = helper_linearizer_opt(out, opts=[opt]) def get_recursive(uop): return set.union(set(uop.src), [uop], *[get_recursive(v) for v in uop.src]) - uops = get_program(ast, opts=opt).uops + uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=opt).uops local_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.DEFINE_LOCAL for x in get_recursive(u.src[0]))] global_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.DEFINE_GLOBAL for x in get_recursive(u.src[0]))] barrier = [u for u in uops if u.op is Ops.BARRIER] @@ -470,7 +471,7 @@ class TestLinearizer(unittest.TestCase): x, y = Tensor.rand(1,128), Tensor.rand(128, 128) r = (x@y).relu() ast = helper_linearizer_opt(r) - uops = get_program(ast).uops + uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer).uops stores = [u for u in uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG] # the float4 value stores directly in lds and we skip upcast @@ -517,7 +518,7 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[] device = real_bufs[0].device wanna_output = [np.array(x).flatten() for x in wanna_output] - def get_prg(opts): return CompiledRunner(replace(get_program(realized_ast, opts=opts), device=device)) + def get_prg(opts): return CompiledRunner(replace(get_program(realized_ast, renderer=Device[Device.DEFAULT].renderer, opts=opts), device=device)) def check_opt(opts): prg = get_prg(opts=opts) diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index e5c1521b91..8d4a451423 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -3,6 +3,7 @@ import unittest from tinygrad.uop.ops import UOp, Ops, AxisType from tinygrad.dtype import dtypes from tinygrad.engine.realize import get_program +from tinygrad.device import Device class TestLinearizerFailures(unittest.TestCase): def test_fail_1(self): @@ -18,7 +19,7 @@ class TestLinearizerFailures(unittest.TestCase): c9 = ((((c6+(c8*UOp.const(dtypes.float, -1.0)))*(c6+(c8*UOp.const(dtypes.float, -1.0)))).reduce(c5, arg=Ops.ADD)*UOp.const(dtypes.float, 0.000390625))+UOp.const(dtypes.float, 1e-05)).sqrt().reciprocal() c10 = c0.index(c3).store(c9).end(c1, c2) ast = c10.sink() - get_program(ast) + get_program(ast, renderer=Device[Device.DEFAULT].renderer) if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/test_opt_gemm.py b/test/test_opt_gemm.py index 27aa767eaa..de7d0b438c 100644 --- a/test/test_opt_gemm.py +++ b/test/test_opt_gemm.py @@ -1,6 +1,6 @@ import numpy as np import unittest -from tinygrad import Tensor +from tinygrad import Tensor, Device from tinygrad.helpers import get_single_element from tinygrad.codegen.opt import Opt, OptOps from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program @@ -17,7 +17,7 @@ class TestOptGemm(unittest.TestCase): t = self.a.T @ self.b.T # TODO: this should be a generic test helper si = get_single_element(t.schedule()) - run = CompiledRunner(get_program(si.ast, opts=opts)) + run = CompiledRunner(get_program(si.ast, renderer=Device[Device.DEFAULT].renderer, opts=opts)) ExecItem(run, si.bufs).run() test = si.bufs[0].numpy().reshape(self.res.shape) np.testing.assert_allclose(self.res, test, atol=1e-4) diff --git a/test/test_opts.py b/test/test_opts.py index 67f90ef1f1..359441cbf1 100644 --- a/test/test_opts.py +++ b/test/test_opts.py @@ -13,7 +13,7 @@ class TestOpts(unittest.TestCase): s = out.schedule() self.assertEqual(s[-1].ast.arg.opts_to_apply, opts) if Device.DEFAULT in {"CPU", "CL", "METAL"} and not CPU_LLVM and not CPU_LVP: - prg = get_program(s[-1].ast) + prg = get_program(s[-1].ast, renderer=Device[Device.DEFAULT].renderer) self.assertIn('float4', prg.src) if __name__ == '__main__': diff --git a/test/test_quantize_onnx.py b/test/test_quantize_onnx.py index dba794cbc8..4a30b9ad2c 100644 --- a/test/test_quantize_onnx.py +++ b/test/test_quantize_onnx.py @@ -38,7 +38,7 @@ def create_gemm_model(model_path:str, batch_size=N, in_size=N, out_size=N, bias= def sexec(out:Tensor, opts:list[Opt], replace_src=None, run_count=3): si = out.schedule()[-1] - prg = get_program(si.ast, opts=opts) + prg = get_program(si.ast, renderer=Device[Device.DEFAULT].renderer, opts=opts) if replace_src is not None: old_name = prg.src.split("__attribute__((noinline)) void ")[1].split("(")[0] prg = replace(prg, src=replace_src + "/* DSP boilerplate */" + prg.src.split("/* DSP boilerplate */")[1].replace(old_name, "fxn")) diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index 1aff484ab6..78c35e08ca 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -175,13 +175,13 @@ class TestStatsOptimized(unittest.TestCase): self.assertEqual(p.estimates.mem, 3*N*N*4) # 3 NxN mats with floats def test_gemm(self): - p = get_program(self.ast_gemm, opts=[]) + p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, opts=[]) self.check_gemm(p) self.assertEqual(p.estimates.lds, 2*N*N*N*4 + 4*N*N) def test_gemm_tc_unroll(self): try: - p = get_program(self.ast_gemm, opts=[Opt(OptOps.TC, 0, (-1, 0, 1)), Opt(OptOps.UNROLL, 0, 2)]) + p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(OptOps.TC, 0, (-1, 0, 1)), Opt(OptOps.UNROLL, 0, 2)]) except KernelOptError: raise unittest.SkipTest("no tensor cores") print(p.src) @@ -190,18 +190,19 @@ class TestStatsOptimized(unittest.TestCase): # this is a good lesson about why UPCASTing is a good idea def test_gemm_one_upcasted(self): - p = get_program(self.ast_gemm, opts=[Opt(OptOps.UPCAST, 0, 4)]) + p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(OptOps.UPCAST, 0, 4)]) self.check_gemm(p) self.assertEqual(p.estimates.lds, N*N*N*4 + N*N*N*4//4 + 4*N*N) def test_gemm_upcasted(self): - p = get_program(self.ast_gemm, opts=[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)]) + p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, + opts=[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)]) self.check_gemm(p) self.assertEqual(p.estimates.lds, 2*N*N*N*4//4 + 4*N*N) def test_gemm_upcasted_locals(self): try: - p = get_program(self.ast_gemm, opts=[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), + p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4)]) except KernelOptError: raise unittest.SkipTest("no locals") @@ -210,7 +211,7 @@ class TestStatsOptimized(unittest.TestCase): def test_gemm_group(self): try: - p = get_program(self.ast_gemm, opts=[Opt(OptOps.GROUP, 0, 4)]) + p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(OptOps.GROUP, 0, 4)]) except KernelOptError: raise unittest.SkipTest("no locals") SZ = N*N*4 @@ -219,14 +220,14 @@ class TestStatsOptimized(unittest.TestCase): self.assertEqual(p.estimates.lds, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4) def test_reduce(self): - p = get_program(self.ast_reduce, opts=[]) + p = get_program(self.ast_reduce, renderer=Device[Device.DEFAULT].renderer, opts=[]) print(p.name, p.estimates.ops, p.estimates.mem, p.estimates.lds) self.assertEqual(p.estimates.ops, N*N) self.assertEqual(p.estimates.mem, N*N*4 + 4) def test_reduce_group(self): try: - p = get_program(self.ast_reduce, opts=[Opt(OptOps.GROUP, 0, 50)]) + p = get_program(self.ast_reduce, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(OptOps.GROUP, 0, 50)]) except KernelOptError: raise unittest.SkipTest("no locals") # NOTE: these are wrong, they don't respect the if statement diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 3529dbefee..a2edd2ad62 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -14,7 +14,7 @@ from tinygrad.codegen.opt import Opt # **************** Program Creation **************** @track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True) -def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) -> ProgramSpec: +def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> ProgramSpec: """ Transform an AST into a ProgramSpec. May trigger BEAM search. @@ -30,7 +30,6 @@ def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) if DEBUG >= 5: print(pyrender(ast)) # linearize - if renderer is None: renderer = Device.default.renderer if opts is not None: assert ast.arg is None, "can't apply opts if sink has an arg" ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))