mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Disable pipelining (#276)
This PR sets default value of pipeline stages to 1 for amd. Add explicit num stages in test_gemm test
This commit is contained in:
@@ -1511,7 +1511,8 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
M=a.shape[0], N=b.shape[1], K=a.shape[1],
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
num_warps=NUM_WARPS)
|
||||
num_warps=NUM_WARPS,
|
||||
num_stages=2)
|
||||
golden = torch.matmul(a, b)
|
||||
|
||||
# It's not easy to get a proper error threshold in different size
|
||||
|
||||
@@ -399,7 +399,7 @@ def compile(fn, **kwargs):
|
||||
context = ir.context()
|
||||
constants = kwargs.get("constants", dict())
|
||||
num_warps = kwargs.get("num_warps", 4)
|
||||
num_stages = kwargs.get("num_stages", 3 if is_cuda and arch >= 75 else 2)
|
||||
num_stages = kwargs.get("num_stages", 3 if is_cuda and arch >= 75 else (1 if is_hip else 2))
|
||||
extern_libs = kwargs.get("extern_libs", dict())
|
||||
if extern_libs is None:
|
||||
extern_libs = dict()
|
||||
|
||||
Reference in New Issue
Block a user