[OPTIMIZER] Fix the load and store fallback issue of test_persisten… (#2057)

Co-authored-by: Biao Wang <biaow@nvidia.com>
This commit is contained in:
Beal Wang
2023-08-09 16:42:01 +08:00
committed by GitHub
parent 6d98a0899f
commit de47bba07d
3 changed files with 148 additions and 9 deletions

View File

@@ -125,9 +125,6 @@ def static_persistent_tma_matmul_kernel(
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, TRANS_A, TRANS_B):
# TODO: fix RewriteTensorPtrPass
pytest.skip('RewriteTensorPtrPass issue')
if (TRANS_A):
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
@@ -458,8 +455,6 @@ def static_persistent_tma_warp_specialized_matmul_kernel(
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B):
# TODO: fix RewriteTensorPtrPass
pytest.skip('RewriteTensorPtrPass issue')
if (TRANS_A):
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
else: