Merge commit '36fc54b6f28168d3644808bfe299f1ba06a36272' into ifu230908-2

Conflicts:
	.gitignore
	bin/triton-translate.cpp
	include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h
	include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
	include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
	lib/Analysis/Utility.cpp
	lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp
	lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
	lib/Conversion/TritonGPUToLLVM/Utility.h
	lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp
	lib/Dialect/TritonGPU/IR/Dialect.cpp
	lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
	lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
	lib/Target/LLVMIR/LLVMIRTranslation.cpp
	python/src/triton.cc
	python/test/unit/runtime/test_subproc.py
	python/triton/compiler/compiler.py
	python/triton/compiler/make_launcher.py
	python/triton/language/semantic.py
	python/triton/runtime/jit.py
	python/tutorials/06-fused-attention.py
	test/Conversion/triton_to_tritongpu.mlir
	test/Conversion/tritongpu_to_llvm.mlir
	test/TritonGPU/coalesce.mlir
	unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt
This commit is contained in:
Jason Furmanek
2023-10-02 18:01:04 +00:00
259 changed files with 32652 additions and 3712 deletions

View File

@@ -260,7 +260,7 @@ class LayerNorm(torch.autograd.Function):
# enqueue kernel
_layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd,
x_arg.stride(0), N, eps,
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1)
ctx.save_for_backward(x, weight, bias, mean, rstd)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
@@ -296,7 +296,7 @@ class LayerNorm(torch.autograd.Function):
# accumulate partial sums in separate kernel
_layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=128)
BLOCK_SIZE_N=128, num_ctas=1)
return dx, None, dw, db, None
@@ -356,19 +356,21 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c
quantiles = [0.5, 0.2, 0.8]
# utility functions
if provider == 'triton':
y_fwd = lambda: layer_norm(x, w_shape, weight, bias, eps)
def y_fwd(): return layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704
if provider == 'torch':
y_fwd = lambda: torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)
def y_fwd(): return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704
if provider == 'apex':
apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)
y_fwd = lambda: apex_layer_norm(x)
apex_layer_norm = apex.normalization.FusedLayerNorm(
w_shape).to(x.device).to(x.dtype)
def y_fwd(): return apex_layer_norm(x) # noqa: F811, E704
# forward pass
if mode == 'forward':
gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500)
# backward pass
if mode == 'backward':
gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6
def gbps(ms): return 3 * x.numel() * x.element_size() / ms * 1e-6 # noqa: F811, E704
y = y_fwd()
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True),
quantiles=quantiles, grad_to_none=[x], rep=500)

View File

@@ -87,10 +87,14 @@ def _fwd_kernel(
k = tl.load(K_block_ptr)
v = tl.load(V_block_ptr)
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)
if IS_CAUSAL:
qk = tl.where(P_SEQ + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
<<<<<<< HEAD
qk += tl.dot(q, k)
=======
qk += tl.dot(q, k, out_dtype=tl.float16)
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
alpha = tl.math.exp2(m_i - m_i_new)
@@ -148,8 +152,8 @@ def _bwd_kernel(
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
num_block,
Z, H, N_CTX, P_SEQ,
num_block_q, num_block_kv,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
CAUSAL: tl.constexpr,
@@ -160,17 +164,23 @@ def _bwd_kernel(
qk_scale = sm_scale * 1.44269504
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_qz + off_h * stride_qh
V += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_kz + off_h * stride_kh
V += off_z * stride_vz + off_h * stride_vh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
<<<<<<< HEAD
DK += off_z * stride_qz + off_h * stride_qh
DV += off_z * stride_qz + off_h * stride_qh
# See fwd pass above for explanation.
qk_scale = sm_scale * 1.44269504
for start_n in range(0, num_block):
=======
DK += off_z * stride_kz + off_h * stride_kh
DV += off_z * stride_vz + off_h * stride_vh
for start_n in range(0, num_block_kv):
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
if CAUSAL:
lo = start_n * BLOCK_M
lo = tl.math.max(start_n * BLOCK_M - P_SEQ, 0)
else:
lo = 0
# initialize row/col offsets
@@ -187,14 +197,14 @@ def _bwd_kernel(
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
l_ptrs = L + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# initialize dk amd dv
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
for start_m in range(lo, num_block_q * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
@@ -226,10 +236,10 @@ def _bwd_kernel(
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
tl.store(dv_ptrs, dv)
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
tl.store(dk_ptrs, dk)
tl.store(dv_ptrs, dv)
@triton.jit
def _bwd_kernel_dk_dv(
@@ -456,6 +466,7 @@ class _attention(torch.autograd.Function):
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
BLOCK_M = 128
<<<<<<< HEAD
if torch.version.hip is None:
BLOCK_N = 64 if Lk <= 64 else 32
num_stages = 4 if Lk <= 64 else 3
@@ -469,6 +480,14 @@ class _attention(torch.autograd.Function):
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2]
=======
BLOCK_N = 64 if Lk <= 64 else 32
num_stages = 4 if Lk <= 64 else 3
num_warps = 4
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2]
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
_fwd_kernel[grid](
q, k, v, sm_scale,
L,
@@ -488,7 +507,10 @@ class _attention(torch.autograd.Function):
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
ctx.causal = causal
<<<<<<< HEAD
ctx.split_kernel = split_kernel
=======
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
ctx.P_SEQ = P_SEQ
return o
@@ -516,8 +538,28 @@ class _attention(torch.autograd.Function):
block_scale = (q.shape[2] // ctx.grid[0]) // BLOCK
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do,
<<<<<<< HEAD
do_scaled, delta,
BLOCK_M=block_scale * BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
=======
delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do,
dq, dk, dv,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2], ctx.P_SEQ,
ctx.grid[0], triton.cdiv(k.shape[2], BLOCK),
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
CAUSAL=ctx.causal,
num_stages=1,
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
)
if not ctx.split_kernel:
_bwd_kernel[(ctx.grid[1],)](
@@ -569,6 +611,7 @@ class _attention(torch.autograd.Function):
attention = _attention.apply
<<<<<<< HEAD
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, P_SEQ',
[(4, 48, 1024, 64, 128),
(4, 48, 2048, 64, 128),
@@ -578,10 +621,16 @@ attention = _attention.apply
])
@pytest.mark.parametrize('causal', [False, True])
def test_op_fwd(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16):
=======
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, P_SEQ', [(6, 9, 1024, 64, 128)])
@pytest.mark.parametrize('causal', [False, True])
def test_op(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16):
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
<<<<<<< HEAD
sm_scale = q.shape[-1] ** (-0.5)
dout = torch.randn_like(q)
# reference implementation
@@ -614,6 +663,12 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, P_SEQ, dtype=torch.float16):
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ)
=======
sm_scale = 0.5
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ)
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
if causal:
p[:, :, M == 0] = float("-inf")

View File

@@ -0,0 +1,200 @@
"""
Matrix Multiplication with TMA (Experimental)
================================================
In this tutorial, you will write a very short high-performance multiplication kernel that achieves
performance on parallel with cuBLAS.
"""
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
if torch.cuda.get_device_capability()[0] < 9:
import sys
print("Skipping TMA benchmark for GPU with compute capability < 9")
sys.exit(0)
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=2),
# triton.Config({'BLOCK_SIZE_M': 512, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=4),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, z_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_zm, stride_zn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr,
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr
):
pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
block_offset_m = pid_m * BLOCK_SIZE_M
block_offset_n = pid_n * BLOCK_SIZE_N
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), order=(A_ORDER_0, A_ORDER_1))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(B_ORDER_0, B_ORDER_1))
z = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
offs_m = block_offset_m + tl.arange(0, BLOCK_SIZE_M)
offs_n = block_offset_n + tl.arange(0, BLOCK_SIZE_N)
z_ptrs = z_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn
mask = (offs_m < M)[:, None] & (offs_n < N)[None, :]
for k in range(0, K, BLOCK_SIZE_K):
a = tl.load(a_tile_ptr)
b = tl.load(b_tile_ptr)
z += tl.dot(a, b)
a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_SIZE_K])
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0])
z = z.to(tl.float16)
tl.store(z_ptrs, z, mask=mask)
def matmul(a, b, a_order, b_order):
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
K, N = b.shape
z = torch.empty((M, N), device=a.device, dtype=torch.float16)
def grid(META):
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
matmul_kernel[grid](a_ptr=a, b_ptr=b, z_ptr=z,
M=M, N=N, K=K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_zm=z.stride(0), stride_zn=z.stride(1),
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1],
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1]
)
return z
problem_list = [
[2048, 512, 512, False, True],
[2048, 1024, 1024, False, False],
[2048, 2048, 2048, True, False],
[2048, 4096, 4096, True, True],
]
def test_matmul():
for case in problem_list:
M, N, K, TRANS_A, TRANS_B = case
print(M, N, K, TRANS_A, TRANS_B)
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
golden = torch.matmul(a, b)
z = matmul(a, b, a_order, b_order)
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
torch.set_printoptions(profile="full")
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
@triton.testing.perf_report(
triton.testing.Benchmark(
# argument names to use as an x-axis for the plot
x_names=['M', 'N', 'K', 'TRANS_A', 'TRANS_B'],
x_vals=problem_list, # different possible values for `x_name`
line_arg='provider',
# argument name whose value corresponds to a different line in the plot
# possible values for `line_arg``
line_vals=['cublas', 'triton'],
# label name for the lines
line_names=["cuBLAS", "Triton"],
# line styles
styles=[('green', '-'), ('green', '--'),
('blue', '-'), ('blue', '--')],
ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance",
# name for the plot. Used also as a file name for saving the plot.
args={},
)
)
def benchmark(M, N, K, TRANS_A, TRANS_B, provider):
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
quantiles = [0.5, 0.2, 0.8]
if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: matmul(a, b, a_order, b_order), rep=100, quantiles=quantiles, fast_flush=False)
def perf(ms):
return 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
test_matmul()
benchmark.run(show_plots=False, print_data=True)

