mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix process_replay Ops.BEAM [pr] (#15752)
This commit is contained in:
@@ -43,6 +43,7 @@ class ProcessReplayWarning(Warning): pass
|
||||
# *** replay the function and convert return values to string
|
||||
|
||||
def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> tuple[str, str, tuple[Any, ...]]:
|
||||
if ast.op is Ops.BEAM: ast = ast.src[0]
|
||||
# the ast.arg is non None if we are inside of search.py
|
||||
sink_arg = ast.arg or KernelInfo()
|
||||
if opts is not None: sink_arg = replace(sink_arg, opts_to_apply=tuple(opts))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad import Tensor, Device, Context
|
||||
from tinygrad.engine.realize import get_program
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from test.external.process_replay.process_replay import replay_get_program
|
||||
@@ -30,5 +30,12 @@ class TestProcessReplay(unittest.TestCase):
|
||||
good, compare, _ = replay_get_program(p, self.ast, self.renderer, opts=opts)
|
||||
self.assertEqual(good, compare)
|
||||
|
||||
@Context(BEAM=1)
|
||||
def test_beam(self):
|
||||
si = (Tensor.empty(N, N) @ Tensor.empty(N, N)).schedule()[-1]
|
||||
p = get_program(si.ast, self.renderer)
|
||||
good, compare, _ = replay_get_program(p, self.ast, self.renderer)
|
||||
self.assertEqual(good, compare)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user