mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
TC_SEARCH_OVER_SHAPE to search multiple TC shapes (#8793)
* squash search over search * refactor assert * init benchmark * cleaner get_kernel_actions * cleaner get_kernel_actions * add comment
This commit is contained in:
@@ -1112,7 +1112,9 @@ class TestLinearizer(unittest.TestCase):
|
||||
# check that get_kernel_actions produces all 9 options
|
||||
from tinygrad.engine.search import get_kernel_actions
|
||||
tc_actions = [k for i, k in get_kernel_actions(Kernel(realized_ast), False).items() if k.applied_opts[0].op == OptOps.TC]
|
||||
assert len(tc_actions) == 9, f"get_kernel_actions should contain 9 possible TC actions, only got {len(tc_actions)}"
|
||||
|
||||
available_tc = len([x for x in Device[Device.DEFAULT].renderer.tensor_cores if x.dtype_in == tc.dtype_in and x.dtype_out == tc.dtype_out])
|
||||
assert len(tc_actions) == 9 * available_tc, f"should contain 9 possible TC actions for every available TC, got {len(tc_actions)}"
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tensor_cores_unroll_phi(self):
|
||||
|
||||
@@ -102,6 +102,24 @@ class TestBEAM(unittest.TestCase):
|
||||
if Opt(OptOps.GROUPTOP, 0, 0) in actions:
|
||||
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUPTOP, axis=0, arg=3)]) == 0, "did not de-dup GROUPTOP"
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_search_over_shape(self):
|
||||
from test.test_linearizer import helper_realized_ast
|
||||
from tinygrad.engine.search import get_kernel_actions
|
||||
|
||||
dtype_pairs = [(tc.dtype_in, tc.dtype_out) for tc in Device[Device.DEFAULT].renderer.tensor_cores]
|
||||
multi_shape_dtype_pairs = [dts for dts in dtype_pairs if dtype_pairs.count(dts) > 1]
|
||||
|
||||
if len(multi_shape_dtype_pairs) == 0: raise unittest.SkipTest("only one tc available per dtype pair to search over")
|
||||
|
||||
for (dtype_in, dtype_out) in multi_shape_dtype_pairs:
|
||||
a = Tensor.rand(16, 16, dtype=dtype_in)
|
||||
b = Tensor.rand(16, 16, dtype=dtype_in)
|
||||
realized_ast, _ = helper_realized_ast(a.matmul(b, acc_dtype=dtype_out))
|
||||
|
||||
lins = get_kernel_actions(Kernel(realized_ast)).values()
|
||||
assert len(set(lin.tensor_core.dims for lin in lins if lin.tensor_core is not None)) > 1
|
||||
|
||||
def test_filter_global_buffer(self):
|
||||
# taken from https://github.com/tinygrad/tinygrad/issues/4612
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
|
||||
Reference in New Issue
Block a user