View File

@@ -0,0 +1,179 @@
"""
Matrix Multiplication with TMA Store (Experimental)
================================================
In this tutorial, you will write a very short high-performance multiplication kernel that achieves
performance on parallel with cuBLAS.
"""
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
if torch.cuda.get_device_capability()[0] < 9:
import sys
print("Skipping TMA benchmark for GPU with compute capability < 9")
sys.exit(0)
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 512, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=4),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
block_offset_m = pid_m * BLOCK_SIZE_M
block_offset_n = pid_n * BLOCK_SIZE_N
a_tile_ptr = tl.make_block_ptr(
base=a_ptr, shape=(
M, K), strides=(
stride_am, stride_ak), offsets=(
block_offset_m, 0), block_shape=(
BLOCK_SIZE_M, BLOCK_SIZE_K), order=(
1, 0))
b_tile_ptr = tl.make_block_ptr(
base=b_ptr, shape=(
K, N), strides=(
stride_bk, stride_bn), offsets=(
0, block_offset_n), block_shape=(
BLOCK_SIZE_K, BLOCK_SIZE_N), order=(
0, 1))
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
a = tl.load(a_tile_ptr)
b = tl.load(b_tile_ptr)
accumulator += tl.dot(a, b)
a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_SIZE_K])
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0])
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
tl.store(c_block_ptr, accumulator)
def matmul(a, b):
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
K, N = b.shape
assert (
K % 32 == 0
), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K"
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
def grid(META):
return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
M=M, N=N, K=K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1))
return c
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16).T
c = matmul(a, b)
c = torch.nn.functional.normalize(c)
golden = torch.nn.functional.normalize(torch.matmul(a, b))
torch.set_printoptions(profile="full")
assert_close(
c,
golden,
rtol=1e-2,
atol=1e-3,
check_dtype=False)
@triton.testing.perf_report(
triton.testing.Benchmark(
# argument names to use as an x-axis for the plot
x_names=['M', 'N', 'K'],
x_vals=[
[2048, 512, 512],
[2048, 1024, 1024],
[2048, 2048, 2048],
[2048, 4096, 4096],
[2048, 8192, 8192]
], # different possible values for `x_name`
line_arg='provider',
# argument name whose value corresponds to a different line in the plot
# possible values for `line_arg``
line_vals=['cublas', 'triton'],
# label name for the lines
line_names=["cuBLAS", "Triton"],
# line styles
styles=[('green', '-'), ('green', '--'),
('blue', '-'), ('blue', '--')],
ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance",
# name for the plot. Used also as a file name for saving the plot.
args={},
)
)
def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
quantiles = [0.5, 0.2, 0.8]
if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False)
def perf(ms):
return 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
benchmark.run(show_plots=False, print_data=True)