mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge commit '36fc54b6f28168d3644808bfe299f1ba06a36272' into ifu230908-2
Conflicts: .gitignore bin/triton-translate.cpp include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td lib/Analysis/Utility.cpp lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp lib/Conversion/TritonGPUToLLVM/Utility.h lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp lib/Dialect/TritonGPU/IR/Dialect.cpp lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp lib/Target/LLVMIR/LLVMIRTranslation.cpp python/src/triton.cc python/test/unit/runtime/test_subproc.py python/triton/compiler/compiler.py python/triton/compiler/make_launcher.py python/triton/language/semantic.py python/triton/runtime/jit.py python/tutorials/06-fused-attention.py test/Conversion/triton_to_tritongpu.mlir test/Conversion/tritongpu_to_llvm.mlir test/TritonGPU/coalesce.mlir unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt
This commit is contained in:
@@ -260,7 +260,7 @@ class LayerNorm(torch.autograd.Function):
|
||||
# enqueue kernel
|
||||
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
|
||||
x_arg.stride(0), N, eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
||||
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1)
|
||||
ctx.save_for_backward(x, weight, bias, mean, rstd)
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
@@ -296,7 +296,7 @@ class LayerNorm(torch.autograd.Function):
|
||||
# accumulate partial sums in separate kernel
|
||||
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
|
||||
BLOCK_SIZE_M=32,
|
||||
BLOCK_SIZE_N=128)
|
||||
BLOCK_SIZE_N=128, num_ctas=1)
|
||||
return dx, None, dw, db, None
|
||||
|
||||
|
||||
@@ -356,19 +356,21 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
# utility functions
|
||||
if provider == 'triton':
|
||||
y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)
|
||||
def y_fwd(): return layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704
|
||||
if provider == 'torch':
|
||||
y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)
|
||||
def y_fwd(): return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704
|
||||
if provider == 'apex':
|
||||
apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)
|
||||
y_fwd = lambda: apex_layer_norm(x)
|
||||
apex_layer_norm = apex.normalization.FusedLayerNorm(
|
||||
w_shape).to(x.device).to(x.dtype)
|
||||
|
||||
def y_fwd(): return apex_layer_norm(x) # noqa: F811, E704
|
||||
# forward pass
|
||||
if mode == 'forward':
|
||||
gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500)
|
||||
# backward pass
|
||||
if mode == 'backward':
|
||||
gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6
|
||||
def gbps(ms): return 3 * x.numel() * x.element_size() / ms * 1e-6 # noqa: F811, E704
|
||||
y = y_fwd()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
|
||||
quantiles=quantiles, grad_to_none=[x], rep=500)
|
||||
|
||||
@@ -87,10 +87,14 @@ def _fwd_kernel(
|
||||
k = tl.load(K_block_ptr)
|
||||
v = tl.load(V_block_ptr)
|
||||
# -- compute qk ---
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)
|
||||
if IS_CAUSAL:
|
||||
qk = tl.where(P_SEQ + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
<<<<<<< HEAD
|
||||
qk += tl.dot(q, k)
|
||||
=======
|
||||
qk += tl.dot(q, k, out_dtype=tl.float16)
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
# -- compute scaling constant ---
|
||||
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
||||
alpha = tl.math.exp2(m_i - m_i_new)
|
||||
@@ -148,8 +152,8 @@ def _bwd_kernel(
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX,
|
||||
num_block,
|
||||
Z, H, N_CTX, P_SEQ,
|
||||
num_block_q, num_block_kv,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
CAUSAL: tl.constexpr,
|
||||
@@ -160,17 +164,23 @@ def _bwd_kernel(
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
# offset pointers for batch/head
|
||||
Q += off_z * stride_qz + off_h * stride_qh
|
||||
K += off_z * stride_qz + off_h * stride_qh
|
||||
V += off_z * stride_qz + off_h * stride_qh
|
||||
K += off_z * stride_kz + off_h * stride_kh
|
||||
V += off_z * stride_vz + off_h * stride_vh
|
||||
DO += off_z * stride_qz + off_h * stride_qh
|
||||
DQ += off_z * stride_qz + off_h * stride_qh
|
||||
<<<<<<< HEAD
|
||||
DK += off_z * stride_qz + off_h * stride_qh
|
||||
DV += off_z * stride_qz + off_h * stride_qh
|
||||
# See fwd pass above for explanation.
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
for start_n in range(0, num_block):
|
||||
=======
|
||||
DK += off_z * stride_kz + off_h * stride_kh
|
||||
DV += off_z * stride_vz + off_h * stride_vh
|
||||
for start_n in range(0, num_block_kv):
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
if CAUSAL:
|
||||
lo = start_n * BLOCK_M
|
||||
lo = tl.math.max(start_n * BLOCK_M - P_SEQ, 0)
|
||||
else:
|
||||
lo = 0
|
||||
# initialize row/col offsets
|
||||
@@ -187,14 +197,14 @@ def _bwd_kernel(
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D_ptrs = D + off_hz * N_CTX
|
||||
l_ptrs = L + off_hz * N_CTX
|
||||
# initialize dv amd dk
|
||||
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# initialize dk amd dv
|
||||
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# k and v stay in SRAM throughout
|
||||
k = tl.load(k_ptrs)
|
||||
v = tl.load(v_ptrs)
|
||||
# loop over rows
|
||||
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
|
||||
for start_m in range(lo, num_block_q * BLOCK_M, BLOCK_M):
|
||||
offs_m_curr = start_m + offs_m
|
||||
# load q, k, v, do on-chip
|
||||
q = tl.load(q_ptrs)
|
||||
@@ -226,10 +236,10 @@ def _bwd_kernel(
|
||||
q_ptrs += BLOCK_M * stride_qm
|
||||
do_ptrs += BLOCK_M * stride_qm
|
||||
# write-back
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
tl.store(dv_ptrs, dv)
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
tl.store(dk_ptrs, dk)
|
||||
tl.store(dv_ptrs, dv)
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel_dk_dv(
|
||||
@@ -456,6 +466,7 @@ class _attention(torch.autograd.Function):
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
BLOCK_M = 128
|
||||
<<<<<<< HEAD
|
||||
if torch.version.hip is None:
|
||||
BLOCK_N = 64 if Lk <= 64 else 32
|
||||
num_stages = 4 if Lk <= 64 else 3
|
||||
@@ -469,6 +480,14 @@ class _attention(torch.autograd.Function):
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2]
|
||||
|
||||
=======
|
||||
BLOCK_N = 64 if Lk <= 64 else 32
|
||||
num_stages = 4 if Lk <= 64 else 3
|
||||
num_warps = 4
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2]
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
L,
|
||||
@@ -488,7 +507,10 @@ class _attention(torch.autograd.Function):
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = Lk
|
||||
ctx.causal = causal
|
||||
<<<<<<< HEAD
|
||||
ctx.split_kernel = split_kernel
|
||||
=======
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
ctx.P_SEQ = P_SEQ
|
||||
return o
|
||||
|
||||
@@ -516,8 +538,28 @@ class _attention(torch.autograd.Function):
|
||||
block_scale = (q.shape[2] // ctx.grid[0]) // BLOCK
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do,
|
||||
<<<<<<< HEAD
|
||||
do_scaled, delta,
|
||||
BLOCK_M=block_scale * BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
=======
|
||||
delta,
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do,
|
||||
dq, dk, dv,
|
||||
L, delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2], ctx.P_SEQ,
|
||||
ctx.grid[0], triton.cdiv(k.shape[2], BLOCK),
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||
CAUSAL=ctx.causal,
|
||||
num_stages=1,
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
)
|
||||
if not ctx.split_kernel:
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
@@ -569,6 +611,7 @@ class _attention(torch.autograd.Function):
|
||||
attention = _attention.apply
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, P_SEQ',
|
||||
[(4, 48, 1024, 64, 128),
|
||||
(4, 48, 2048, 64, 128),
|
||||
@@ -578,10 +621,16 @@ attention = _attention.apply
|
||||
])
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
def test_op_fwd(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16):
|
||||
=======
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, P_SEQ', [(6, 9, 1024, 64, 128)])
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16):
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
<<<<<<< HEAD
|
||||
sm_scale = q.shape[-1] ** (-0.5)
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
@@ -614,6 +663,12 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, P_SEQ, dtype=torch.float16):
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ)
|
||||
=======
|
||||
sm_scale = 0.5
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ)
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
||||
if causal:
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
|
||||
200
python/tutorials/09-experimental-tma-matrix-multiplication.py
Normal file
200
python/tutorials/09-experimental-tma-matrix-multiplication.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Matrix Multiplication with TMA (Experimental)
|
||||
================================================
|
||||
In this tutorial, you will write a very short high-performance multiplication kernel that achieves
|
||||
performance on parallel with cuBLAS.
|
||||
"""
|
||||
|
||||
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files
|
||||
# (the "Software"), to deal in the Software without restriction,
|
||||
# including without limitation the rights to use, copy, modify, merge,
|
||||
# publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
# and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
if torch.cuda.get_device_capability()[0] < 9:
|
||||
import sys
|
||||
print("Skipping TMA benchmark for GPU with compute capability < 9")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4),
|
||||
# triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=2),
|
||||
# triton.Config({'BLOCK_SIZE_M': 512, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=4),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr, b_ptr, z_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_zm, stride_zn,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
|
||||
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr,
|
||||
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
block_offset_m = pid_m * BLOCK_SIZE_M
|
||||
block_offset_n = pid_n * BLOCK_SIZE_N
|
||||
|
||||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), order=(A_ORDER_0, A_ORDER_1))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(B_ORDER_0, B_ORDER_1))
|
||||
z = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
offs_m = block_offset_m + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = block_offset_n + tl.arange(0, BLOCK_SIZE_N)
|
||||
z_ptrs = z_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn
|
||||
mask = (offs_m < M)[:, None] & (offs_n < N)[None, :]
|
||||
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
a = tl.load(a_tile_ptr)
|
||||
b = tl.load(b_tile_ptr)
|
||||
z += tl.dot(a, b)
|
||||
a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_SIZE_K])
|
||||
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0])
|
||||
|
||||
z = z.to(tl.float16)
|
||||
|
||||
tl.store(z_ptrs, z, mask=mask)
|
||||
|
||||
|
||||
def matmul(a, b, a_order, b_order):
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
|
||||
z = torch.empty((M, N), device=a.device, dtype=torch.float16)
|
||||
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
|
||||
matmul_kernel[grid](a_ptr=a, b_ptr=b, z_ptr=z,
|
||||
M=M, N=N, K=K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_zm=z.stride(0), stride_zn=z.stride(1),
|
||||
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1],
|
||||
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1]
|
||||
)
|
||||
return z
|
||||
|
||||
|
||||
problem_list = [
|
||||
[2048, 512, 512, False, True],
|
||||
[2048, 1024, 1024, False, False],
|
||||
[2048, 2048, 2048, True, False],
|
||||
[2048, 4096, 4096, True, True],
|
||||
]
|
||||
|
||||
|
||||
def test_matmul():
|
||||
for case in problem_list:
|
||||
M, N, K, TRANS_A, TRANS_B = case
|
||||
print(M, N, K, TRANS_A, TRANS_B)
|
||||
if (TRANS_A):
|
||||
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
a_order = [0, 1]
|
||||
else:
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
a_order = [1, 0]
|
||||
|
||||
if (TRANS_B):
|
||||
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
|
||||
b_order = [0, 1]
|
||||
else:
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
b_order = [1, 0]
|
||||
|
||||
golden = torch.matmul(a, b)
|
||||
z = matmul(a, b, a_order, b_order)
|
||||
|
||||
golden = torch.nn.functional.normalize(golden)
|
||||
z = torch.nn.functional.normalize(z)
|
||||
torch.set_printoptions(profile="full")
|
||||
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
# argument names to use as an x-axis for the plot
|
||||
x_names=['M', 'N', 'K', 'TRANS_A', 'TRANS_B'],
|
||||
x_vals=problem_list, # different possible values for `x_name`
|
||||
line_arg='provider',
|
||||
# argument name whose value corresponds to a different line in the plot
|
||||
# possible values for `line_arg``
|
||||
line_vals=['cublas', 'triton'],
|
||||
# label name for the lines
|
||||
line_names=["cuBLAS", "Triton"],
|
||||
# line styles
|
||||
styles=[('green', '-'), ('green', '--'),
|
||||
('blue', '-'), ('blue', '--')],
|
||||
ylabel="TFLOPS", # label name for the y-axis
|
||||
plot_name="matmul-performance",
|
||||
# name for the plot. Used also as a file name for saving the plot.
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(M, N, K, TRANS_A, TRANS_B, provider):
|
||||
if (TRANS_A):
|
||||
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
a_order = [0, 1]
|
||||
else:
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
a_order = [1, 0]
|
||||
|
||||
if (TRANS_B):
|
||||
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
|
||||
b_order = [0, 1]
|
||||
else:
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
b_order = [1, 0]
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == 'cublas':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: matmul(a, b, a_order, b_order), rep=100, quantiles=quantiles, fast_flush=False)
|
||||
|
||||
def perf(ms):
|
||||
return 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||
return perf(ms), perf(max_ms), perf(min_ms)
|
||||
|
||||
|
||||
test_matmul()
|
||||
benchmark.run(show_plots=False, print_data=True)
|
||||
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Matrix Multiplication with TMA Store (Experimental)
|
||||
================================================
|
||||
In this tutorial, you will write a very short high-performance multiplication kernel that achieves
|
||||
performance on parallel with cuBLAS.
|
||||
"""
|
||||
|
||||
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files
|
||||
# (the "Software"), to deal in the Software without restriction,
|
||||
# including without limitation the rights to use, copy, modify, merge,
|
||||
# publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
# and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
if torch.cuda.get_device_capability()[0] < 9:
|
||||
import sys
|
||||
print("Skipping TMA benchmark for GPU with compute capability < 9")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4),
|
||||
# triton.Config({'BLOCK_SIZE_M': 512, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=4),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
block_offset_m = pid_m * BLOCK_SIZE_M
|
||||
block_offset_n = pid_n * BLOCK_SIZE_N
|
||||
|
||||
a_tile_ptr = tl.make_block_ptr(
|
||||
base=a_ptr, shape=(
|
||||
M, K), strides=(
|
||||
stride_am, stride_ak), offsets=(
|
||||
block_offset_m, 0), block_shape=(
|
||||
BLOCK_SIZE_M, BLOCK_SIZE_K), order=(
|
||||
1, 0))
|
||||
b_tile_ptr = tl.make_block_ptr(
|
||||
base=b_ptr, shape=(
|
||||
K, N), strides=(
|
||||
stride_bk, stride_bn), offsets=(
|
||||
0, block_offset_n), block_shape=(
|
||||
BLOCK_SIZE_K, BLOCK_SIZE_N), order=(
|
||||
0, 1))
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
a = tl.load(a_tile_ptr)
|
||||
b = tl.load(b_tile_ptr)
|
||||
accumulator += tl.dot(a, b)
|
||||
a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_SIZE_K])
|
||||
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0])
|
||||
|
||||
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
|
||||
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
|
||||
|
||||
tl.store(c_block_ptr, accumulator)
|
||||
|
||||
|
||||
def matmul(a, b):
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
assert (
|
||||
K % 32 == 0
|
||||
), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"
|
||||
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
|
||||
|
||||
matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
M=M, N=N, K=K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1))
|
||||
return c
|
||||
|
||||
|
||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16).T
|
||||
c = matmul(a, b)
|
||||
c = torch.nn.functional.normalize(c)
|
||||
|
||||
golden = torch.nn.functional.normalize(torch.matmul(a, b))
|
||||
|
||||
torch.set_printoptions(profile="full")
|
||||
assert_close(
|
||||
c,
|
||||
golden,
|
||||
rtol=1e-2,
|
||||
atol=1e-3,
|
||||
check_dtype=False)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
# argument names to use as an x-axis for the plot
|
||||
x_names=['M', 'N', 'K'],
|
||||
x_vals=[
|
||||
[2048, 512, 512],
|
||||
[2048, 1024, 1024],
|
||||
[2048, 2048, 2048],
|
||||
[2048, 4096, 4096],
|
||||
[2048, 8192, 8192]
|
||||
], # different possible values for `x_name`
|
||||
line_arg='provider',
|
||||
# argument name whose value corresponds to a different line in the plot
|
||||
# possible values for `line_arg``
|
||||
line_vals=['cublas', 'triton'],
|
||||
# label name for the lines
|
||||
line_names=["cuBLAS", "Triton"],
|
||||
# line styles
|
||||
styles=[('green', '-'), ('green', '--'),
|
||||
('blue', '-'), ('blue', '--')],
|
||||
ylabel="TFLOPS", # label name for the y-axis
|
||||
plot_name="matmul-performance",
|
||||
# name for the plot. Used also as a file name for saving the plot.
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(M, N, K, provider):
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == 'cublas':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False)
|
||||
|
||||
def perf(ms):
|
||||
return 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||
return perf(ms), perf(max_ms), perf(min_ms)
|
||||
|
||||
|
||||
benchmark.run(show_plots=False, print_data=True)
|
||||
Reference in New Issue
Block a user