mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[HOPPER][WS] fix TMA store hang in ws mode (#2056)
This commit is contained in:
@@ -93,11 +93,6 @@ def matmul_no_scf_kernel(
|
||||
]))
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, ENABLE_WS):
|
||||
if '-'.join(map(str, [USE_TMA_EPILOGUE, ENABLE_WS])) in [
|
||||
'True-True'
|
||||
]:
|
||||
pytest.skip("error, skip")
|
||||
|
||||
if (TRANS_A):
|
||||
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
@@ -335,12 +330,6 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A,
|
||||
]:
|
||||
pytest.skip('Known legacy issue, ldmatrix can only support x4')
|
||||
|
||||
# with ENABLE_TMA=1 and ENABLE_MMA_V3=1
|
||||
if ENABLE_WS:
|
||||
# example:
|
||||
# [128-128-64-4-1-None-None-None-False-False-False-chain-dot-float16-False-3-True]
|
||||
pytest.skip('hang!')
|
||||
|
||||
M = BLOCK_M if M is None else M
|
||||
N = BLOCK_N if N is None else N
|
||||
K = BLOCK_K if K is None else K
|
||||
|
||||
Reference in New Issue
Block a user