mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
wmma: widen TC usage in search by using PADTO on TC axes when possible (#4216)
* wmma: widen TC usage in search by using PADTO on TC axes when possible * test: start tests for the new padding TC behavior * search: upgrade padded TC search to TC_OPT >= 2 * test: add behavior and correctness test for padded TC added optional argument to apply_tensor_core to set TC_OPT level * linearizer: add tests for the PADTO behvaior and docs
This commit is contained in:
@@ -196,6 +196,49 @@ class TestLinearizer(unittest.TestCase):
|
||||
else: tc_atol, tc_rtol = 5e-3, 1e-4
|
||||
np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol)
|
||||
|
||||
def test_tensor_cores_padded(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores:
|
||||
self.skipTest("device doesn't have tensor cores")
|
||||
for tc in tensor_cores[Device[Device.DEFAULT].compiler.compiler_opts.device]:
|
||||
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
|
||||
pad = 1
|
||||
|
||||
def ensure_uops_and_opts_count(m:int, k:int, n:int, tc_opt:int, ensure_triggered:bool=True):
|
||||
a, b = Tensor.rand(m, k, dtype=tc.dtype_in), Tensor.rand(k, n, dtype=tc.dtype_in)
|
||||
r = a.matmul(b, acc_dtype=tc.dtype_out)
|
||||
sched = create_schedule([r.lazydata])
|
||||
realized_ast = sched[-1].ast[0]
|
||||
k = Linearizer(realized_ast)
|
||||
k.apply_tensor_cores(1, tc_opt=tc_opt)
|
||||
k.linearize()
|
||||
wmmas = len([uop for uop in k.uops if uop.uop is UOps.WMMA])
|
||||
tcs = len([x for x in k.applied_opts if x.op is OptOps.TC])
|
||||
if ensure_triggered:
|
||||
assert wmmas > 0, "tensor core not triggered"
|
||||
assert tcs == 1, "tensor core opt not included"
|
||||
else:
|
||||
assert wmmas == 0, "tensor core is incorrectly triggered"
|
||||
assert tcs == 0, "tensor core opt is incorrectly included"
|
||||
|
||||
# check that TC is triggered for TC_OPT=2
|
||||
ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad, tc_opt=2, ensure_triggered=True)
|
||||
|
||||
# check that TC is not triggered for TC_OPT<2
|
||||
ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad, tc_opt=1, ensure_triggered=False)
|
||||
|
||||
# check excessive padding doesn't trigger padded TC in TC_OPT=2
|
||||
ensure_uops_and_opts_count(tc.dims[0]//2, tc.dims[2], tc.dims[1], tc_opt=2, ensure_triggered=False)
|
||||
ensure_uops_and_opts_count(tc.dims[0], tc.dims[2]//2, tc.dims[1], tc_opt=2, ensure_triggered=False)
|
||||
ensure_uops_and_opts_count(tc.dims[0], tc.dims[2], tc.dims[1]//2, tc_opt=2, ensure_triggered=False)
|
||||
|
||||
# check correctness
|
||||
a, b = Tensor.rand(tc.dims[1]+pad, tc.dims[2]+pad, dtype=tc.dtype_in), Tensor.rand(tc.dims[2]+pad, tc.dims[0]+pad, dtype=tc.dtype_in)
|
||||
r = a.matmul(b, acc_dtype=tc.dtype_out)
|
||||
(atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4)
|
||||
helper_linearizer_opt(r, [
|
||||
[Opt(OptOps.TC, axis=0, amt=2)],
|
||||
], atol=atol, rtol=rtol)
|
||||
|
||||
def test_limit_dims_to_max_5d_global(self):
|
||||
t = Tensor.empty(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
|
||||
|
||||
Reference in New Issue
Block a user