Fix dangling gpu_has_mfma use (#325)

* Fix dangling gpu_has_mfma use

This PR replaces gpu_has_mfma use with gpu_matrix_core_version

* add basic test
This commit is contained in:
Alexander Efimov
2023-09-11 19:31:48 +02:00
committed by GitHub
parent 6691de65db
commit a06072f8ff
2 changed files with 6 additions and 5 deletions

View File

@@ -1588,10 +1588,11 @@ def get_variant_golden(a, b):
return c_padded[:SIZE_M, :SIZE_N]
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [
[64, 32, 128, 4, 64, 32, 64],
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,NUM_STAGES', [
[64, 32, 128, 4, 64, 32, 64, 0],
[64, 32, 128, 4, 64, 32, 64, 2]
])
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K):
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, NUM_STAGES):
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
@@ -1603,7 +1604,7 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO
M=a.shape[0], N=b.shape[1], K=a.shape[1],
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
num_warps=NUM_WARPS,
num_stages=2)
num_stages=NUM_STAGES)
golden = torch.matmul(a, b)
# It's not easy to get a proper error threshold in different size

View File

@@ -91,7 +91,7 @@ def optimize_ttgir(mod, num_stages, arch):
pm.add_tritongpu_accelerate_matmul_pass(matrix_core_version)
pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_optimize_dot_operands_pass()
if num_stages == 0 and is_hip() and gpu_has_mfma():
if num_stages == 0 and is_hip() and gpu_matrix_core_version() != 0:
pm.add_tritongpu_stream_pipeline_pass()
pm.add_canonicalizer_pass()
else: