[TESTS] Fix tl.dot test on sm75 (#2140)

Disable tf32 if run on sm75 and below
Fix the pattern match to compare the generated ptx against if run on
sm75
This commit is contained in:
Alexander Zinoviev
2023-08-19 22:21:18 -07:00
committed by GitHub
parent e072da5b57
commit a7b40a10f9

View File

@@ -2151,7 +2151,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
if capability[0] < 8:
if in_dtype == 'int8':
pytest.skip("Only test int8 on devices with sm >= 80")
elif in_dtype == 'float32' and allow_tf32:
elif allow_tf32:
pytest.skip("Only test tf32 on devices with sm >= 80")
if capability[0] == 7:
if (M, N, K, num_warps) == (128, 256, 32, 8):
@@ -2205,7 +2205,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
z = num / den[:, None]
if CHAIN_DOT:
w = tl.load(Ws)
z = tl.dot(z.to(w.dtype), w, out_dtype=out_dtype)
z = tl.dot(z.to(w.dtype), w, allow_tf32=ALLOW_TF32, out_dtype=out_dtype)
tl.store(Zs, z)
# input
rs = RandomState(17)
@@ -2312,9 +2312,15 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
if in_dtype == 'float32' and allow_tf32:
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.tf32.tf32', ptx)
elif in_dtype == 'float16' and out_dtype == tl.float32:
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.f16.f16', ptx)
if capability[0] == 7 and capability[1] == 5: # Turing
assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.f16.f16', ptx)
else:
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.f16.f16', ptx)
elif in_dtype == 'float16' and out_dtype == tl.float16:
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f16.f16.f16', ptx)
if capability[0] == 7 and capability[1] == 5: # Turing
assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f16.f16.f16', ptx)
else:
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f16.f16.f16', ptx)
elif in_dtype == 'int8':
assert 'wgmma.mma_async.sync.aligned' in ptx or\
'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
@@ -2322,6 +2328,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
@pytest.mark.parametrize('in_dtype', ['float32'])
def test_dot_mulbroadcastred(in_dtype, device):
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
pytest.skip("Requires sm >= 80 to run")
@triton.jit
def kernel(Z, X, Y,
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
@@ -2419,11 +2428,13 @@ def test_constexpr(literal, dtype_str, device):
@pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
def test_dot_without_load(dtype_str, device):
capability = torch.cuda.get_device_capability()
allow_tf32 = capability[0] > 7
@triton.jit
def _kernel(out):
def _kernel(out, ALLOW_TF32: tl.constexpr):
a = GENERATE_TEST_HERE
b = GENERATE_TEST_HERE
c = tl.dot(a, b)
c = tl.dot(a, b, allow_tf32=ALLOW_TF32)
out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
tl.store(out_ptr, c)
kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"})
@@ -2431,7 +2442,7 @@ def test_dot_without_load(dtype_str, device):
b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device)
out_ref = torch.matmul(a, b)
out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device=device)
kernel[(1,)](out)
kernel[(1,)](out, ALLOW_TF32=allow_tf32)
assert torch.all(out == out_ref)
# ---------------