mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] turn on MMA V3 by default on Hopper (#2414)
This commit is contained in:
@@ -331,7 +331,6 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A,
|
||||
]:
|
||||
pytest.skip('shapePerCTA[1] < 16 not supported')
|
||||
|
||||
# with ENABLE_TMA=0 and ENABLE_MMA_V3=0
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
|
||||
'16-32-64-4-1-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-False',
|
||||
@@ -444,7 +443,7 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A,
|
||||
atol=1e-3,
|
||||
check_dtype=False)
|
||||
|
||||
enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower()
|
||||
if enable_mmav3 in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
|
||||
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
|
||||
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
|
||||
ptx = pgm.asm['ptx']
|
||||
assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(BLOCK_N), ptx)
|
||||
|
||||
@@ -818,7 +818,6 @@ def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WAR
|
||||
]:
|
||||
pytest.skip('shapePerCTA[1] < 16 not supported')
|
||||
|
||||
# with ENABLE_TMA=0 and ENABLE_MMA_V3=0
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
|
||||
'16-32-64-4-1-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-False',
|
||||
|
||||
@@ -2433,11 +2433,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
red_code = ptx[start:end]
|
||||
assert len(red_code) > 0
|
||||
import os
|
||||
enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower()
|
||||
enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower()
|
||||
|
||||
# skip this check on hopper because there are some functions whose name contain "shared" in ptx.
|
||||
# TODO: we should eliminate these unused functions in ptx code.
|
||||
if not (enable_mmav3 in ["on", "true", "1"] and enable_tma in ["on", "true", "1"]):
|
||||
if not (capability[0] >= 9):
|
||||
assert "shared" not in red_code
|
||||
assert "bar.sync" not in red_code
|
||||
# torch result
|
||||
@@ -2540,13 +2539,12 @@ def test_dot_mulbroadcastred(in_dtype, device):
|
||||
if is_hip():
|
||||
return
|
||||
assert "tt.dot" in h.asm['ttir']
|
||||
# with option ENABLE_MMA_V3 on, we will not pipeline the load op for Y
|
||||
# when using MMAv3, we will not pipeline the load op for Y
|
||||
# as the loaded value is in rowmajor. But MMAv3 requires it's second
|
||||
# operand is in colmajor because transpose is not supported for MMAv3
|
||||
# with float32 input.
|
||||
import os
|
||||
enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower()
|
||||
if enable_mmav3 in ["on", "true", "1"]:
|
||||
if capability[0] >= 9:
|
||||
assert "triton_gpu.async_wait {num = 1 : i32}" in h.asm['ttgir']
|
||||
else:
|
||||
assert "triton_gpu.async_wait {num = 2 : i32}" in h.asm['ttgir']
|
||||
|
||||
Reference in New Issue
Block a user