Files
tinygrad/test/test_search.py
chenyu fb71d1e5fd delete some test_search tests (#11998)
TC_SEARCH_OVER_SHAPE was removed so should the tests
2025-09-04 11:19:49 -04:00

80 lines
3.4 KiB
Python

import unittest
from tinygrad.codegen.opt.kernel import Opt, OptOps, Kernel
from tinygrad.codegen.opt.search import bufs_from_lin, actions, beam_search
from tinygrad.tensor import Tensor
from tinygrad.helpers import Context, GlobalCounters
from tinygrad.engine.realize import capturing
class TestBEAM(unittest.TestCase):
def test_dynamic_beam(self):
# TODO: make this infra globally usable
class Capture:
def __init__(self): self.captured = []
def add(self, x): self.captured.append(x)
capturing.append(Capture())
kernel_count = GlobalCounters.kernel_count
with Context(BEAM=1): Tensor.zeros(16).contiguous().realize()
assert GlobalCounters.kernel_count == kernel_count + 1
k_beam_1 = capturing[0].captured
capturing.clear()
capturing.append(Capture())
kernel_count = GlobalCounters.kernel_count
with Context(BEAM=0): Tensor.zeros(16).contiguous().realize()
assert GlobalCounters.kernel_count == kernel_count + 1
k_beam_0 = capturing[0].captured
capturing.clear()
self.assertNotEqual(k_beam_0[-1].prg.p.src, k_beam_1[-1].prg.p.src)
def test_get_kernel_actions_dedup(self):
from test.test_linearizer import helper_realized_ast
from tinygrad.codegen.opt.search import get_kernel_actions
a = Tensor.empty(4, 3)
b = Tensor.empty(3)
realized_ast, _ = helper_realized_ast(a @ b)
candidates = [
Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=4),
Opt(op=OptOps.LOCAL, axis=0, arg=0), Opt(op=OptOps.LOCAL, axis=0, arg=4),
Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=3),
Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=3),
Opt(op=OptOps.GROUPTOP, axis=0, arg=0), Opt(op=OptOps.GROUPTOP, axis=0, arg=3),
]
lins = get_kernel_actions(Kernel(realized_ast), include_0=False, candidates=candidates).values()
# ensure amt=0 are not duplicated
assert all(len(x.applied_opts) == 1 for x in lins)
kernel_actions = [x.applied_opts[0] for x in lins]
assert Opt(OptOps.UPCAST, axis=0, arg=4) not in kernel_actions, "did not de-dup UPCAST"
assert Opt(OptOps.LOCAL, axis=0, arg=4) not in kernel_actions, "did not de-dup LOCAL"
assert Opt(OptOps.UNROLL, axis=0, arg=3) not in kernel_actions, "did not de-dup UNROLL"
assert Opt(OptOps.GROUP, axis=0, arg=3) not in kernel_actions, "did not de-dup GROUP"
assert Opt(OptOps.GROUPTOP, axis=0, arg=3) not in kernel_actions, "did not de-dup GROUPTOP"
def test_get_kernel_actions_preserves_actions_state(self):
from test.test_linearizer import helper_realized_ast
from tinygrad.codegen.opt.search import get_kernel_actions
a = Tensor.rand(16, 16)
b = Tensor.rand(16, 16)
realized_ast, _ = helper_realized_ast(a @ b)
actions_before = actions.copy()
get_kernel_actions(Kernel(realized_ast))
actions_after = actions.copy()
assert actions_after == actions_before, "actions state was not preserved"
def test_beam_unnamed_kernels(self):
from test.test_linearizer import push_views
a = Tensor.rand(100)
b = Tensor.rand(100)
si = (a+b).schedule()[-1]
lin = Kernel(push_views(si.ast))
bufs = bufs_from_lin(lin)
# TODO: beam should have better instrumentation so we don't have to check this indirect thing
kcount = len(Kernel.kernel_cnt)
beam_search(lin, bufs, 3, disable_cache=True)
self.assertEqual(kcount, len(Kernel.kernel_cnt))
if __name__ == '__main__':
unittest.main()