mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TESTS] better matmul unit testing (#2098)
This commit is contained in:
@@ -58,14 +58,10 @@ def f8_to_f16(x, dtype):
|
||||
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
# variable input
|
||||
(128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE),
|
||||
(128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE),
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||
],
|
||||
@@ -77,9 +73,6 @@ def f8_to_f16(x, dtype):
|
||||
(128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE),
|
||||
(256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE),
|
||||
(128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE),
|
||||
# split-k
|
||||
(64, 64, 16, 8, 4, STAGES, 128, 128, 768, AT, BT, DTYPE, DTYPE),
|
||||
(64, 64, 16, 8, 4, STAGES, 128, 128, 32, AT, BT, DTYPE, DTYPE),
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [4]
|
||||
],
|
||||
# mixed-precision
|
||||
@@ -88,7 +81,6 @@ def f8_to_f16(x, dtype):
|
||||
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE),
|
||||
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE),
|
||||
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE),
|
||||
(128, 128, 32, 8, 4, 2, 256, 256, 128, AT, BT, ADTYPE, BDTYPE),
|
||||
] for ADTYPE, BDTYPE in [("float8e4", "float8e5"),
|
||||
("float8e4", "float16"),
|
||||
("float16", "float8e5"),
|
||||
@@ -131,35 +123,44 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
|
||||
M = BLOCK_M if M is None else M
|
||||
N = BLOCK_N if N is None else N
|
||||
K = BLOCK_K * SPLIT_K if K is None else K
|
||||
a_fp8 = "float8" in ADTYPE
|
||||
b_fp8 = "float8" in BDTYPE
|
||||
|
||||
def maybe_upcast(x, dtype, is_float8):
|
||||
if is_float8:
|
||||
return f8_to_f16(x, dtype)
|
||||
return x
|
||||
|
||||
def init_input(n, m, t, dtype, is_float8):
|
||||
if t:
|
||||
return init_input(m, n, False, dtype, is_float8).t()
|
||||
if is_float8:
|
||||
return torch.randint(20, 50, (n, m), device="cuda", dtype=torch.int8)
|
||||
def init_input(m, n, dtype):
|
||||
if 'float8' in dtype:
|
||||
ewidth = {'float8e4b15': 4, 'float8e4': 4, 'float8e5': 5}[dtype]
|
||||
sign = torch.randint(2, size=(m, n), device="cuda", dtype=torch.int8) * 128
|
||||
val = torch.randint(2**3 - 1, size=(m, n), device="cuda", dtype=torch.int8) << 7 - ewidth
|
||||
return sign | val
|
||||
if dtype == "int8":
|
||||
return torch.randint(-128, 127, (n, m), device="cuda", dtype=torch.int8)
|
||||
return torch.randint(-128, 127, (m, n), device="cuda", dtype=torch.int8)
|
||||
dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype]
|
||||
return .1 * torch.randn((n, m), device="cuda", dtype=dtype)
|
||||
exponents = torch.randint(-10, 0, size=(m, n))
|
||||
ret = (2. ** exponents).to(dtype).to("cuda")
|
||||
return ret
|
||||
|
||||
# allocate/transpose inputs
|
||||
a = init_input(M, K, AT, ADTYPE, a_fp8)
|
||||
b = init_input(K, N, BT, BDTYPE, b_fp8)
|
||||
a = init_input(M, K, ADTYPE)
|
||||
b = init_input(K, N, BDTYPE)
|
||||
a = a if not AT else a.T.contiguous().T
|
||||
b = b if not BT else b.T.contiguous().T
|
||||
# run test
|
||||
th_a = maybe_upcast(a, ADTYPE, a_fp8).to(torch.float32)
|
||||
a_fp8 = "float8" in ADTYPE
|
||||
b_fp8 = "float8" in BDTYPE
|
||||
th_a = maybe_upcast(a, ADTYPE, a_fp8)
|
||||
if AT and a_fp8:
|
||||
th_a = th_a.view(th_a.shape[::-1]).T
|
||||
th_b = maybe_upcast(b, BDTYPE, b_fp8).to(torch.float32)
|
||||
th_b = maybe_upcast(b, BDTYPE, b_fp8)
|
||||
if BT and b_fp8:
|
||||
th_b = th_b.view(th_b.shape[::-1]).T
|
||||
th_c = torch.matmul(th_a, th_b)
|
||||
if th_a.is_floating_point():
|
||||
ab_dtype = th_a.dtype if th_a.element_size() > th_b.element_size() else th_b.dtype
|
||||
else:
|
||||
ab_dtype = torch.float32
|
||||
th_c = torch.matmul(th_a.to(ab_dtype), th_b.to(ab_dtype))
|
||||
if ADTYPE == "int8" or BDTYPE == "int8":
|
||||
th_c = th_c.to(torch.int8)
|
||||
try:
|
||||
@@ -168,9 +169,6 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
|
||||
if b_fp8:
|
||||
b = triton.reinterpret(b, getattr(tl, BDTYPE))
|
||||
tt_c = triton.ops.matmul(a, b)
|
||||
atol, rtol = 1e-2, 0
|
||||
if ADTYPE == torch.bfloat16 or BDTYPE == torch.bfloat16:
|
||||
atol, rtol = 3.5e-2, 0
|
||||
torch.testing.assert_allclose(th_c, tt_c, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_allclose(th_c, tt_c, atol=0, rtol=0)
|
||||
except triton.OutOfResources as e:
|
||||
pytest.skip(str(e))
|
||||
|
||||
Reference in New Issue
Block a user