diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 15fa660ced..351945a29c 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -53,7 +53,9 @@ def replay_get_rangeify_map(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str, def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> tuple[str, str, tuple[Any, ...]]: # the ast.arg is non None if we are inside of search.py - sink_arg = ast.arg or KernelInfo(opts_to_apply=tuple(opts) if opts is not None else p.applied_opts if BEAM>=1 else None) + sink_arg = ast.arg or KernelInfo() + if opts is not None: sink_arg = replace(sink_arg, opts_to_apply=tuple(opts)) + elif BEAM >= 1 and sink_arg.opts_to_apply is None: sink_arg = replace(sink_arg, opts_to_apply=p.applied_opts) input_ast = ast if ast.op is Ops.PROGRAM else ast.replace(arg=replace(sink_arg, name=p.name)) p2 = get_program(input_ast, renderer=renderer) def to_str(ret:ProgramSpec) -> str: diff --git a/test/null/test_process_replay.py b/test/null/test_process_replay.py new file mode 100644 index 0000000000..e2fc5db108 --- /dev/null +++ b/test/null/test_process_replay.py @@ -0,0 +1,34 @@ +import unittest +from tinygrad import Tensor, Device +from tinygrad.engine.realize import get_program +from tinygrad.codegen.opt import Opt, OptOps +from test.external.process_replay.process_replay import replay_get_program + +N = 16 +class TestProcessReplay(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.ast = (Tensor.empty(N, N) @ Tensor.empty(N, N)).schedule()[-1].ast + cls.renderer = Device[Device.DEFAULT].renderer + + def test_replay_no_opts(self): + # opts=None means use default heuristic path + p = get_program(self.ast, self.renderer) + good, compare, _ = replay_get_program(p, self.ast, self.renderer) + self.assertEqual(good, compare) + + def test_replay_empty_opts(self): + # opts=[] means explicitly apply zero opts (unoptimized) + p = get_program(self.ast, self.renderer, opts=[]) + good, compare, _ = replay_get_program(p, self.ast, self.renderer, opts=[]) + self.assertEqual(good, compare) + + def test_replay_with_opt(self): + # opts=[Opt(...)] means apply a specific opt + opts = [Opt(OptOps.UPCAST, 0, 4)] + p = get_program(self.ast, self.renderer, opts=opts) + good, compare, _ = replay_get_program(p, self.ast, self.renderer, opts=opts) + self.assertEqual(good, compare) + +if __name__ == '__main__': + unittest.main(verbosity=2)