From 260df1a17f974ef7cfaf91e574eecd903c72a012 Mon Sep 17 00:00:00 2001 From: Ignacio Sica Date: Wed, 29 Jan 2025 15:53:23 -0300 Subject: [PATCH] `tc_select` noop (#8801) * tc_select noop * revert changes in test --- test/external/external_test_nv.py | 2 +- test/test_linearizer.py | 11 ++++++----- test/test_linearizer_dumb.py | 2 +- test/test_linearizer_failures.py | 24 ++++++++++++------------ tinygrad/codegen/kernel.py | 25 +++++++++++++++++-------- tinygrad/engine/search.py | 4 ++-- tinygrad/helpers.py | 3 ++- 7 files changed, 41 insertions(+), 30 deletions(-) diff --git a/test/external/external_test_nv.py b/test/external/external_test_nv.py index 702f9c805c..8be629055d 100644 --- a/test/external/external_test_nv.py +++ b/test/external/external_test_nv.py @@ -26,7 +26,7 @@ class TestNV(unittest.TestCase): def test_oor_kernels(self): ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501 - opts = [Opt(op=OptOps.TC, axis=6, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=1, arg=2)] # noqa: E501 + opts = [Opt(op=OptOps.TC, axis=6, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=1, arg=2)] # noqa: E501 helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["NV"]) def test_error_on_huge_dims(self): diff --git a/test/test_linearizer.py b/test/test_linearizer.py index a65559b2e0..3bb3e77ac2 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -25,7 +25,7 @@ def helper_realized_ast(r:Union[Tensor, List[Tensor]]) -> Tuple[UOp, List[Buffer bufs = [Buffer((x).device, x.size, x.dtype).allocate() if x in s[-1].outputs else x for x in s[-1].bufs] return s[-1].ast, bufs -def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_opt:int=0): +def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0): a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in) np_a, np_b = a.numpy(), b.numpy() r = a.matmul(b, acc_dtype=dtype_out) @@ -34,7 +34,7 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi run_schedule(sched) out = r.numpy() k = Kernel(realized_ast) - k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt) + k.apply_tensor_cores(1, axis=axis, tc_select=tc_select, tc_opt=tc_opt) k.linearize() assert len([uop for uop in k.uops if uop.op is Ops.WMMA]) > 0, "tensor core not triggered" assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included" @@ -44,13 +44,14 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi else: tc_atol, tc_rtol = 5e-3, 1e-4 np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol) -def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_opt:int=0, ensure_triggered:bool=True): +def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0, + ensure_triggered:bool=True): a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in) r = a.matmul(b, acc_dtype=dtype_out) sched = r.schedule() realized_ast = sched[-1].ast k = Kernel(realized_ast) - k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt) + k.apply_tensor_cores(1, axis=axis, tc_select=tc_select, tc_opt=tc_opt) k.linearize() wmmas = len([uop for uop in k.uops if uop.op is Ops.WMMA]) tcs = len([x for x in k.applied_opts if x.op is OptOps.TC]) @@ -1959,7 +1960,7 @@ class TestKernelOpts(unittest.TestCase): UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)) # noqa: E501 k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) with self.assertRaises(KernelOptError): - k.apply_opt(Opt(OptOps.TC, 0, 1)) + k.apply_opt(Opt(OptOps.TC, 0, (-1, 1))) @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_tensor_core_opts(self): diff --git a/test/test_linearizer_dumb.py b/test/test_linearizer_dumb.py index 2408bc2626..85e88bb6d9 100644 --- a/test/test_linearizer_dumb.py +++ b/test/test_linearizer_dumb.py @@ -35,7 +35,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.half, 0.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=2, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0)] + opts = [Opt(op=OptOps.TC, axis=2, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=2, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0)] k = Kernel(ast, opts=Device["METAL"].renderer) k.required_optimizations() for opt in opts: k.apply_opt(opt) diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 3dffc05c63..b1b385029b 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -740,7 +740,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=0, arg=1), Opt(op=OptOps.PADTO, axis=2, arg=32)] + opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 1)), Opt(op=OptOps.PADTO, axis=2, arg=32)] helper_test_lin(Kernel(ast), opts, failed_platforms=[], atol=1.0) def test_failure_30(self): @@ -800,7 +800,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 256, 14, 14, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=2, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=7), Opt(op=OptOps.UNROLL, axis=1, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=16)] + opts = [Opt(op=OptOps.TC, axis=2, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=2, arg=7), Opt(op=OptOps.UNROLL, axis=1, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=16)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[], atol=0.1, rtol=0.05) def test_failure_33(self): @@ -861,7 +861,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 2, 5), strides=(0, 0, 10, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), ast_const(dtypes.float, 0.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=0, arg=2), Opt(op=OptOps.UNROLL, axis=0, arg=0)] if unroll else [Opt(op=OptOps.TC, axis=0, arg=2)] + opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 2)), Opt(op=OptOps.UNROLL, axis=0, arg=0)] if unroll else [Opt(op=OptOps.TC, axis=0, arg=(-1, 2))] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) def test_failure_35(self): self.test_failure_34(True) @@ -910,7 +910,7 @@ class TestLinearizerFailures(unittest.TestCase): ast_const(dtypes.float, 0.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) for axis in [0,1,2,3,4,5]: - opts = [Opt(op=OptOps.TC, axis=axis, arg=2)] + opts = [Opt(op=OptOps.TC, axis=axis, arg=(-1, 2))] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) def test_failure_38(self): @@ -930,7 +930,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 32, 24, 24, 1, 5, 5, 256), strides=(18432, 0, 576, 24, 1, 0, 0, 0, 36864), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) for axis in [0,1,3,4]: - opts = [Opt(op=OptOps.TC, axis=axis, arg=2)] + opts = [Opt(op=OptOps.TC, axis=axis, arg=(-1, 2))] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @unittest.skip("very slow, similar to test_failure_37") @@ -958,7 +958,7 @@ class TestLinearizerFailures(unittest.TestCase): ast_const(dtypes.float, 0.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) for axis in [0,1,2,3,4,5]: - opts = [Opt(op=OptOps.TC, axis=axis, arg=2)] + opts = [Opt(op=OptOps.TC, axis=axis, arg=(-1, 2))] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) def test_failure_40(self): @@ -996,7 +996,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) - opts=[Opt(op=OptOps.TC, axis=5, arg=2), Opt(op=OptOps.UNROLL, axis=0, arg=0)] + opts=[Opt(op=OptOps.TC, axis=5, arg=(-1, 2)), Opt(op=OptOps.UNROLL, axis=0, arg=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"], atol=0.02) # llama3 8B failure with BEAM=2 https://github.com/tinygrad/tinygrad/actions/runs/10150118124/job/28066519425#step:14:1, these don't compile @@ -1151,7 +1151,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 56, 56, 256, 1, 1, 256), strides=(0, 0, 3136, 56, 1, 0, 0, 0, 200704), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=0, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=2)] + opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 0)), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) def test_failure_49(self): @@ -1168,7 +1168,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 10), strides=(0, 1, 6), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=0, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=2)] + opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=0, arg=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) def test_failure_50(self): @@ -1229,7 +1229,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), UOp(Ops.CONST, dtypes.half, arg=-1.4426950408889634, src=( x6,)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=0, arg=2)] + opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 2))] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) @unittest.skipIf(CI and Device.DEFAULT in {"METAL"}, "hangs metal gpu CI") @@ -1251,7 +1251,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=0, arg=2), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16)] + opts = [Opt(op=OptOps.TC, axis=0, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) def test_failure_53(self): @@ -1305,7 +1305,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=2, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=7), Opt(op=OptOps.UPCAST, axis=1, arg=2)] + opts = [Opt(op=OptOps.TC, axis=2, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=2, arg=7), Opt(op=OptOps.UPCAST, axis=1, arg=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["HIP", "AMD"]) @unittest.skipIf(CI and Device.DEFAULT in {"METAL"}, "hangs metal gpu CI") diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index a20455492e..6f40d0b4b2 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -11,7 +11,7 @@ from tinygrad.device import Device from tinygrad.renderer import Renderer, TensorCore, ProgramSpec from tinygrad.dtype import ImageDType from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap, ContextVar -from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX, CAPTURE_PROCESS_REPLAY +from tinygrad.helpers import DEBUG, TC_SELECT, TC_OPT, USE_TC, AMX, CAPTURE_PROCESS_REPLAY from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import strides_for_shape from tinygrad.codegen.linearize import linearize_uop @@ -291,9 +291,10 @@ class Kernel: if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc) return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads) - def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool: + def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> bool: if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD: - for tc in self.opts.tensor_cores: + tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]] + for tc in tensor_cores: tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops] # can only fuse reduces with the same tc options assert all_same(tensor_core_opts) @@ -312,8 +313,9 @@ class Kernel: return True return False - def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[list[Opt]]=None, axis:int=0, tc_opt:Optional[int]=None) -> bool: - """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false. + def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[list[Opt]]=None, axis:int=0, tc_select:Optional[int]=None, + tc_opt:Optional[int]=None) -> bool: + """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false. Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N). Keyword arguments: @@ -322,15 +324,19 @@ class Kernel: 1: enable tensor cores 2: apply tensor core shape but don't use UOp.WMMA extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None) + tc_select -- specifies which tensor core(s) to use for optimization (default -1) + -1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes) + [0-N]: uses only the n'th tensor core available; useful for search tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise) 0: applies to only kernels with a single reduce axis and direct UOps.LOAD into Ops.MUL 1: allows kernels with multiple reduce axes and also multiplication of UOps.CAST'd buffers 2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed """ + if tc_select is None: tc_select = TC_SELECT.value if tc_opt is None: tc_opt = TC_OPT.value if not self.opts.tensor_cores and use_tensor_cores != 2: return False try: # check TC first and apply hand-coded opts if successful - self.apply_opt(Opt(OptOps.TC, axis, tc_opt)) + self.apply_opt(Opt(OptOps.TC, axis, (tc_select, tc_opt))) if (tc_opts:=self.tensor_core_opts) is not None: if extra_opts is not None: @@ -353,9 +359,12 @@ class Kernel: if opt.op is OptOps.TC: check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine - check(opt.axis is not None and opt.arg is not None, "tensor core opts must have an axis and arg") check((use_tensor_cores:=USE_TC.value) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2") - check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.arg)), "no tensor core available") + check(opt.axis is not None, "tensor core opts must have an axis") + check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 2, "tensor core opts must have tc_select and tc_opt") + check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select") + check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt") + check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt), "no tensor core available") self.applied_opts.append(opt) return diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index ebc8bb778a..87ff15aea8 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -19,8 +19,8 @@ actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29, actions += [Opt(op=OptOps.GROUP, axis=axis, arg=amt) for amt in [0,4,8,16] for axis in range(3)] if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, arg=amt) for amt in [32] for axis in range(7)] actions += [Opt(op=OptOps.LOCAL, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=6, arg=2)] -actions += [Opt(op=OptOps.TC, axis=0, arg=0)] -actions += [Opt(op=OptOps.TC, axis=axis, arg=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce) +actions += [Opt(op=OptOps.TC, axis=0, arg=(-1, 0))] +actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2))) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce) actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)] if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)] diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 9c055acd87..d0503e484a 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -105,7 +105,8 @@ class ContextVar: DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0) JIT = ContextVar("JIT", 2 if platform.system() == 'Darwin' and ('Intel' in platform.processor() or 'i386' in platform.processor()) else 1) WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1) -USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0), ContextVar("TRANSCENDENTAL", 1) +USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0) +TRANSCENDENTAL = ContextVar("TRANSCENDENTAL", 1) FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0) SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1) PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)