mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[HOPPER][WS] remove numCTAs = 1 check in guard pass (#2066)
This commit is contained in:
@@ -777,20 +777,17 @@ def full_static_persistent_matmul_kernel(
|
||||
(*shape_w_c, trans_a, trans_b, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 32, 4, 1, 128, 256, 64],
|
||||
# TODO: enable when num_ctas != 1 is supported.
|
||||
# [128, 128, 16, 4, 4, 512, 256, 64],
|
||||
# [128, 256, 32, 4, 8, 256, 256, 192],
|
||||
# [512, 256, 32, 4, 8, 1024, 256, 192],
|
||||
[128, 128, 16, 4, 4, 512, 256, 64],
|
||||
[128, 256, 32, 4, 8, 256, 256, 192],
|
||||
[512, 256, 32, 4, 8, 1024, 256, 192],
|
||||
# BLOCK_K >= 128
|
||||
[64, 128, 128, 4, 1, 512, 256, 256],
|
||||
[128, 128, 128, 4, 1, 256, 256, 192],
|
||||
# TODO: enable when num_ctas != 1 is supported.
|
||||
# [128, 128, 128, 4, 2, 256, 256, 192],
|
||||
[128, 128, 128, 4, 2, 256, 256, 192],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 32, 32, 4, 1, 128, 256, 64],
|
||||
[32, 32, 16, 4, 1, 256, 256, 192],
|
||||
# TODO: enable when num_ctas != 1 is supported.
|
||||
# [16, 32, 64, 4, 4, 512, 256, 64],
|
||||
[16, 32, 64, 4, 4, 512, 256, 64],
|
||||
]
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
@@ -801,13 +798,13 @@ def full_static_persistent_matmul_kernel(
|
||||
] + [(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 128, 128, 64],
|
||||
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4] for num_ctas in [1]],
|
||||
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4] for num_ctas in [1, 2, 4]],
|
||||
# for chain-dot
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[64, 64, 16, 4, 1, None, None, None],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 16, 64, 4, 1, 128, 128, 64],
|
||||
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4] for num_ctas in [1]],
|
||||
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4] for num_ctas in [1, 2]],
|
||||
# # TODO: enable when num_warps != 4 is supported.
|
||||
# # repeat
|
||||
# # [64, 64, 32, 8, 1, 128, 256, 64],
|
||||
@@ -836,9 +833,9 @@ def full_static_persistent_matmul_kernel(
|
||||
(*shape_w_c, *shape, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
# irregular shapes
|
||||
for shape_w_c in [
|
||||
[128, 128, 64, 4, 1]
|
||||
# [256, 128, 64, 4, 2],
|
||||
# [128, 128, 128, 4, 2],
|
||||
[128, 128, 64, 4, 1],
|
||||
[256, 128, 64, 4, 2],
|
||||
[128, 128, 128, 4, 2]
|
||||
]
|
||||
for shape in list(itertools.product([*range(512, 4096, 360)], [*range(512, 4096, 360)], [512, 1024]))
|
||||
for out_dtype in ['float16', 'float32']
|
||||
|
||||
Reference in New Issue
Block a user