mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
180 lines
6.5 KiB
Python
180 lines
6.5 KiB
Python
"""
|
|
Matrix Multiplication with TMA Store (Experimental)
|
|
================================================
|
|
In this tutorial, you will write a very short high-performance multiplication kernel that achieves
|
|
performance on parallel with cuBLAS.
|
|
"""
|
|
|
|
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person obtaining
|
|
# a copy of this software and associated documentation files
|
|
# (the "Software"), to deal in the Software without restriction,
|
|
# including without limitation the rights to use, copy, modify, merge,
|
|
# publish, distribute, sublicense, and/or sell copies of the Software,
|
|
# and to permit persons to whom the Software is furnished to do so,
|
|
# subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be
|
|
# included in all copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
|
|
import torch
|
|
from torch.testing import assert_close
|
|
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
if torch.cuda.get_device_capability()[0] < 9:
|
|
import sys
|
|
print("Skipping TMA benchmark for GPU with compute capability < 9")
|
|
sys.exit(0)
|
|
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4),
|
|
# triton.Config({'BLOCK_SIZE_M': 512, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=4),
|
|
],
|
|
key=['M', 'N', 'K'],
|
|
)
|
|
@triton.jit
|
|
def matmul_kernel(
|
|
a_ptr, b_ptr, c_ptr,
|
|
M, N, K,
|
|
stride_am, stride_ak,
|
|
stride_bk, stride_bn,
|
|
stride_cm, stride_cn,
|
|
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
|
GROUP_SIZE_M: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
group_id = pid // num_pid_in_group
|
|
first_pid_m = group_id * GROUP_SIZE_M
|
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
|
pid_m = first_pid_m + (pid % group_size_m)
|
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
block_offset_m = pid_m * BLOCK_SIZE_M
|
|
block_offset_n = pid_n * BLOCK_SIZE_N
|
|
|
|
a_tile_ptr = tl.make_block_ptr(
|
|
base=a_ptr, shape=(
|
|
M, K), strides=(
|
|
stride_am, stride_ak), offsets=(
|
|
block_offset_m, 0), block_shape=(
|
|
BLOCK_SIZE_M, BLOCK_SIZE_K), order=(
|
|
1, 0))
|
|
b_tile_ptr = tl.make_block_ptr(
|
|
base=b_ptr, shape=(
|
|
K, N), strides=(
|
|
stride_bk, stride_bn), offsets=(
|
|
0, block_offset_n), block_shape=(
|
|
BLOCK_SIZE_K, BLOCK_SIZE_N), order=(
|
|
0, 1))
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
|
|
for k in range(0, K, BLOCK_SIZE_K):
|
|
a = tl.load(a_tile_ptr)
|
|
b = tl.load(b_tile_ptr)
|
|
accumulator += tl.dot(a, b)
|
|
a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_SIZE_K])
|
|
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0])
|
|
|
|
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
|
|
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
|
|
|
|
tl.store(c_block_ptr, accumulator)
|
|
|
|
|
|
def matmul(a, b):
|
|
# checks constraints
|
|
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
|
M, K = a.shape
|
|
K, N = b.shape
|
|
assert (
|
|
K % 32 == 0
|
|
), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"
|
|
|
|
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
|
|
|
def grid(META):
|
|
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
|
|
|
|
matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
|
M=M, N=N, K=K,
|
|
stride_am=a.stride(0), stride_ak=a.stride(1),
|
|
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
|
stride_cm=c.stride(0), stride_cn=c.stride(1))
|
|
return c
|
|
|
|
|
|
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
|
b = torch.randn((512, 512), device='cuda', dtype=torch.float16).T
|
|
c = matmul(a, b)
|
|
c = torch.nn.functional.normalize(c)
|
|
|
|
golden = torch.nn.functional.normalize(torch.matmul(a, b))
|
|
|
|
torch.set_printoptions(profile="full")
|
|
assert_close(
|
|
c,
|
|
golden,
|
|
rtol=1e-2,
|
|
atol=1e-3,
|
|
check_dtype=False)
|
|
|
|
|
|
@triton.testing.perf_report(
|
|
triton.testing.Benchmark(
|
|
# argument names to use as an x-axis for the plot
|
|
x_names=['M', 'N', 'K'],
|
|
x_vals=[
|
|
[2048, 512, 512],
|
|
[2048, 1024, 1024],
|
|
[2048, 2048, 2048],
|
|
[2048, 4096, 4096],
|
|
[2048, 8192, 8192]
|
|
], # different possible values for `x_name`
|
|
line_arg='provider',
|
|
# argument name whose value corresponds to a different line in the plot
|
|
# possible values for `line_arg``
|
|
line_vals=['cublas', 'triton'],
|
|
# label name for the lines
|
|
line_names=["cuBLAS", "Triton"],
|
|
# line styles
|
|
styles=[('green', '-'), ('green', '--'),
|
|
('blue', '-'), ('blue', '--')],
|
|
ylabel="TFLOPS", # label name for the y-axis
|
|
plot_name="matmul-performance",
|
|
# name for the plot. Used also as a file name for saving the plot.
|
|
args={},
|
|
)
|
|
)
|
|
def benchmark(M, N, K, provider):
|
|
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
|
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
|
|
quantiles = [0.5, 0.2, 0.8]
|
|
if provider == 'cublas':
|
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False)
|
|
if provider == 'triton':
|
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
lambda: matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False)
|
|
|
|
def perf(ms):
|
|
return 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
|
return perf(ms), perf(max_ms), perf(min_ms)
|
|
|
|
|
|
benchmark.run(show_plots=False, print_data=True)
|