mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER] Fix the load and store fallback issue of test_persisten… (#2057)
Co-authored-by: Biao Wang <biaow@nvidia.com>
This commit is contained in:
@@ -125,9 +125,6 @@ def static_persistent_tma_matmul_kernel(
|
||||
])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, TRANS_A, TRANS_B):
|
||||
# TODO: fix RewriteTensorPtrPass
|
||||
pytest.skip('RewriteTensorPtrPass issue')
|
||||
|
||||
if (TRANS_A):
|
||||
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
@@ -458,8 +455,6 @@ def static_persistent_tma_warp_specialized_matmul_kernel(
|
||||
])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B):
|
||||
# TODO: fix RewriteTensorPtrPass
|
||||
pytest.skip('RewriteTensorPtrPass issue')
|
||||
if (TRANS_A):
|
||||
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user