Files
tinygrad/test/test_search.py
chenyu c86adabe15 time with real global buffers in search (#4621)
* filter fake buffers in search

* test that

* update test
2024-05-17 12:36:23 -04:00

87 lines
6.5 KiB
Python

import unittest
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search
from tinygrad.device import Device, Buffer
from tinygrad.ops import LazyOp, LoadOps, BufferOps, ReduceOps, BinaryOps, MemBuffer, ConstBuffer
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from tinygrad.helpers import Context
from tinygrad.engine.realize import capturing
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
class TestTimeLinearizer(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT in {"AMD", "NV"}, "Tries to open HSA/CUDA. #4607")
def test_reasonable_time(self):
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast[0].op not in LoadOps][0]
out = Buffer(Device.DEFAULT, si.outputs[0].size, si.outputs[0].dtype).allocate()
memops = {x.arg.idx:x.arg.st.real_size() for x in si.ast[0].lazyops if x.op is BufferOps.LOAD}
rawbufs = [out] + [Buffer(Device.DEFAULT, memops[i], x.dtype).allocate() for i,x in enumerate(si.inputs, start=len(si.outputs))]
tm = time_linearizer(Linearizer(*si.ast), rawbufs, allow_test_size=False, cnt=10)
assert tm > 0 and tm != float('inf')
@unittest.skipIf(Device.DEFAULT in {"AMD", "NV"}, "Tries to open HSA/CUDA. #4607")
def test_bufs_from_lin(self):
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast[0].op not in LoadOps][0]
rawbufs = bufs_from_lin(lin:=Linearizer(*si.ast))
assert len(rawbufs) == len(lin.membufs)
assert all(r is not None for r in rawbufs)
assert all(isinstance(r, Buffer) for r in rawbufs)
assert all(r.size > 0 for r in rawbufs)
class TestBEAM(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT in {"AMD", "NV"}, "Tries to open HSA/CUDA. #4607")
def test_dynamic_beam(self):
# TODO: make this infra globally usable
class Capture:
def __init__(self): self.captured = []
def add(self, x): self.captured.append(x)
capturing.append(Capture())
with Context(BEAM=1): Tensor.zeros(16).contiguous().realize()
k_beam_1 = capturing[0].captured
capturing.clear()
capturing.append(Capture())
with Context(BEAM=0): Tensor.zeros(16).contiguous().realize()
k_beam_0 = capturing[0].captured
capturing.clear()
assert k_beam_0[-1].prg.p.src != k_beam_1[-1].prg.p.src
def test_get_linearizer_actions(self):
from test.test_linearizer import helper_realized_ast
a = Tensor.rand(4, 3)
b = Tensor.rand(3)
realized_ast, _ = helper_realized_ast(a @ b)
from tinygrad.engine.search import get_linearizer_actions
lins = get_linearizer_actions(Linearizer(realized_ast), False).values()
# ensure amt=0 are not duplicated
if Opt(OptOps.UPCAST, 0, 0) in actions:
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.UPCAST, axis=0, amt=4)]) == 0, "did not de-dup UPCAST"
if Opt(OptOps.LOCAL, 0, 0) in actions:
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.LOCAL, axis=0, amt=4)]) == 0, "did not de-dup LOCAL"
if Opt(OptOps.UNROLL, 0, 0) in actions:
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.UNROLL, axis=0, amt=3)]) == 0, "did not de-dup UNROLL"
if Opt(OptOps.GROUP, 0, 0) in actions:
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUP, axis=0, amt=3)]) == 0, "did not de-dup GROUP"
if Opt(OptOps.GROUPTOP, 0, 0) in actions:
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUPTOP, axis=0, amt=3)]) == 0, "did not de-dup GROUPTOP"
@unittest.skipIf(Device.DEFAULT in {"NV"}, "Tries to open CUDA. #4607")
def test_filter_global_buffer(self):
# taken from https://github.com/tinygrad/tinygrad/issues/4612
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.MAX, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=4, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=5, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=6, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.4285714285714286, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
lin = Linearizer(ast)
bufs = bufs_from_lin(lin)
best_lin = beam_search(lin, bufs, 3)
assert best_lin
# need disable_cache to trigger.
tm = time_linearizer(best_lin, bufs, allow_test_size=False, cnt=2, disable_cache=True)
assert tm
if __name__ == '__main__':
unittest.main()