Files
tinygrad/test/test_search.py
Ahmed Harmouche 133cacadde Autogen webgpu dawn, removing wgpu-py dependency (f16 support part 1) (#8646)
* Switch to dawn, all tests passing locally

* Use dawn-python

* Skip failing test

* Skip midcast and fix timestamp on metal ci

* Autogen webgpu

* Try fetch dawn lib again

* /usr/lib

* Without lib prefix

* Test autogen diff

* Delete webgpu support, move everything to ops_webgpu

* mypy fix

* Simplify, refactor

* Line savings

* No ResultContainer

* Type annotation for result

* Some more simplifications

* Why was this explicit sync used at all?

* Refactor: delete functions that are only used once

* Create shader module inline

* Clear unit tests cache, maybe that solves it

* That wasn't it

* Try deleting cache to pass failing weight compare

* weights_only=False for pytorch 2.6

* Simplify ctype array creation

* Remove nanosecond precision timestamps

* Simplify error handling

* Refactor, add back type annotations

* Deleted custom submit function, refactor

* read_buffer simplify

* Fix use after free, refactor

* Simplify supported_features

* Runtime docs

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2025-02-07 15:16:59 +08:00

190 lines
10 KiB
Python

import unittest
from test.helpers import ast_const
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.codegen.kernel import Kernel
from tinygrad.ops import UOp, Ops
from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search
from tinygrad.device import Device, Buffer
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from tinygrad.helpers import Context, GlobalCounters
from tinygrad.engine.realize import capturing
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
class TestTimeLinearizer(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WebGPU timestamps are low precision, tm is 0")
def test_reasonable_time(self):
a = Tensor([1,2,3,4]).realize()
si = (a+1).schedule()[0]
out = Buffer(Device.DEFAULT, si.outputs[0].size, si.outputs[0].dtype).allocate()
memops = {x.src[0].arg:x.src[-1].arg.real_size() for x in si.ast.toposort if x.op is Ops.LOAD}
rawbufs = [out] + [Buffer(Device.DEFAULT, memops[i], x.dtype).allocate() for i,x in enumerate(si.inputs, start=len(si.outputs))]
tm = time_linearizer(Kernel(si.ast), rawbufs, allow_test_size=False, cnt=10, disable_cache=True)
assert tm > 0 and tm != float('inf')
def test_bufs_from_lin(self):
a = Tensor([1,2,3,4]).realize()
si = (a+1).schedule()[0]
rawbufs = bufs_from_lin(lin:=Kernel(si.ast))
assert len(rawbufs) == len(lin.membufs) == 2
assert all(r is not None for r in rawbufs)
assert all(isinstance(r, Buffer) for r in rawbufs)
assert all(r.size > 0 for r in rawbufs)
def test_bufs_from_lin_alt(self):
a = Tensor.randn(4, 4).realize()
b = a+a[0]
si = b.schedule()[0]
rawbufs = bufs_from_lin(k:=Kernel(si.ast))
assert len(rawbufs) == len(k.membufs) == 2
assert all(r is not None for r in rawbufs)
assert all(isinstance(r, Buffer) for r in rawbufs)
assert all(r.size > 0 for r in rawbufs)
def test_kernel_count(self):
"""
Ensure that the kernel count is not incremented by time_linearizer when clearing l2
"""
# ast of Tensor.zeros(16).contiguous().realize()
ast = UOp(Ops.SINK, src=(
UOp(Ops.STORE, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),))),
ast_const(dtypes.float, 0.0, st_src=(
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(0,), offset=0, mask=None, contiguous=False),))),)),)),))
lin = Kernel(ast)
bufs = bufs_from_lin(lin)
kernel_count = GlobalCounters.kernel_count
time_linearizer(lin, bufs, allow_test_size=False, cnt=2, disable_cache=True, clear_l2=True)
assert GlobalCounters.kernel_count == kernel_count, "kernel count was incremented by time_linearizer"
class TestBEAM(unittest.TestCase):
def test_dynamic_beam(self):
# TODO: make this infra globally usable
class Capture:
def __init__(self): self.captured = []
def add(self, x): self.captured.append(x)
capturing.append(Capture())
kernel_count = GlobalCounters.kernel_count
with Context(BEAM=1): Tensor.zeros(16).contiguous().realize()
assert GlobalCounters.kernel_count == kernel_count + 1
k_beam_1 = capturing[0].captured
capturing.clear()
capturing.append(Capture())
kernel_count = GlobalCounters.kernel_count
with Context(BEAM=0): Tensor.zeros(16).contiguous().realize()
assert GlobalCounters.kernel_count == kernel_count + 1
k_beam_0 = capturing[0].captured
capturing.clear()
self.assertNotEqual(k_beam_0[-1].prg.p.src, k_beam_1[-1].prg.p.src)
def test_get_kernel_actions(self):
from test.test_linearizer import helper_realized_ast
a = Tensor.rand(4, 3)
b = Tensor.rand(3)
realized_ast, _ = helper_realized_ast(a @ b)
from tinygrad.engine.search import get_kernel_actions
lins = get_kernel_actions(Kernel(realized_ast), False).values()
# ensure amt=0 are not duplicated
if Opt(OptOps.UPCAST, 0, 0) in actions:
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.UPCAST, axis=0, arg=4)]) == 0, "did not de-dup UPCAST"
if Opt(OptOps.LOCAL, 0, 0) in actions:
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.LOCAL, axis=0, arg=4)]) == 0, "did not de-dup LOCAL"
if Opt(OptOps.UNROLL, 0, 0) in actions:
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.UNROLL, axis=0, arg=3)]) == 0, "did not de-dup UNROLL"
if Opt(OptOps.GROUP, 0, 0) in actions:
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUP, axis=0, arg=3)]) == 0, "did not de-dup GROUP"
if Opt(OptOps.GROUPTOP, 0, 0) in actions:
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUPTOP, axis=0, arg=3)]) == 0, "did not de-dup GROUPTOP"
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_search_over_shape(self):
from test.test_linearizer import helper_realized_ast
from tinygrad.engine.search import get_kernel_actions
dtype_pairs = [(tc.dtype_in, tc.dtype_out) for tc in Device[Device.DEFAULT].renderer.tensor_cores]
multi_shape_dtype_pairs = [dts for dts in dtype_pairs if dtype_pairs.count(dts) > 1]
if len(multi_shape_dtype_pairs) == 0: raise unittest.SkipTest("only one tc available per dtype pair to search over")
for (dtype_in, dtype_out) in multi_shape_dtype_pairs:
a = Tensor.rand(16, 16, dtype=dtype_in)
b = Tensor.rand(16, 16, dtype=dtype_in)
realized_ast, _ = helper_realized_ast(a.matmul(b, acc_dtype=dtype_out))
lins = get_kernel_actions(Kernel(realized_ast)).values()
assert len(set(lin.tensor_core.dims for lin in lins if lin.tensor_core is not None)) > 1
def test_get_kernel_actions_preserves_actions_state(self):
from test.test_linearizer import helper_realized_ast
from tinygrad.engine.search import get_kernel_actions
a = Tensor.rand(16, 16)
b = Tensor.rand(16, 16)
realized_ast, _ = helper_realized_ast(a @ b)
actions_before = actions.copy()
get_kernel_actions(Kernel(realized_ast))
actions_after = actions.copy()
assert actions_after == actions_before, "actions state was not preserved"
def test_filter_global_buffer(self):
# taken from https://github.com/tinygrad/tinygrad/issues/4612
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.MAX, (1,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)), # noqa: E501
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
ast_const(dtypes.float, 1.4285714285714286, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) # noqa: E501
lin = Kernel(ast)
bufs = bufs_from_lin(lin)
best_lin = beam_search(lin, bufs, 2)
assert best_lin
# need disable_cache to trigger.
tm = time_linearizer(best_lin, bufs, allow_test_size=False, cnt=2, disable_cache=True)
assert tm
def test_beam_unnamed_kernels(self):
a = Tensor.rand(100)
b = Tensor.rand(100)
si = (a+b).schedule()[-1]
lin = Kernel(si.ast)
bufs = bufs_from_lin(lin)
# TODO: beam should have better instrumentation so we don't have to check this indirect thing
kcount = len(Kernel.kernel_cnt)
beam_search(lin, bufs, 3, disable_cache=True)
self.assertEqual(kcount, len(Kernel.kernel_cnt))
if __name__ == '__main__':
unittest.main()