mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix opt in process replay [pr] (#14599)
This commit is contained in:
@@ -53,7 +53,9 @@ def replay_get_rangeify_map(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str,
|
||||
|
||||
def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> tuple[str, str, tuple[Any, ...]]:
|
||||
# the ast.arg is non None if we are inside of search.py
|
||||
sink_arg = ast.arg or KernelInfo(opts_to_apply=tuple(opts) if opts is not None else p.applied_opts if BEAM>=1 else None)
|
||||
sink_arg = ast.arg or KernelInfo()
|
||||
if opts is not None: sink_arg = replace(sink_arg, opts_to_apply=tuple(opts))
|
||||
elif BEAM >= 1 and sink_arg.opts_to_apply is None: sink_arg = replace(sink_arg, opts_to_apply=p.applied_opts)
|
||||
input_ast = ast if ast.op is Ops.PROGRAM else ast.replace(arg=replace(sink_arg, name=p.name))
|
||||
p2 = get_program(input_ast, renderer=renderer)
|
||||
def to_str(ret:ProgramSpec) -> str:
|
||||
|
||||
34
test/null/test_process_replay.py
Normal file
34
test/null/test_process_replay.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.engine.realize import get_program
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from test.external.process_replay.process_replay import replay_get_program
|
||||
|
||||
N = 16
|
||||
class TestProcessReplay(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.ast = (Tensor.empty(N, N) @ Tensor.empty(N, N)).schedule()[-1].ast
|
||||
cls.renderer = Device[Device.DEFAULT].renderer
|
||||
|
||||
def test_replay_no_opts(self):
|
||||
# opts=None means use default heuristic path
|
||||
p = get_program(self.ast, self.renderer)
|
||||
good, compare, _ = replay_get_program(p, self.ast, self.renderer)
|
||||
self.assertEqual(good, compare)
|
||||
|
||||
def test_replay_empty_opts(self):
|
||||
# opts=[] means explicitly apply zero opts (unoptimized)
|
||||
p = get_program(self.ast, self.renderer, opts=[])
|
||||
good, compare, _ = replay_get_program(p, self.ast, self.renderer, opts=[])
|
||||
self.assertEqual(good, compare)
|
||||
|
||||
def test_replay_with_opt(self):
|
||||
# opts=[Opt(...)] means apply a specific opt
|
||||
opts = [Opt(OptOps.UPCAST, 0, 4)]
|
||||
p = get_program(self.ast, self.renderer, opts=opts)
|
||||
good, compare, _ = replay_get_program(p, self.ast, self.renderer, opts=opts)
|
||||
self.assertEqual(good, compare)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
Reference in New Issue
Block a user