wmma: widen TC usage in search by using PADTO on TC axes when possible (#4216)

* wmma: widen TC usage in search by using PADTO on TC axes when possible

* test: start tests for the new padding TC behavior

* search: upgrade padded TC search to TC_OPT >= 2

* test: add behavior and correctness test for padded TC

added optional argument to apply_tensor_core to set TC_OPT level

* linearizer: add tests for the PADTO behvaior and docs
This commit is contained in:
Francis Lam
2024-04-22 13:50:31 -07:00
committed by GitHub
parent 9e53d6cffa
commit bbb0ad4800
5 changed files with 82 additions and 10 deletions

View File

@@ -196,6 +196,49 @@ class TestLinearizer(unittest.TestCase):
else: tc_atol, tc_rtol = 5e-3, 1e-4
np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol)
def test_tensor_cores_padded(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores:
self.skipTest("device doesn't have tensor cores")
for tc in tensor_cores[Device[Device.DEFAULT].compiler.compiler_opts.device]:
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
pad = 1
def ensure_uops_and_opts_count(m:int, k:int, n:int, tc_opt:int, ensure_triggered:bool=True):
a, b = Tensor.rand(m, k, dtype=tc.dtype_in), Tensor.rand(k, n, dtype=tc.dtype_in)
r = a.matmul(b, acc_dtype=tc.dtype_out)
sched = create_schedule([r.lazydata])
realized_ast = sched[-1].ast[0]
k = Linearizer(realized_ast)
k.apply_tensor_cores(1, tc_opt=tc_opt)
k.linearize()
wmmas = len([uop for uop in k.uops if uop.uop is UOps.WMMA])
tcs = len([x for x in k.applied_opts if x.op is OptOps.TC])
if ensure_triggered:
assert wmmas > 0, "tensor core not triggered"
assert tcs == 1, "tensor core opt not included"
else:
assert wmmas == 0, "tensor core is incorrectly triggered"
assert tcs == 0, "tensor core opt is incorrectly included"
# check that TC is triggered for TC_OPT=2
ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad, tc_opt=2, ensure_triggered=True)
# check that TC is not triggered for TC_OPT<2
ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad, tc_opt=1, ensure_triggered=False)
# check excessive padding doesn't trigger padded TC in TC_OPT=2
ensure_uops_and_opts_count(tc.dims[0]//2, tc.dims[2], tc.dims[1], tc_opt=2, ensure_triggered=False)
ensure_uops_and_opts_count(tc.dims[0], tc.dims[2]//2, tc.dims[1], tc_opt=2, ensure_triggered=False)
ensure_uops_and_opts_count(tc.dims[0], tc.dims[2], tc.dims[1]//2, tc_opt=2, ensure_triggered=False)
# check correctness
a, b = Tensor.rand(tc.dims[1]+pad, tc.dims[2]+pad, dtype=tc.dtype_in), Tensor.rand(tc.dims[2]+pad, tc.dims[0]+pad, dtype=tc.dtype_in)
r = a.matmul(b, acc_dtype=tc.dtype_out)
(atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4)
helper_linearizer_opt(r, [
[Opt(OptOps.TC, axis=0, amt=2)],
], atol=atol, rtol=rtol)
def test_limit_dims_to_max_5d_global(self):
t = Tensor.empty(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]