mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Conflicts: lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp lib/Dialect/TritonGPU/IR/Dialect.cpp python/setup.py python/test/unit/language/assert_helper.py python/test/unit/operators/test_flash_attention.py python/test/unit/runtime/test_subproc.py python/triton/compiler/compiler.py python/triton/language/semantic.py python/triton/runtime/autotuner.py python/triton/runtime/jit.py python/tutorials/03-matrix-multiplication.py python/tutorials/05-layer-norm.py python/tutorials/06-fused-attention.py python/tutorials/11-grouped-gemm.py test/Conversion/tritongpu_to_llvm.mlir
459 lines
20 KiB
Python
459 lines
20 KiB
Python
"""
|
|
Matrix Multiplication
|
|
=====================
|
|
In this tutorial, you will write a very short high-performance FP16 matrix multiplication kernel that achieves
|
|
performance on parallel with cuBLAS.
|
|
|
|
You will specifically learn about:
|
|
|
|
* Block-level matrix multiplications.
|
|
|
|
* Multi-dimensional pointer arithmetics.
|
|
|
|
* Program re-ordering for improved L2 cache hit rate.
|
|
|
|
* Automatic performance tuning.
|
|
|
|
"""
|
|
|
|
# %%
|
|
# Motivations
|
|
# -----------
|
|
#
|
|
# Matrix multiplications are a key building block of most modern high-performance computing systems.
|
|
# They are notoriously hard to optimize, hence their implementation is generally done by
|
|
# hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
|
|
# Unfortunately, these libraries are often proprietary and cannot be easily customized
|
|
# to accommodate the needs of modern deep learning workloads (e.g., fused activation functions).
|
|
# In this tutorial, you will learn how to implement efficient matrix multiplications by
|
|
# yourself with Triton, in a way that is easy to customize and extend.
|
|
#
|
|
# Roughly speaking, the kernel that we will write will implement the following blocked
|
|
# algorithm to multiply a (M, K) by a (K, N) matrix:
|
|
#
|
|
# .. code-block:: python
|
|
#
|
|
# # Do in parallel
|
|
# for m in range(0, M, BLOCK_SIZE_M):
|
|
# # Do in parallel
|
|
# for n in range(0, N, BLOCK_SIZE_N):
|
|
# acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32)
|
|
# for k in range(0, K, BLOCK_SIZE_K):
|
|
# a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K]
|
|
# b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]
|
|
# acc += dot(a, b)
|
|
# C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc
|
|
#
|
|
# where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance.
|
|
|
|
# %%
|
|
# Compute Kernel
|
|
# --------------
|
|
#
|
|
# The above algorithm is, actually, fairly straightforward to implement in Triton.
|
|
# The main difficulty comes from the computation of the memory locations at which blocks
|
|
# of :code:`A` and :code:`B` must be read in the inner loop. For that, we need
|
|
# multi-dimensional pointer arithmetics.
|
|
#
|
|
# Pointer Arithmetics
|
|
# ~~~~~~~~~~~~~~~~~~~
|
|
#
|
|
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given b
|
|
# y :code:`&X[i, j] = X + i*stride_xi + j*stride_xj`.
|
|
# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and
|
|
# :code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as:
|
|
#
|
|
# .. code-block:: python
|
|
#
|
|
# &A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);
|
|
# &B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);
|
|
#
|
|
# Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as the following
|
|
# code. Also note that we need an extra modulo to handle the case where :code:`M` is not a multiple of
|
|
# :code:`BLOCK_SIZE_M` or :code:`N` is not a multiple of :code:`BLOCK_SIZE_N`, in which case we can pad the data with
|
|
# some useless values, which will not contribute to the results. For the :code:`K` dimension, we will handle that later
|
|
# using masking load semantics.
|
|
#
|
|
# .. code-block:: python
|
|
#
|
|
# offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
|
# offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
|
# offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
# a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
|
|
# b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)
|
|
#
|
|
# And then updated in the inner loop as follows:
|
|
#
|
|
# .. code-block:: python
|
|
#
|
|
# a_ptrs += BLOCK_SIZE_K * stride_ak;
|
|
# b_ptrs += BLOCK_SIZE_K * stride_bk;
|
|
#
|
|
#
|
|
# L2 Cache Optimizations
|
|
# ~~~~~~~~~~~~~~~~~~~~~~
|
|
#
|
|
# As mentioned above, each program instance computes a :code:`[BLOCK_SIZE_M, BLOCK_SIZE_N]`
|
|
# block of :code:`C`.
|
|
# It is important to remember that the order in which these blocks are computed does
|
|
# matter, since it affects the L2 cache hit rate of our program. and unfortunately, a
|
|
# a simple row-major ordering
|
|
#
|
|
# .. code-block:: Python
|
|
#
|
|
# pid = triton.program_id(0);
|
|
# grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M;
|
|
# grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N;
|
|
# pid_m = pid / grid_n;
|
|
# pid_n = pid % grid_n;
|
|
#
|
|
# is just not going to cut it.
|
|
#
|
|
# One possible solution is to launch blocks in an order that promotes data reuse.
|
|
# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before
|
|
# switching to the next column:
|
|
#
|
|
# .. code-block:: python
|
|
#
|
|
# # Program ID
|
|
# pid = tl.program_id(axis=0)
|
|
# # Number of program ids along the M axis
|
|
# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
|
# # Number of programs ids along the N axis
|
|
# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
# # Number of programs in group
|
|
# num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
# # Id of the group this program is in
|
|
# group_id = pid // num_pid_in_group
|
|
# # Row-id of the first program in the group
|
|
# first_pid_m = group_id * GROUP_SIZE_M
|
|
# # If `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
|
|
# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
|
# # *Within groups*, programs are ordered in a column-major order
|
|
# # Row-id of the program in the *launch grid*
|
|
# pid_m = first_pid_m + (pid % group_size_m)
|
|
# # Col-id of the program in the *launch grid*
|
|
# pid_n = (pid % num_pid_in_group) // group_size_m
|
|
#
|
|
# For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
|
|
# we can see that if we compute the output in row-major ordering, we need to load 90
|
|
# blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped
|
|
# ordering, we only need to load 54 blocks.
|
|
#
|
|
# .. image:: grouped_vs_row_major_ordering.png
|
|
#
|
|
# In practice, this can improve the performance of our matrix multiplication kernel by
|
|
# more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100).
|
|
#
|
|
|
|
# %%
|
|
# Final Result
|
|
# ------------
|
|
|
|
import torch
|
|
|
|
import triton
|
|
import triton.language as tl
|
|
import sys
|
|
import argparse
|
|
import pytest
|
|
|
|
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
|
|
# - A list of `triton.Config` objects that define different configurations of
|
|
# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
|
|
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
|
|
# provided configs
|
|
@triton.autotune(
|
|
configs=[
|
|
<<<<<<< HEAD
|
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
|
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
|
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
|
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
|
|
] if torch.version.hip is None else [
|
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, num_warps=4, num_stages=0),
|
|
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2}, num_warps=8, num_stages=0),
|
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, num_warps=8, num_stages=0),
|
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3}, num_warps=4, num_stages=0),
|
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8}, num_warps=4, num_stages=0),
|
|
=======
|
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
|
|
num_warps=8),
|
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
|
|
num_warps=4),
|
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
|
|
num_warps=4),
|
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
|
|
num_warps=4),
|
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
|
|
num_warps=4),
|
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
|
|
num_warps=4),
|
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
|
|
num_warps=2),
|
|
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
|
|
num_warps=2),
|
|
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
|
],
|
|
key=['M', 'N', 'K'],
|
|
)
|
|
@triton.heuristics({
|
|
'EVEN_K': lambda args: args['K'] % args['BLOCK_SIZE_K'] == 0,
|
|
})
|
|
@triton.jit
|
|
def matmul_kernel(
|
|
<<<<<<< HEAD
|
|
# Pointers to matrices
|
|
a_ptr, b_ptr, c_ptr,
|
|
# Matrix dimensions
|
|
M, N, K,
|
|
# The stride variables represent how much to increase the ptr by when moving by 1
|
|
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
|
|
# by to get the element one row down (A has M rows).
|
|
stride_am, stride_ak,
|
|
stride_bk, stride_bn,
|
|
stride_cm, stride_cn,
|
|
# Meta-parameters
|
|
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
|
EVEN_K: tl.constexpr,
|
|
GROUP_SIZE_M: tl.constexpr,
|
|
ACTIVATION: tl.constexpr,
|
|
=======
|
|
# Pointers to matrices
|
|
a_ptr, b_ptr, c_ptr,
|
|
# Matrix dimensions
|
|
M, N, K,
|
|
# The stride variables represent how much to increase the ptr by when moving by 1
|
|
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
|
|
# by to get the element one row down (A has M rows).
|
|
stride_am, stride_ak, #
|
|
stride_bk, stride_bn, #
|
|
stride_cm, stride_cn,
|
|
# Meta-parameters
|
|
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
|
|
GROUP_SIZE_M: tl.constexpr, #
|
|
ACTIVATION: tl.constexpr #
|
|
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
|
):
|
|
"""Kernel for computing the matmul C = A x B.
|
|
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
|
"""
|
|
# -----------------------------------------------------------
|
|
# Map program ids `pid` to the block of C it should compute.
|
|
# This is done in a grouped ordering to promote L2 data reuse.
|
|
# See above `L2 Cache Optimizations` section for details.
|
|
pid = tl.program_id(axis=0)
|
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
if GROUP_SIZE_M == 1:
|
|
pid_m = pid // num_pid_n
|
|
pid_n = pid % num_pid_n
|
|
else:
|
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
group_id = pid // num_pid_in_group
|
|
first_pid_m = group_id * GROUP_SIZE_M
|
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
|
pid_m = first_pid_m + (pid % group_size_m)
|
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
|
|
# ----------------------------------------------------------
|
|
# Create pointers for the first blocks of A and B.
|
|
# We will advance this pointer as we move in the K direction
|
|
# and accumulate
|
|
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
|
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
|
# See above `Pointer Arithmetics` section for details
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
if torch.version.hip is None:
|
|
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
|
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
|
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
else:
|
|
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))
|
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))
|
|
a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
|
|
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
|
|
|
|
# -----------------------------------------------------------
|
|
# Iterate to compute a block of the C matrix.
|
|
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
|
# of fp32 values for higher accuracy.
|
|
# `accumulator` will be converted back to fp16 after the loop.
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
|
# Load the next block of A and B, generate a mask by checking the K dimension.
|
|
# If it is out of bounds, set it to 0.
|
|
if EVEN_K:
|
|
a = tl.load(a_ptrs)
|
|
b = tl.load(b_ptrs)
|
|
else:
|
|
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
|
# We accumulate along the K dimension.
|
|
accumulator += tl.dot(a, b)
|
|
# Advance the ptrs to the next K block.
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
# You can fuse arbitrary activation functions here
|
|
# while the accumulator is still in FP32!
|
|
if ACTIVATION == "leaky_relu":
|
|
accumulator = leaky_relu(accumulator)
|
|
c = accumulator.to(tl.float16)
|
|
|
|
# -----------------------------------------------------------
|
|
# Write back the block of the output matrix C with masks.
|
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
|
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
|
tl.store(c_ptrs, c, mask=c_mask)
|
|
|
|
|
|
# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`.
|
|
@triton.jit
|
|
def leaky_relu(x):
|
|
x = x + 1
|
|
return tl.where(x >= 0, x, 0.01 * x)
|
|
|
|
|
|
# %%
|
|
# We can now create a convenience wrapper function that only takes two input tensors,
|
|
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.
|
|
|
|
|
|
def matmul(a, b, activation=""):
|
|
# Check constraints.
|
|
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
|
|
assert a.is_contiguous(), "Matrix A must be contiguous"
|
|
assert b.is_contiguous(), "Matrix B must be contiguous"
|
|
M, K = a.shape
|
|
K, N = b.shape
|
|
# Allocates output.
|
|
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
|
# 1D launch kernel where each block gets its own program.
|
|
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
|
|
matmul_kernel[grid](
|
|
a, b, c, #
|
|
M, N, K, #
|
|
a.stride(0), a.stride(1), #
|
|
b.stride(0), b.stride(1), #
|
|
c.stride(0), c.stride(1), #
|
|
ACTIVATION=activation #
|
|
)
|
|
return c
|
|
|
|
|
|
# %%
|
|
# Unit Test
|
|
# ---------
|
|
#
|
|
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS).
|
|
@pytest.mark.parametrize("M, N, K, in_dtype, out_dtype",
|
|
[ (*shape, in_dtype, out_dtype)
|
|
for shape in [(128, 256, 32), (128, 16, 32), (32, 128, 64),
|
|
(128, 128, 64), (64, 128, 128), (32, 128, 64),
|
|
(64, 64, 32), (32, 32, 128), (128, 128, 64),
|
|
(64, 128, 128), (512, 512, 512), (1024, 1024, 1024)]
|
|
for in_dtype, out_dtype in [('int8', 'int8'),
|
|
('float16', 'float16'),
|
|
('bfloat16', 'bfloat16'),
|
|
('float16', 'float32'),
|
|
('float32', 'float32')]]
|
|
)
|
|
def test_correctness(M, N, K, in_dtype, out_dtype):
|
|
torch.manual_seed(0)
|
|
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
|
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
|
triton_output = matmul(a, b)
|
|
torch_output = torch.matmul(a, b)
|
|
print(f"triton_output={triton_output}")
|
|
print(f"torch_output={torch_output}")
|
|
rtol = 0 if torch.version.hip is None else 1e-2
|
|
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
|
|
print("✅ Triton and Torch match")
|
|
else:
|
|
print("❌ Triton and Torch differ")
|
|
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol)
|
|
|
|
|
|
# %%
|
|
# Benchmark
|
|
# ---------
|
|
#
|
|
# Square Matrix Performance
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
#
|
|
# We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices,
|
|
# but feel free to arrange this script as you wish to benchmark any other matrix shape.
|
|
|
|
global verbose
|
|
verbose = False
|
|
|
|
@triton.testing.perf_report(
|
|
triton.testing.Benchmark(
|
|
x_names=['M', 'N', 'K'], # Argument names to use as an x-axis for the plot
|
|
<<<<<<< HEAD
|
|
x_vals=[
|
|
(1024, 1024, 1024),
|
|
(2048, 2048, 2048),
|
|
(4096, 4096, 4096),
|
|
(8192, 8192, 8192),
|
|
(9728, 8192, 65536)
|
|
], # Different possible values for `x_name`
|
|
=======
|
|
x_vals=[128 * i for i in range(2, 33)], # Different possible values for `x_name`
|
|
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
|
line_arg='provider', # Argument name whose value corresponds to a different line in the plot
|
|
# Possible values for `line_arg`
|
|
line_vals=['rocblas', 'triton'],
|
|
# Label name for the lines
|
|
line_names=["rocBLAS", "Triton"],
|
|
# Line styles
|
|
styles=[('green', '-'), ('blue', '-')],
|
|
ylabel="TFLOPS", # Label name for the y-axis
|
|
plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot.
|
|
args={},
|
|
))
|
|
def benchmark(M, N, K, provider):
|
|
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
|
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
|
quantiles = [0.5, 0.2, 0.8]
|
|
if provider == 'rocblas':
|
|
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
|
|
if provider == 'triton':
|
|
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
|
|
global verbose
|
|
if verbose:
|
|
print(f'SIZE: {M},{N},{K} Best tuning config: ({matmul_kernel.get_best_config()})')
|
|
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
|
return perf(ms), perf(max_ms), perf(min_ms)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
prog="GEMM tutorial example",
|
|
allow_abbrev=False,
|
|
)
|
|
|
|
parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config")
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
def main():
|
|
# assign to a global verbose var to indicate whether print
|
|
# best tuning config
|
|
global verbose
|
|
args = parse_args()
|
|
verbose=args.v
|
|
benchmark.run(show_plots=True, print_data=True)
|
|
|
|
if __name__ == '__main__':
|
|
sys.exit(main())
|