Refine test_correctness (#463)

- Check correctness of what is benchmarked
- Add capability to check col_a and col_b
  - But only check col_a=False, col_b=True for now
- Only benchmark col_a=False, col_b=True for now
- Remove in='int8', out='int8' due to too large error
This commit is contained in:
Lixun Zhang
2024-01-16 11:15:54 -06:00
committed by GitHub
parent 1223f6077a
commit a819e48435

View File

@@ -170,7 +170,7 @@ name_to_tl_types = {
'fp8e5': tl.float8e5b16,
}
def gen_input(M, N, ty_name, seed, device='cuda'):
def gen_input(M, N, ty_name, needTrans, seed, device='cuda'):
d_type = name_to_tl_types[ty_name]
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@@ -183,7 +183,10 @@ def gen_input(M, N, ty_name, seed, device='cuda'):
output = input
tl.store(output_ptr + offsets, output, mask=mask)
raw_data = torch.randn((M, N), dtype=torch.float32, device='cuda')
if needTrans:
raw_data = torch.randn((N, M), dtype=torch.float32, device='cuda').T
else:
raw_data = torch.randn((M, N), dtype=torch.float32, device='cuda')
if (d_type == tl.float8e4b8 and TORCH_HAS_FP8E4B8) or \
(d_type == tl.float8e5b16 and TORCH_HAS_FP8E5B16) or not d_type.is_fp8():
input = raw_data.to(tl_to_torch_types[d_type])
@@ -205,37 +208,50 @@ def gen_input(M, N, ty_name, seed, device='cuda'):
# ---------
#
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., rocBLAS).
@pytest.mark.parametrize("M, N, K, in_dtype, out_dtype",
[ (*shape, in_dtype, out_dtype)
for shape in [(128, 256, 32), (128, 16, 32), (32, 128, 64),
(128, 128, 64), (64, 128, 128), (32, 128, 64),
(64, 64, 32), (32, 32, 128), (128, 128, 64),
(64, 128, 128), (512, 512, 512), (1024, 1024, 1024)]
def get_x_vals():
x_vals = [(1024 * v, 1024 * v, 1024 * v) for v in range (1, 9)]
x_vals += [(4864, 4096, 8192), (9728, 8192, 65536)]
return x_vals
@pytest.mark.parametrize("M, N, K, in_dtype, out_dtype, col_a, col_b",
[ (*shape, in_dtype, out_dtype, col_a, col_b)
for shape in get_x_vals()
for in_dtype, out_dtype in [('fp16', 'fp16'),
('bf16', 'bf16'),
('fp16', 'fp32'),
('fp32', 'fp32'),
('fp8e4', 'fp16'),
('fp8e5', 'fp16'),
('int8', 'int8'),
('int8', 'int32')]]
#('int8', 'int8'),
('int8', 'int32')]
# Only test k-major tensors because
# 1. This is the most preformant config and the current focus
# 2. Other case does not work with num_stages=0 (TODO (zhanglx))
for col_a in [False]
for col_b in [True]]
)
def test_correctness(M, N, K, in_dtype, out_dtype):
a, a_fp16 = gen_input(M, K, in_dtype, 1, device='cuda')
b, b_fp16 = gen_input(K, N, in_dtype, 2, device='cuda')
def test_correctness(M, N, K, col_a, col_b, in_dtype, out_dtype):
a, a_fp16 = gen_input(M, K, in_dtype, col_a, 1, device='cuda')
b, b_fp16 = gen_input(K, N, in_dtype, col_b, 2, device='cuda')
# Allocates output.
tl_out_dtype = name_to_tl_types[out_dtype]
c = torch.empty((M, N), device=a.device, dtype=tl_to_torch_types[tl_out_dtype])
torch_out_dtype = tl_to_torch_types[tl_out_dtype]
c = torch.empty((M, N), device=a.device, dtype=torch_out_dtype)
matmul(a, b, c, activation="")
torch_output = torch.matmul(a_fp16, b_fp16)
print(f"triton_output={c}")
print(f"torch_output={torch_output}")
rtol = 0 if torch.version.hip is None else 1e-2
if torch.allclose(c.to(torch.float16), torch_output, atol=5e-2, rtol=rtol):
print("✅ Triton and Torch match")
if in_dtype == 'fp8e4' or in_dtype == 'fp8e5' or in_dtype == 'int8':
# For f8 and int8 inputs, use fp16 for torch.matmul
torch_output = torch.matmul(a_fp16, b_fp16)
else:
print("❌ Triton and Torch differ")
assert torch.allclose(c, torch_output, atol=1e-2, rtol=rtol)
torch_output = torch.matmul(a, b)
#print(f"triton_output={c}")
#print(f"torch_output={torch_output}")
rtol = 0 if torch.version.hip is None else 1e-2
if in_dtype == 'int8':
torch.testing.assert_close(c.to(torch.float16), torch_output, atol=5e-2, rtol=rtol)
else:
torch.testing.assert_close(c, torch_output.to(torch_out_dtype), atol=5e-2, rtol=rtol)
# %%
@@ -252,13 +268,6 @@ def get_type(provider):
res = re.findall(r'\(.*?\)', provider)
return res[0][1:-1]
def get_x_vals():
x_vals = [(1024 * v, 1024 * v, 1024 * v) for v in range (1, 9)]
x_vals += [(4864, 4096, 8192), (9728, 8192, 65536)]
return x_vals
inout_dtype = {
'int8': torch.int8,
'fp16': torch.float16,
@@ -293,8 +302,8 @@ def benchmark(M, N, K, provider):
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
else: # triton, different data types
assert "triton" in provider
a, _ = gen_input(M, K, in_dtype, 1, device='cuda')
b, _ = gen_input(K, N, in_dtype, 2, device='cuda')
a, _ = gen_input(M, K, in_dtype, False, 1, device='cuda')
b, _ = gen_input(K, N, in_dtype, True, 2, device='cuda')
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=out_dtype)