fix process_replay Ops.BEAM [pr] (#15752)

This commit is contained in:
qazal
2026-04-16 01:35:28 +03:00
committed by GitHub
parent 41421c3b48
commit 96092d110c
2 changed files with 9 additions and 1 deletions

View File

@@ -43,6 +43,7 @@ class ProcessReplayWarning(Warning): pass
# *** replay the function and convert return values to string
def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> tuple[str, str, tuple[Any, ...]]:
if ast.op is Ops.BEAM: ast = ast.src[0]
# the ast.arg is non None if we are inside of search.py
sink_arg = ast.arg or KernelInfo()
if opts is not None: sink_arg = replace(sink_arg, opts_to_apply=tuple(opts))

View File

@@ -1,5 +1,5 @@
import unittest
from tinygrad import Tensor, Device
from tinygrad import Tensor, Device, Context
from tinygrad.engine.realize import get_program
from tinygrad.codegen.opt import Opt, OptOps
from test.external.process_replay.process_replay import replay_get_program
@@ -30,5 +30,12 @@ class TestProcessReplay(unittest.TestCase):
good, compare, _ = replay_get_program(p, self.ast, self.renderer, opts=opts)
self.assertEqual(good, compare)
@Context(BEAM=1)
def test_beam(self):
si = (Tensor.empty(N, N) @ Tensor.empty(N, N)).schedule()[-1]
p = get_program(si.ast, self.renderer)
good, compare, _ = replay_get_program(p, self.ast, self.renderer)
self.assertEqual(good, compare)
if __name__ == '__main__':
unittest.main(verbosity=2)