From 15f94ac964e0760639b94fc157643429210a73cd Mon Sep 17 00:00:00 2001 From: Ignacio Sica Date: Wed, 5 Feb 2025 13:03:46 -0300 Subject: [PATCH] TC_SEARCH_OVER_SHAPE to search multiple TC shapes (#8793) * squash search over search * refactor assert * init benchmark * cleaner get_kernel_actions * cleaner get_kernel_actions * add comment --- test/test_linearizer.py | 4 +++- test/test_search.py | 18 ++++++++++++++++++ tinygrad/engine/search.py | 13 ++++++++++--- tinygrad/helpers.py | 2 +- 4 files changed, 32 insertions(+), 5 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 3bb3e77ac2..fa3df34608 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1112,7 +1112,9 @@ class TestLinearizer(unittest.TestCase): # check that get_kernel_actions produces all 9 options from tinygrad.engine.search import get_kernel_actions tc_actions = [k for i, k in get_kernel_actions(Kernel(realized_ast), False).items() if k.applied_opts[0].op == OptOps.TC] - assert len(tc_actions) == 9, f"get_kernel_actions should contain 9 possible TC actions, only got {len(tc_actions)}" + + available_tc = len([x for x in Device[Device.DEFAULT].renderer.tensor_cores if x.dtype_in == tc.dtype_in and x.dtype_out == tc.dtype_out]) + assert len(tc_actions) == 9 * available_tc, f"should contain 9 possible TC actions for every available TC, got {len(tc_actions)}" @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_tensor_cores_unroll_phi(self): diff --git a/test/test_search.py b/test/test_search.py index d0d6cf9114..d6bda9aa5d 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -102,6 +102,24 @@ class TestBEAM(unittest.TestCase): if Opt(OptOps.GROUPTOP, 0, 0) in actions: assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUPTOP, axis=0, arg=3)]) == 0, "did not de-dup GROUPTOP" + @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") + def test_search_over_shape(self): + from test.test_linearizer import helper_realized_ast + from tinygrad.engine.search import get_kernel_actions + + dtype_pairs = [(tc.dtype_in, tc.dtype_out) for tc in Device[Device.DEFAULT].renderer.tensor_cores] + multi_shape_dtype_pairs = [dts for dts in dtype_pairs if dtype_pairs.count(dts) > 1] + + if len(multi_shape_dtype_pairs) == 0: raise unittest.SkipTest("only one tc available per dtype pair to search over") + + for (dtype_in, dtype_out) in multi_shape_dtype_pairs: + a = Tensor.rand(16, 16, dtype=dtype_in) + b = Tensor.rand(16, 16, dtype=dtype_in) + realized_ast, _ = helper_realized_ast(a.matmul(b, acc_dtype=dtype_out)) + + lins = get_kernel_actions(Kernel(realized_ast)).values() + assert len(set(lin.tensor_core.dims for lin in lins if lin.tensor_core is not None)) > 1 + def test_filter_global_buffer(self): # taken from https://github.com/tinygrad/tinygrad/issues/4612 ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index 87ff15aea8..443d22f0e9 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -5,7 +5,7 @@ from dataclasses import replace from tinygrad.ops import UOp, Ops, Variable, sym_infer from tinygrad.device import Device, Buffer, Compiler from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name -from tinygrad.helpers import IGNORE_BEAM_CACHE +from tinygrad.helpers import IGNORE_BEAM_CACHE, TC_SEARCH_OVER_SHAPE from tinygrad.dtype import ImageDType, PtrDType from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError from tinygrad.tensor import Tensor @@ -102,8 +102,15 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]: # get dictionary of all possible actions def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]: - acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024) - for i,a in enumerate(actions): + acted_lins, max_up, max_lcl, kernel_actions = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024), actions + + if TC_SEARCH_OVER_SHAPE and len(lin.applied_opts) == 0: # tensor core opts must be first + for i, action in enumerate(kernel_actions): + if action.op == OptOps.TC and (tc_arg := cast(tuple, action.arg))[0] == -1: + # replace every tc_action with default tc with one tc_action for each available tc + kernel_actions[i:i+1] = [Opt(op=OptOps.TC, axis=action.axis, arg=(tc_select, tc_arg[1])) for tc_select,_ in enumerate(lin.opts.tensor_cores)] + + for i,a in enumerate(kernel_actions): if a.axis is not None and a.op is not OptOps.TC: if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in actions): continue lin2 = lin.copy() diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 42fb309546..09054b4419 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -106,7 +106,7 @@ DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), Cont JIT = ContextVar("JIT", 2 if platform.system() == 'Darwin' and ('Intel' in platform.processor() or 'i386' in platform.processor()) else 1) WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1) USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0) -TRANSCENDENTAL = ContextVar("TRANSCENDENTAL", 1) +TRANSCENDENTAL, TC_SEARCH_OVER_SHAPE = ContextVar("TRANSCENDENTAL", 1), ContextVar("TC_SEARCH_OVER_SHAPE", 1) FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0) SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1) PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)