mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Tweak matmul tutorial on MI2xx GPU (#376)
* Tweak matmul tutorial on MI2xx GPU * Add config for 9728 --------- Co-authored-by: Shucai Xiao <shucai.xiao@amd.com>
This commit is contained in:
@@ -174,24 +174,17 @@ import pytest
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
||||
] if torch.version.hip is None else [
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 16}, num_warps=4, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4}, num_warps=4, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 32}, num_warps=4, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 32}, num_warps=4, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 16}, num_warps=4, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 32}, num_warps=4, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_warps=4, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4}, num_warps=4, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 16}, num_warps=4, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 32}, num_warps=4, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_warps=4, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, num_warps=4, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2}, num_warps=8, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, num_warps=8, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3}, num_warps=4, num_stages=0),
|
||||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8}, num_warps=4, num_stages=0),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % args['BLOCK_SIZE_K'] == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
# Pointers to matrices
|
||||
@@ -206,6 +199,7 @@ def matmul_kernel(
|
||||
stride_cm, stride_cn,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
ACTIVATION: tl.constexpr,
|
||||
):
|
||||
@@ -219,12 +213,16 @@ def matmul_kernel(
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
if GROUP_SIZE_M == 1:
|
||||
pid_m = pid // num_pid_n
|
||||
pid_n = pid % num_pid_n
|
||||
else:
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Create pointers for the first blocks of A and B.
|
||||
@@ -254,8 +252,12 @@ def matmul_kernel(
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
# Load the next block of A and B, generate a mask by checking the K dimension.
|
||||
# If it is out of bounds, set it to 0.
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
if EVEN_K:
|
||||
a = tl.load(a_ptrs)
|
||||
b = tl.load(b_ptrs)
|
||||
else:
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
# We accumulate along the K dimension.
|
||||
accumulator += tl.dot(a, b)
|
||||
# Advance the ptrs to the next K block.
|
||||
@@ -362,13 +364,17 @@ verbose = False
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M', 'N', 'K'], # Argument names to use as an x-axis for the plot
|
||||
x_vals=[
|
||||
128 * i for i in range(2, 33)
|
||||
(1024, 1024, 1024),
|
||||
(2048, 2048, 2048),
|
||||
(4096, 4096, 4096),
|
||||
(8192, 8192, 8192),
|
||||
(9728, 8192, 65536)
|
||||
], # Different possible values for `x_name`
|
||||
line_arg='provider', # Argument name whose value corresponds to a different line in the plot
|
||||
# Possible values for `line_arg`
|
||||
line_vals=['cublas', 'triton'],
|
||||
line_vals=['rocblas', 'triton'],
|
||||
# Label name for the lines
|
||||
line_names=["cuBLAS", "Triton"],
|
||||
line_names=["rocBLAS", "Triton"],
|
||||
# Line styles
|
||||
styles=[('green', '-'), ('blue', '-')],
|
||||
ylabel="TFLOPS", # Label name for the y-axis
|
||||
@@ -380,7 +386,7 @@ def benchmark(M, N, K, provider):
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == 'cublas':
|
||||
if provider == 'rocblas':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
|
||||
|
||||
Reference in New Issue
Block a user