mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-14 09:28:04 -05:00
80 lines
3.4 KiB
Python
80 lines
3.4 KiB
Python
import unittest
|
|
|
|
from tinygrad.codegen.opt.kernel import Opt, OptOps, Kernel
|
|
from tinygrad.codegen.opt.search import bufs_from_lin, actions, beam_search
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.helpers import Context, GlobalCounters
|
|
from tinygrad.engine.realize import capturing
|
|
|
|
class TestBEAM(unittest.TestCase):
|
|
def test_dynamic_beam(self):
|
|
# TODO: make this infra globally usable
|
|
class Capture:
|
|
def __init__(self): self.captured = []
|
|
def add(self, x): self.captured.append(x)
|
|
|
|
capturing.append(Capture())
|
|
kernel_count = GlobalCounters.kernel_count
|
|
with Context(BEAM=1): Tensor.zeros(16).contiguous().realize()
|
|
assert GlobalCounters.kernel_count == kernel_count + 1
|
|
k_beam_1 = capturing[0].captured
|
|
capturing.clear()
|
|
|
|
capturing.append(Capture())
|
|
kernel_count = GlobalCounters.kernel_count
|
|
with Context(BEAM=0): Tensor.zeros(16).contiguous().realize()
|
|
assert GlobalCounters.kernel_count == kernel_count + 1
|
|
k_beam_0 = capturing[0].captured
|
|
capturing.clear()
|
|
self.assertNotEqual(k_beam_0[-1].prg.p.src, k_beam_1[-1].prg.p.src)
|
|
|
|
def test_get_kernel_actions_dedup(self):
|
|
from test.test_linearizer import helper_realized_ast
|
|
from tinygrad.codegen.opt.search import get_kernel_actions
|
|
a = Tensor.empty(4, 3)
|
|
b = Tensor.empty(3)
|
|
realized_ast, _ = helper_realized_ast(a @ b)
|
|
candidates = [
|
|
Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=4),
|
|
Opt(op=OptOps.LOCAL, axis=0, arg=0), Opt(op=OptOps.LOCAL, axis=0, arg=4),
|
|
Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=3),
|
|
Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=3),
|
|
Opt(op=OptOps.GROUPTOP, axis=0, arg=0), Opt(op=OptOps.GROUPTOP, axis=0, arg=3),
|
|
]
|
|
lins = get_kernel_actions(Kernel(realized_ast), include_0=False, candidates=candidates).values()
|
|
|
|
# ensure amt=0 are not duplicated
|
|
assert all(len(x.applied_opts) == 1 for x in lins)
|
|
kernel_actions = [x.applied_opts[0] for x in lins]
|
|
assert Opt(OptOps.UPCAST, axis=0, arg=4) not in kernel_actions, "did not de-dup UPCAST"
|
|
assert Opt(OptOps.LOCAL, axis=0, arg=4) not in kernel_actions, "did not de-dup LOCAL"
|
|
assert Opt(OptOps.UNROLL, axis=0, arg=3) not in kernel_actions, "did not de-dup UNROLL"
|
|
assert Opt(OptOps.GROUP, axis=0, arg=3) not in kernel_actions, "did not de-dup GROUP"
|
|
assert Opt(OptOps.GROUPTOP, axis=0, arg=3) not in kernel_actions, "did not de-dup GROUPTOP"
|
|
|
|
def test_get_kernel_actions_preserves_actions_state(self):
|
|
from test.test_linearizer import helper_realized_ast
|
|
from tinygrad.codegen.opt.search import get_kernel_actions
|
|
a = Tensor.rand(16, 16)
|
|
b = Tensor.rand(16, 16)
|
|
realized_ast, _ = helper_realized_ast(a @ b)
|
|
actions_before = actions.copy()
|
|
get_kernel_actions(Kernel(realized_ast))
|
|
actions_after = actions.copy()
|
|
assert actions_after == actions_before, "actions state was not preserved"
|
|
|
|
def test_beam_unnamed_kernels(self):
|
|
from test.test_linearizer import push_views
|
|
a = Tensor.rand(100)
|
|
b = Tensor.rand(100)
|
|
si = (a+b).schedule()[-1]
|
|
lin = Kernel(push_views(si.ast))
|
|
bufs = bufs_from_lin(lin)
|
|
# TODO: beam should have better instrumentation so we don't have to check this indirect thing
|
|
kcount = len(Kernel.kernel_cnt)
|
|
beam_search(lin, bufs, 3, disable_cache=True)
|
|
self.assertEqual(kcount, len(Kernel.kernel_cnt))
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|