mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TESTS] Fix tl.dot test on sm75 (#2140)
Disable tf32 if run on sm75 and below Fix the pattern match to compare the generated ptx against if run on sm75
This commit is contained in:
committed by
GitHub
parent
e072da5b57
commit
a7b40a10f9
@@ -2151,7 +2151,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
if capability[0] < 8:
|
||||
if in_dtype == 'int8':
|
||||
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||
elif in_dtype == 'float32' and allow_tf32:
|
||||
elif allow_tf32:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
if capability[0] == 7:
|
||||
if (M, N, K, num_warps) == (128, 256, 32, 8):
|
||||
@@ -2205,7 +2205,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
z = num / den[:, None]
|
||||
if CHAIN_DOT:
|
||||
w = tl.load(Ws)
|
||||
z = tl.dot(z.to(w.dtype), w, out_dtype=out_dtype)
|
||||
z = tl.dot(z.to(w.dtype), w, allow_tf32=ALLOW_TF32, out_dtype=out_dtype)
|
||||
tl.store(Zs, z)
|
||||
# input
|
||||
rs = RandomState(17)
|
||||
@@ -2312,9 +2312,15 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
if in_dtype == 'float32' and allow_tf32:
|
||||
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.tf32.tf32', ptx)
|
||||
elif in_dtype == 'float16' and out_dtype == tl.float32:
|
||||
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.f16.f16', ptx)
|
||||
if capability[0] == 7 and capability[1] == 5: # Turing
|
||||
assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.f16.f16', ptx)
|
||||
else:
|
||||
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.f16.f16', ptx)
|
||||
elif in_dtype == 'float16' and out_dtype == tl.float16:
|
||||
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f16.f16.f16', ptx)
|
||||
if capability[0] == 7 and capability[1] == 5: # Turing
|
||||
assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f16.f16.f16', ptx)
|
||||
else:
|
||||
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f16.f16.f16', ptx)
|
||||
elif in_dtype == 'int8':
|
||||
assert 'wgmma.mma_async.sync.aligned' in ptx or\
|
||||
'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
@@ -2322,6 +2328,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
|
||||
@pytest.mark.parametrize('in_dtype', ['float32'])
|
||||
def test_dot_mulbroadcastred(in_dtype, device):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
pytest.skip("Requires sm >= 80 to run")
|
||||
@triton.jit
|
||||
def kernel(Z, X, Y,
|
||||
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
|
||||
@@ -2419,11 +2428,13 @@ def test_constexpr(literal, dtype_str, device):
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
|
||||
def test_dot_without_load(dtype_str, device):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
allow_tf32 = capability[0] > 7
|
||||
@triton.jit
|
||||
def _kernel(out):
|
||||
def _kernel(out, ALLOW_TF32: tl.constexpr):
|
||||
a = GENERATE_TEST_HERE
|
||||
b = GENERATE_TEST_HERE
|
||||
c = tl.dot(a, b)
|
||||
c = tl.dot(a, b, allow_tf32=ALLOW_TF32)
|
||||
out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
||||
tl.store(out_ptr, c)
|
||||
kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"})
|
||||
@@ -2431,7 +2442,7 @@ def test_dot_without_load(dtype_str, device):
|
||||
b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device)
|
||||
out_ref = torch.matmul(a, b)
|
||||
out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device=device)
|
||||
kernel[(1,)](out)
|
||||
kernel[(1,)](out, ALLOW_TF32=allow_tf32)
|
||||
assert torch.all(out == out_ref)
|
||||
|
||||
# ---------------
|
||||
|
||||
Reference in New Issue
Block a user