[TESTS] better matmul unit testing (#2098)

This commit is contained in:
Philippe Tillet
2023-08-13 17:54:32 -07:00
committed by GitHub
parent fc667d1f8f
commit facc1dcbac
2 changed files with 27 additions and 29 deletions

View File

@@ -99,7 +99,7 @@ jobs:
if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0' && env.ENABLE_MMA_V3 == '0'}}
run: |
cd python/test/unit
python3 -m pytest -n 8 --ignore=runtime
python3 -m pytest -n 8 --ignore=runtime --ignore=hopper
# run runtime tests serially to avoid race condition with cache handling.
python3 -m pytest runtime/

View File

@@ -58,14 +58,10 @@ def f8_to_f16(x, dtype):
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE),
# split-k
(64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
# variable input
(128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE),
(128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
],
@@ -77,9 +73,6 @@ def f8_to_f16(x, dtype):
(128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE),
(256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE),
# split-k
(64, 64, 16, 8, 4, STAGES, 128, 128, 768, AT, BT, DTYPE, DTYPE),
(64, 64, 16, 8, 4, STAGES, 128, 128, 32, AT, BT, DTYPE, DTYPE),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [4]
],
# mixed-precision
@@ -88,7 +81,6 @@ def f8_to_f16(x, dtype):
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE),
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE),
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE),
(128, 128, 32, 8, 4, 2, 256, 256, 128, AT, BT, ADTYPE, BDTYPE),
] for ADTYPE, BDTYPE in [("float8e4", "float8e5"),
("float8e4", "float16"),
("float16", "float8e5"),
@@ -131,35 +123,44 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K * SPLIT_K if K is None else K
a_fp8 = "float8" in ADTYPE
b_fp8 = "float8" in BDTYPE
def maybe_upcast(x, dtype, is_float8):
if is_float8:
return f8_to_f16(x, dtype)
return x
def init_input(n, m, t, dtype, is_float8):
if t:
return init_input(m, n, False, dtype, is_float8).t()
if is_float8:
return torch.randint(20, 50, (n, m), device="cuda", dtype=torch.int8)
def init_input(m, n, dtype):
if 'float8' in dtype:
ewidth = {'float8e4b15': 4, 'float8e4': 4, 'float8e5': 5}[dtype]
sign = torch.randint(2, size=(m, n), device="cuda", dtype=torch.int8) * 128
val = torch.randint(2**3 - 1, size=(m, n), device="cuda", dtype=torch.int8) << 7 - ewidth
return sign | val
if dtype == "int8":
return torch.randint(-128, 127, (n, m), device="cuda", dtype=torch.int8)
return torch.randint(-128, 127, (m, n), device="cuda", dtype=torch.int8)
dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype]
return .1 * torch.randn((n, m), device="cuda", dtype=dtype)
exponents = torch.randint(-10, 0, size=(m, n))
ret = (2. ** exponents).to(dtype).to("cuda")
return ret
# allocate/transpose inputs
a = init_input(M, K, AT, ADTYPE, a_fp8)
b = init_input(K, N, BT, BDTYPE, b_fp8)
a = init_input(M, K, ADTYPE)
b = init_input(K, N, BDTYPE)
a = a if not AT else a.T.contiguous().T
b = b if not BT else b.T.contiguous().T
# run test
th_a = maybe_upcast(a, ADTYPE, a_fp8).to(torch.float32)
a_fp8 = "float8" in ADTYPE
b_fp8 = "float8" in BDTYPE
th_a = maybe_upcast(a, ADTYPE, a_fp8)
if AT and a_fp8:
th_a = th_a.view(th_a.shape[::-1]).T
th_b = maybe_upcast(b, BDTYPE, b_fp8).to(torch.float32)
th_b = maybe_upcast(b, BDTYPE, b_fp8)
if BT and b_fp8:
th_b = th_b.view(th_b.shape[::-1]).T
th_c = torch.matmul(th_a, th_b)
if th_a.is_floating_point():
ab_dtype = th_a.dtype if th_a.element_size() > th_b.element_size() else th_b.dtype
else:
ab_dtype = torch.float32
th_c = torch.matmul(th_a.to(ab_dtype), th_b.to(ab_dtype))
if ADTYPE == "int8" or BDTYPE == "int8":
th_c = th_c.to(torch.int8)
try:
@@ -168,9 +169,6 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
if b_fp8:
b = triton.reinterpret(b, getattr(tl, BDTYPE))
tt_c = triton.ops.matmul(a, b)
atol, rtol = 1e-2, 0
if ADTYPE == torch.bfloat16 or BDTYPE == torch.bfloat16:
atol, rtol = 3.5e-2, 0
torch.testing.assert_allclose(th_c, tt_c, atol=atol, rtol=rtol)
torch.testing.assert_allclose(th_c, tt_c, atol=0, rtol=0)
except triton.OutOfResources as e:
pytest.skip(str(e))