From a9a1fa6bbfcb7bb66549678fa4e48540f0cbb0f0 Mon Sep 17 00:00:00 2001 From: Francis Lam Date: Mon, 29 Apr 2024 16:15:39 -0700 Subject: [PATCH] wmma: add reduce axis choice to TC action space (#4328) * wmma: add reduce axis choice to TC action space * add test for TC multi-reduce axis choice --- test/test_linearizer.py | 58 +++++++++++++++++++++++++++---------- tinygrad/codegen/kernel.py | 12 ++++---- tinygrad/features/search.py | 2 +- 3 files changed, 50 insertions(+), 22 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index f9561f8520..59ce46b829 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -15,7 +15,15 @@ from tinygrad.helpers import prod, Context, getenv from tinygrad.dtype import DType, dtypes from tinygrad.codegen.uops import UOpGraph -def helper_tc_allclose(m:int, k:int, n:int, dtype_in:DType, dtype_out:DType, tc_opt:int): +def helper_realized_ast(r:Tensor): + s = create_schedule([r.lazydata]) + run_schedule(s[:-1]) # run all kernels except the last one + # now all input LazyBuffers buffers in s[-1] should be realized + # allocate an output buffer + output_buffer = Buffer((out:=s[-1].outputs[0]).device, out.size, out.dtype).allocate() + return s[-1].ast[0], [output_buffer] + list(s[-1].inputs) + +def helper_tc_allclose(m:int, k:int, n:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_opt:int=0): a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in) np_a, np_b = a.numpy(), b.numpy() r = a.matmul(b, acc_dtype=dtype_out) @@ -24,7 +32,7 @@ def helper_tc_allclose(m:int, k:int, n:int, dtype_in:DType, dtype_out:DType, tc_ run_schedule(sched) out = r.numpy() k = Linearizer(realized_ast) - k.apply_tensor_cores(1, tc_opt=tc_opt) + k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt) k.linearize() assert len([uop for uop in k.uops if uop.uop is UOps.WMMA]) > 0, "tensor core not triggered" assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included" @@ -34,13 +42,13 @@ def helper_tc_allclose(m:int, k:int, n:int, dtype_in:DType, dtype_out:DType, tc_ else: tc_atol, tc_rtol = 5e-3, 1e-4 np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol) -def helper_tc_ensure_uops_and_opts_count(m:int, k:int, n:int, dtype_in:DType, dtype_out:DType, tc_opt:int, ensure_triggered:bool=True): +def helper_tc_ensure_uops_and_opts_count(m:int, k:int, n:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_opt:int=0, ensure_triggered:bool=True): a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in) r = a.matmul(b, acc_dtype=dtype_out) sched = create_schedule([r.lazydata]) realized_ast = sched[-1].ast[0] k = Linearizer(realized_ast) - k.apply_tensor_cores(1, tc_opt=tc_opt) + k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt) k.linearize() wmmas = len([uop for uop in k.uops if uop.uop is UOps.WMMA]) tcs = len([x for x in k.applied_opts if x.op is OptOps.TC]) @@ -214,7 +222,7 @@ class TestLinearizer(unittest.TestCase): self.skipTest("device doesn't have tensor cores") for tc in tensor_cores[Device[Device.DEFAULT].compiler.compiler_opts.device]: if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue - helper_tc_allclose(tc.dims[1], tc.dims[2], tc.dims[0], tc.dtype_in, tc.dtype_out, tc_opt=0) + helper_tc_allclose(tc.dims[1], tc.dims[2], tc.dims[0], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0) def test_tensor_cores_padded(self): if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores: @@ -224,7 +232,8 @@ class TestLinearizer(unittest.TestCase): pad = 1 # check that TC is triggered for TC_OPT=2 - helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad,tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=True) + helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad, + tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=True) # check that TC is not triggered for TC_OPT<2 helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad, @@ -240,6 +249,33 @@ class TestLinearizer(unittest.TestCase): # check correctness helper_tc_allclose(tc.dims[1]+pad, tc.dims[2]+pad, tc.dims[0]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2) + def test_tensor_cores_multi_reduce(self): + if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores: + self.skipTest("device doesn't have tensor cores") + for tc in tensor_cores[Device[Device.DEFAULT].compiler.compiler_opts.device]: + if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue + # this will be a M=G16, N=G32, M=G16, M=G16, K=R16, K=R16, K=R16 with 9 choices of TC MNK axes + golden_result = None + for axis in range(9): + a = Tensor.rand(16, 16, 29, 29, dtype=tc.dtype_in).realize() + b = Tensor.rand(32, 16, 16, 16, dtype=tc.dtype_in).realize() + c = a.conv2d(b, padding=1, acc_dtype=tc.dtype_out) + realized_ast, real_bufs = helper_realized_ast(c) + + k = Linearizer(realized_ast) + k.apply_tensor_cores(1, axis=axis, tc_opt=2) + k.linearize() + assert len([uop for uop in k.uops if uop.uop is UOps.WMMA]) > 0, "tensor core not triggered" + assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included" + + prg = Device[Device.DEFAULT].to_program(k) + real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled + prg.exec(real_bufs) + result = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np) + + if golden_result is None: golden_result = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np) + np.testing.assert_allclose(result, golden_result, atol=0.1, rtol=0.15) + def test_limit_dims_to_max_5d_global(self): t = Tensor.empty(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1 sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps] @@ -315,14 +351,6 @@ class TestLinearizer(unittest.TestCase): helper(Tensor.arange(256), max_ops=2) helper(Tensor.arange(255), max_ops=0) -def helper_realized_ast(r:Tensor): - s = create_schedule([r.lazydata]) - run_schedule(s[:-1]) # run all kernels except the last one - # now all input LazyBuffers buffers in s[-1] should be realized - # allocate an output buffer - output_buffer = Buffer((out:=s[-1].outputs[0]).device, out.size, out.dtype).allocate() - return s[-1].ast[0], [output_buffer] + list(s[-1].inputs) - @unittest.skipUnless(Device[Device.DEFAULT].compiler.compiler_opts.supports_float4, "need backends that support float4") class TestFloat4(unittest.TestCase): @staticmethod @@ -554,7 +582,7 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False, atol=1e-4, rtol=1e- def check_opt(opts, create_k, to_prg, expected_color_size): k = create_k() if apply_tc: - assert k.apply_tensor_cores(1, opts), "no tensor core triggered" + assert k.apply_tensor_cores(1, extra_opts=opts), "no tensor core triggered" else: for opt in opts: k.apply_opt(opt) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 8a9f3685e9..994ab30569 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -352,11 +352,11 @@ class Kernel: axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0] if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): continue - axis_choices = list(itertools.product(axis_buf0, axis_buf1)) + axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len))) if not(axis < len(axis_choices)): continue - s0, s1 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0] # s0 is n, s1 is m - axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, self.first_reduce]) if self.full_shape[x]%tc.dims[i] != 0] + s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k + axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0] if axis_pads and (opt_level < 2): continue # tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern @@ -366,7 +366,7 @@ class Kernel: try: for axis, dim in axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail except KernelOptError: continue - self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]), append_opt=False) + self.apply_opt(Opt(OptOps.UNROLL, s2-self.first_reduce, tc.dims[2]), append_opt=False) for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False) for (tc_dim, tc_amt) in tc.threads: @@ -379,7 +379,7 @@ class Kernel: return False - def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, tc_opt:Optional[int]=getenv("TC_OPT")) -> bool: + def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, axis:int=0, tc_opt:int=getenv("TC_OPT")) -> bool: """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false. Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N). @@ -396,7 +396,7 @@ class Kernel: """ if not self.opts.has_tensor_cores and use_tensor_cores != 2: return False try: # check TC first and apply hand-coded opts if successful - self.apply_opt(Opt(OptOps.TC, 0, tc_opt)) + self.apply_opt(Opt(OptOps.TC, axis, tc_opt)) if (tc_opts:=self.tensor_core_opts) is not None: if extra_opts is not None: diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index 9892b0b524..daf75b5f59 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -17,7 +17,7 @@ actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,28,29, actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)] actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)] actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)] -actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(4)] +actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce) if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)] def _get_test_global_size(global_size, max_global_size, var_vals):