mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Fix dangling gpu_has_mfma use (#325)
* Fix dangling gpu_has_mfma use This PR replaces gpu_has_mfma use with gpu_matrix_core_version * add basic test
This commit is contained in:
@@ -1588,10 +1588,11 @@ def get_variant_golden(a, b):
|
||||
return c_padded[:SIZE_M, :SIZE_N]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [
|
||||
[64, 32, 128, 4, 64, 32, 64],
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,NUM_STAGES', [
|
||||
[64, 32, 128, 4, 64, 32, 64, 0],
|
||||
[64, 32, 128, 4, 64, 32, 64, 2]
|
||||
])
|
||||
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K):
|
||||
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, NUM_STAGES):
|
||||
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
|
||||
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
|
||||
@@ -1603,7 +1604,7 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO
|
||||
M=a.shape[0], N=b.shape[1], K=a.shape[1],
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
num_warps=NUM_WARPS,
|
||||
num_stages=2)
|
||||
num_stages=NUM_STAGES)
|
||||
golden = torch.matmul(a, b)
|
||||
|
||||
# It's not easy to get a proper error threshold in different size
|
||||
|
||||
@@ -91,7 +91,7 @@ def optimize_ttgir(mod, num_stages, arch):
|
||||
pm.add_tritongpu_accelerate_matmul_pass(matrix_core_version)
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
pm.add_tritongpu_optimize_dot_operands_pass()
|
||||
if num_stages == 0 and is_hip() and gpu_has_mfma():
|
||||
if num_stages == 0 and is_hip() and gpu_matrix_core_version() != 0:
|
||||
pm.add_tritongpu_stream_pipeline_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user