TC_SEARCH_OVER_SHAPE to search multiple TC shapes (#8793)

* squash search over search

* refactor assert

* init benchmark

* cleaner get_kernel_actions

* cleaner get_kernel_actions

* add comment
This commit is contained in:
Ignacio Sica
2025-02-05 13:03:46 -03:00
committed by GitHub
parent e7edadda54
commit 15f94ac964
4 changed files with 32 additions and 5 deletions

View File

@@ -1112,7 +1112,9 @@ class TestLinearizer(unittest.TestCase):
# check that get_kernel_actions produces all 9 options
from tinygrad.engine.search import get_kernel_actions
tc_actions = [k for i, k in get_kernel_actions(Kernel(realized_ast), False).items() if k.applied_opts[0].op == OptOps.TC]
assert len(tc_actions) == 9, f"get_kernel_actions should contain 9 possible TC actions, only got {len(tc_actions)}"
available_tc = len([x for x in Device[Device.DEFAULT].renderer.tensor_cores if x.dtype_in == tc.dtype_in and x.dtype_out == tc.dtype_out])
assert len(tc_actions) == 9 * available_tc, f"should contain 9 possible TC actions for every available TC, got {len(tc_actions)}"
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores_unroll_phi(self):

View File

@@ -102,6 +102,24 @@ class TestBEAM(unittest.TestCase):
if Opt(OptOps.GROUPTOP, 0, 0) in actions:
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUPTOP, axis=0, arg=3)]) == 0, "did not de-dup GROUPTOP"
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_search_over_shape(self):
from test.test_linearizer import helper_realized_ast
from tinygrad.engine.search import get_kernel_actions
dtype_pairs = [(tc.dtype_in, tc.dtype_out) for tc in Device[Device.DEFAULT].renderer.tensor_cores]
multi_shape_dtype_pairs = [dts for dts in dtype_pairs if dtype_pairs.count(dts) > 1]
if len(multi_shape_dtype_pairs) == 0: raise unittest.SkipTest("only one tc available per dtype pair to search over")
for (dtype_in, dtype_out) in multi_shape_dtype_pairs:
a = Tensor.rand(16, 16, dtype=dtype_in)
b = Tensor.rand(16, 16, dtype=dtype_in)
realized_ast, _ = helper_realized_ast(a.matmul(b, acc_dtype=dtype_out))
lins = get_kernel_actions(Kernel(realized_ast)).values()
assert len(set(lin.tensor_core.dims for lin in lins if lin.tensor_core is not None)) > 1
def test_filter_global_buffer(self):
# taken from https://github.com/tinygrad/tinygrad/issues/4612
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(