[BACKEND] turn on MMA V3 by default on Hopper (#2414)

This commit is contained in:
Thomas Raoux
2023-09-28 22:45:28 -07:00
committed by GitHub
parent d4fae90169
commit 90bef57acf
11 changed files with 23 additions and 33 deletions

View File

@@ -331,7 +331,6 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A,
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
# with ENABLE_TMA=0 and ENABLE_MMA_V3=0
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
@@ -444,7 +443,7 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A,
atol=1e-3,
check_dtype=False)
enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower()
if enable_mmav3 in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
ptx = pgm.asm['ptx']
assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(BLOCK_N), ptx)

View File

@@ -818,7 +818,6 @@ def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WAR
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
# with ENABLE_TMA=0 and ENABLE_MMA_V3=0
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',

View File

@@ -2433,11 +2433,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
red_code = ptx[start:end]
assert len(red_code) > 0
import os
enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower()
enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower()
# skip this check on hopper because there are some functions whose name contain "shared" in ptx.
# TODO: we should eliminate these unused functions in ptx code.
if not (enable_mmav3 in ["on", "true", "1"] and enable_tma in ["on", "true", "1"]):
if not (capability[0] >= 9):
assert "shared" not in red_code
assert "bar.sync" not in red_code
# torch result
@@ -2540,13 +2539,12 @@ def test_dot_mulbroadcastred(in_dtype, device):
if is_hip():
return
assert "tt.dot" in h.asm['ttir']
# with option ENABLE_MMA_V3 on, we will not pipeline the load op for Y
# when using MMAv3, we will not pipeline the load op for Y
# as the loaded value is in rowmajor. But MMAv3 requires it's second
# operand is in colmajor because transpose is not supported for MMAv3
# with float32 input.
import os
enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower()
if enable_mmav3 in ["on", "true", "1"]:
if capability[0] >= 9:
assert "triton_gpu.async_wait {num = 1 : i32}" in h.asm['ttgir']
else:
assert "triton_gpu.async_wait {num = 2 : i32}" in h.asm['ttgir']