mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
wmma: add reduce axis choice to TC action space (#4328)
* wmma: add reduce axis choice to TC action space * add test for TC multi-reduce axis choice
This commit is contained in:
@@ -15,7 +15,15 @@ from tinygrad.helpers import prod, Context, getenv
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
|
||||
def helper_tc_allclose(m:int, k:int, n:int, dtype_in:DType, dtype_out:DType, tc_opt:int):
|
||||
def helper_realized_ast(r:Tensor):
|
||||
s = create_schedule([r.lazydata])
|
||||
run_schedule(s[:-1]) # run all kernels except the last one
|
||||
# now all input LazyBuffers buffers in s[-1] should be realized
|
||||
# allocate an output buffer
|
||||
output_buffer = Buffer((out:=s[-1].outputs[0]).device, out.size, out.dtype).allocate()
|
||||
return s[-1].ast[0], [output_buffer] + list(s[-1].inputs)
|
||||
|
||||
def helper_tc_allclose(m:int, k:int, n:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_opt:int=0):
|
||||
a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in)
|
||||
np_a, np_b = a.numpy(), b.numpy()
|
||||
r = a.matmul(b, acc_dtype=dtype_out)
|
||||
@@ -24,7 +32,7 @@ def helper_tc_allclose(m:int, k:int, n:int, dtype_in:DType, dtype_out:DType, tc_
|
||||
run_schedule(sched)
|
||||
out = r.numpy()
|
||||
k = Linearizer(realized_ast)
|
||||
k.apply_tensor_cores(1, tc_opt=tc_opt)
|
||||
k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt)
|
||||
k.linearize()
|
||||
assert len([uop for uop in k.uops if uop.uop is UOps.WMMA]) > 0, "tensor core not triggered"
|
||||
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
|
||||
@@ -34,13 +42,13 @@ def helper_tc_allclose(m:int, k:int, n:int, dtype_in:DType, dtype_out:DType, tc_
|
||||
else: tc_atol, tc_rtol = 5e-3, 1e-4
|
||||
np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol)
|
||||
|
||||
def helper_tc_ensure_uops_and_opts_count(m:int, k:int, n:int, dtype_in:DType, dtype_out:DType, tc_opt:int, ensure_triggered:bool=True):
|
||||
def helper_tc_ensure_uops_and_opts_count(m:int, k:int, n:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_opt:int=0, ensure_triggered:bool=True):
|
||||
a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in)
|
||||
r = a.matmul(b, acc_dtype=dtype_out)
|
||||
sched = create_schedule([r.lazydata])
|
||||
realized_ast = sched[-1].ast[0]
|
||||
k = Linearizer(realized_ast)
|
||||
k.apply_tensor_cores(1, tc_opt=tc_opt)
|
||||
k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt)
|
||||
k.linearize()
|
||||
wmmas = len([uop for uop in k.uops if uop.uop is UOps.WMMA])
|
||||
tcs = len([x for x in k.applied_opts if x.op is OptOps.TC])
|
||||
@@ -214,7 +222,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
self.skipTest("device doesn't have tensor cores")
|
||||
for tc in tensor_cores[Device[Device.DEFAULT].compiler.compiler_opts.device]:
|
||||
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
|
||||
helper_tc_allclose(tc.dims[1], tc.dims[2], tc.dims[0], tc.dtype_in, tc.dtype_out, tc_opt=0)
|
||||
helper_tc_allclose(tc.dims[1], tc.dims[2], tc.dims[0], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0)
|
||||
|
||||
def test_tensor_cores_padded(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores:
|
||||
@@ -224,7 +232,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
pad = 1
|
||||
|
||||
# check that TC is triggered for TC_OPT=2
|
||||
helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad,tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=True)
|
||||
helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad,
|
||||
tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=True)
|
||||
|
||||
# check that TC is not triggered for TC_OPT<2
|
||||
helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad,
|
||||
@@ -240,6 +249,33 @@ class TestLinearizer(unittest.TestCase):
|
||||
# check correctness
|
||||
helper_tc_allclose(tc.dims[1]+pad, tc.dims[2]+pad, tc.dims[0]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2)
|
||||
|
||||
def test_tensor_cores_multi_reduce(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores:
|
||||
self.skipTest("device doesn't have tensor cores")
|
||||
for tc in tensor_cores[Device[Device.DEFAULT].compiler.compiler_opts.device]:
|
||||
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
|
||||
# this will be a M=G16, N=G32, M=G16, M=G16, K=R16, K=R16, K=R16 with 9 choices of TC MNK axes
|
||||
golden_result = None
|
||||
for axis in range(9):
|
||||
a = Tensor.rand(16, 16, 29, 29, dtype=tc.dtype_in).realize()
|
||||
b = Tensor.rand(32, 16, 16, 16, dtype=tc.dtype_in).realize()
|
||||
c = a.conv2d(b, padding=1, acc_dtype=tc.dtype_out)
|
||||
realized_ast, real_bufs = helper_realized_ast(c)
|
||||
|
||||
k = Linearizer(realized_ast)
|
||||
k.apply_tensor_cores(1, axis=axis, tc_opt=2)
|
||||
k.linearize()
|
||||
assert len([uop for uop in k.uops if uop.uop is UOps.WMMA]) > 0, "tensor core not triggered"
|
||||
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
|
||||
|
||||
prg = Device[Device.DEFAULT].to_program(k)
|
||||
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np).data) # Zero to check that all values are filled
|
||||
prg.exec(real_bufs)
|
||||
result = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np)
|
||||
|
||||
if golden_result is None: golden_result = np.frombuffer(real_bufs[0].as_buffer(), real_bufs[0].dtype.np)
|
||||
np.testing.assert_allclose(result, golden_result, atol=0.1, rtol=0.15)
|
||||
|
||||
def test_limit_dims_to_max_5d_global(self):
|
||||
t = Tensor.empty(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
|
||||
@@ -315,14 +351,6 @@ class TestLinearizer(unittest.TestCase):
|
||||
helper(Tensor.arange(256), max_ops=2)
|
||||
helper(Tensor.arange(255), max_ops=0)
|
||||
|
||||
def helper_realized_ast(r:Tensor):
|
||||
s = create_schedule([r.lazydata])
|
||||
run_schedule(s[:-1]) # run all kernels except the last one
|
||||
# now all input LazyBuffers buffers in s[-1] should be realized
|
||||
# allocate an output buffer
|
||||
output_buffer = Buffer((out:=s[-1].outputs[0]).device, out.size, out.dtype).allocate()
|
||||
return s[-1].ast[0], [output_buffer] + list(s[-1].inputs)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].compiler.compiler_opts.supports_float4, "need backends that support float4")
|
||||
class TestFloat4(unittest.TestCase):
|
||||
@staticmethod
|
||||
@@ -554,7 +582,7 @@ def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False, atol=1e-4, rtol=1e-
|
||||
def check_opt(opts, create_k, to_prg, expected_color_size):
|
||||
k = create_k()
|
||||
if apply_tc:
|
||||
assert k.apply_tensor_cores(1, opts), "no tensor core triggered"
|
||||
assert k.apply_tensor_cores(1, extra_opts=opts), "no tensor core triggered"
|
||||
else:
|
||||
for opt in opts:
|
||||
k.apply_opt(opt)
|
||||
|
||||
@@ -352,11 +352,11 @@ class Kernel:
|
||||
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
|
||||
if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): continue
|
||||
|
||||
axis_choices = list(itertools.product(axis_buf0, axis_buf1))
|
||||
axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
|
||||
if not(axis < len(axis_choices)): continue
|
||||
|
||||
s0, s1 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0] # s0 is n, s1 is m
|
||||
axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, self.first_reduce]) if self.full_shape[x]%tc.dims[i] != 0]
|
||||
s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
|
||||
axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0]
|
||||
if axis_pads and (opt_level < 2): continue
|
||||
|
||||
# tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
|
||||
@@ -366,7 +366,7 @@ class Kernel:
|
||||
try:
|
||||
for axis, dim in axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
|
||||
except KernelOptError: continue
|
||||
self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]), append_opt=False)
|
||||
self.apply_opt(Opt(OptOps.UNROLL, s2-self.first_reduce, tc.dims[2]), append_opt=False)
|
||||
for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
|
||||
if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
|
||||
for (tc_dim, tc_amt) in tc.threads:
|
||||
@@ -379,7 +379,7 @@ class Kernel:
|
||||
return False
|
||||
|
||||
|
||||
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, tc_opt:Optional[int]=getenv("TC_OPT")) -> bool:
|
||||
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, axis:int=0, tc_opt:int=getenv("TC_OPT")) -> bool:
|
||||
""" Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
|
||||
Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
|
||||
|
||||
@@ -396,7 +396,7 @@ class Kernel:
|
||||
"""
|
||||
if not self.opts.has_tensor_cores and use_tensor_cores != 2: return False
|
||||
try: # check TC first and apply hand-coded opts if successful
|
||||
self.apply_opt(Opt(OptOps.TC, 0, tc_opt))
|
||||
self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
|
||||
|
||||
if (tc_opts:=self.tensor_core_opts) is not None:
|
||||
if extra_opts is not None:
|
||||
|
||||
@@ -17,7 +17,7 @@ actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,28,29,
|
||||
actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)]
|
||||
actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
|
||||
actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)]
|
||||
actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(4)]
|
||||
actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
|
||||
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
|
||||
|
||||
def _get_test_global_size(global_size, max_global_size, var_vals):
|
||||
|
||||
Reference in New Issue
Block a user