diff --git a/python/perf-kernels/03-matrix-multiplication-all-types.py b/python/perf-kernels/03-matrix-multiplication-all-types.py index 87ab3ccbf..6a8928a1f 100644 --- a/python/perf-kernels/03-matrix-multiplication-all-types.py +++ b/python/perf-kernels/03-matrix-multiplication-all-types.py @@ -170,7 +170,7 @@ name_to_tl_types = { 'fp8e5': tl.float8e5b16, } -def gen_input(M, N, ty_name, seed, device='cuda'): +def gen_input(M, N, ty_name, needTrans, seed, device='cuda'): d_type = name_to_tl_types[ty_name] torch.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -183,7 +183,10 @@ def gen_input(M, N, ty_name, seed, device='cuda'): output = input tl.store(output_ptr + offsets, output, mask=mask) - raw_data = torch.randn((M, N), dtype=torch.float32, device='cuda') + if needTrans: + raw_data = torch.randn((N, M), dtype=torch.float32, device='cuda').T + else: + raw_data = torch.randn((M, N), dtype=torch.float32, device='cuda') if (d_type == tl.float8e4b8 and TORCH_HAS_FP8E4B8) or \ (d_type == tl.float8e5b16 and TORCH_HAS_FP8E5B16) or not d_type.is_fp8(): input = raw_data.to(tl_to_torch_types[d_type]) @@ -205,37 +208,50 @@ def gen_input(M, N, ty_name, seed, device='cuda'): # --------- # # We can test our custom matrix multiplication operation against a native torch implementation (i.e., rocBLAS). -@pytest.mark.parametrize("M, N, K, in_dtype, out_dtype", -[ (*shape, in_dtype, out_dtype) - for shape in [(128, 256, 32), (128, 16, 32), (32, 128, 64), - (128, 128, 64), (64, 128, 128), (32, 128, 64), - (64, 64, 32), (32, 32, 128), (128, 128, 64), - (64, 128, 128), (512, 512, 512), (1024, 1024, 1024)] +def get_x_vals(): + x_vals = [(1024 * v, 1024 * v, 1024 * v) for v in range (1, 9)] + + x_vals += [(4864, 4096, 8192), (9728, 8192, 65536)] + + return x_vals + +@pytest.mark.parametrize("M, N, K, in_dtype, out_dtype, col_a, col_b", +[ (*shape, in_dtype, out_dtype, col_a, col_b) + for shape in get_x_vals() for in_dtype, out_dtype in [('fp16', 'fp16'), ('bf16', 'bf16'), ('fp16', 'fp32'), ('fp32', 'fp32'), ('fp8e4', 'fp16'), ('fp8e5', 'fp16'), - ('int8', 'int8'), - ('int8', 'int32')]] + #('int8', 'int8'), + ('int8', 'int32')] + # Only test k-major tensors because + # 1. This is the most preformant config and the current focus + # 2. Other case does not work with num_stages=0 (TODO (zhanglx)) + for col_a in [False] + for col_b in [True]] ) -def test_correctness(M, N, K, in_dtype, out_dtype): - a, a_fp16 = gen_input(M, K, in_dtype, 1, device='cuda') - b, b_fp16 = gen_input(K, N, in_dtype, 2, device='cuda') +def test_correctness(M, N, K, col_a, col_b, in_dtype, out_dtype): + a, a_fp16 = gen_input(M, K, in_dtype, col_a, 1, device='cuda') + b, b_fp16 = gen_input(K, N, in_dtype, col_b, 2, device='cuda') # Allocates output. tl_out_dtype = name_to_tl_types[out_dtype] - c = torch.empty((M, N), device=a.device, dtype=tl_to_torch_types[tl_out_dtype]) + torch_out_dtype = tl_to_torch_types[tl_out_dtype] + c = torch.empty((M, N), device=a.device, dtype=torch_out_dtype) matmul(a, b, c, activation="") - torch_output = torch.matmul(a_fp16, b_fp16) - print(f"triton_output={c}") - print(f"torch_output={torch_output}") - rtol = 0 if torch.version.hip is None else 1e-2 - if torch.allclose(c.to(torch.float16), torch_output, atol=5e-2, rtol=rtol): - print("✅ Triton and Torch match") + if in_dtype == 'fp8e4' or in_dtype == 'fp8e5' or in_dtype == 'int8': + # For f8 and int8 inputs, use fp16 for torch.matmul + torch_output = torch.matmul(a_fp16, b_fp16) else: - print("❌ Triton and Torch differ") - assert torch.allclose(c, torch_output, atol=1e-2, rtol=rtol) + torch_output = torch.matmul(a, b) + #print(f"triton_output={c}") + #print(f"torch_output={torch_output}") + rtol = 0 if torch.version.hip is None else 1e-2 + if in_dtype == 'int8': + torch.testing.assert_close(c.to(torch.float16), torch_output, atol=5e-2, rtol=rtol) + else: + torch.testing.assert_close(c, torch_output.to(torch_out_dtype), atol=5e-2, rtol=rtol) # %% @@ -252,13 +268,6 @@ def get_type(provider): res = re.findall(r'\(.*?\)', provider) return res[0][1:-1] -def get_x_vals(): - x_vals = [(1024 * v, 1024 * v, 1024 * v) for v in range (1, 9)] - - x_vals += [(4864, 4096, 8192), (9728, 8192, 65536)] - - return x_vals - inout_dtype = { 'int8': torch.int8, 'fp16': torch.float16, @@ -293,8 +302,8 @@ def benchmark(M, N, K, provider): ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) else: # triton, different data types assert "triton" in provider - a, _ = gen_input(M, K, in_dtype, 1, device='cuda') - b, _ = gen_input(K, N, in_dtype, 2, device='cuda') + a, _ = gen_input(M, K, in_dtype, False, 1, device='cuda') + b, _ = gen_input(K, N, in_dtype, True, 2, device='cuda') # Allocates output. c = torch.empty((M, N), device=a.device, dtype=out_dtype)