Initial code merge of Hopper support (#2036)

The initial code merge of Nvidia Hopper features support. Please be
aware that the code merge is not finished yet and the trouble-shooting
is still ongoing. The new hardware features (GMMA, TMA, STMATRIX etc.)
and automatic warp-specialization are experimental for now and turned
off by default. It is recommended for a trial when version 3.0 is
released.

The work is contributed by:
ben-zhang-609, bealwang, donproc, qliu93, jsh20, allatit23, LyricZhao,
ivanyinwz, goostavz & yangjunpro
from Nvidia, in cooperation with:
ptillet, Jokeren, ThomasRaoux & zahimoud
from OpenAI.

Co-authored-by: Goostav Zhu <gzhu@nvidia.com>
This commit is contained in:
goostavz
2023-08-07 09:53:04 +08:00
committed by GitHub
parent 5df904233c
commit f1512bded1
220 changed files with 28448 additions and 2295 deletions

View File

View File

@@ -0,0 +1,480 @@
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
"""
# import numpy as np
import pytest
import torch
import triton
import triton.language as tl
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
L, M,
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX, D0,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# TODO: may replace with TMA store without range offset
# initialize offsets for store
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
stride_qh_2d = stride_qh // stride_qm // stride_qk
q_tile_ptr = tl.make_block_ptr(base=Q,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(
off_hz * stride_qh_2d + start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
k_tile_ptr = tl.make_block_ptr(base=K,
shape=(D0, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(off_hz * stride_qh_2d, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0))
v_tile_ptr = tl.make_block_ptr(base=V,
shape=(D0, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(off_hz * stride_qh_2d, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0))
out_tile_ptr = tl.make_block_ptr(base=Out,
shape=(D0, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
# load q: it will stay in SRAM throughout
q = tl.load(q_tile_ptr)
# loop over k, v and update accumulators
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
# -- compute qk ----
k = tl.load(k_tile_ptr, boundary_check=(0, 1))
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (
start_n + offs_n[None, :]), qk, float("-inf"))
# compute new m
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
# correct old l
l_prev *= tl.exp(m_prev - m_curr)
# attention weights
p = tl.exp(qk - m_curr[:, None])
l_curr = tl.sum(p, 1) + l_prev
# rescale operands of matmuls
l_rcp = 1. / l_curr
p *= l_rcp[:, None]
acc *= (l_prev * l_rcp)[:, None]
# update acc
p = p.to(tl.float16)
v = tl.load(v_tile_ptr, boundary_check=(0, 1))
acc += tl.dot(p, v)
# update m_i and l_i
l_prev = l_curr
m_prev = m_curr
# update pointers
k_tile_ptr = tl.advance(k_tile_ptr, [BLOCK_N, 0])
v_tile_ptr = tl.advance(v_tile_ptr, [BLOCK_N, 0])
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(l_ptrs, l_prev)
tl.store(m_ptrs, m_prev)
acc = acc.to(tl.float16)
tl.store(out_tile_ptr, acc, boundary_check=(0, 1))
@triton.jit
def _bwd_preprocess(
Out, DO, L,
NewDO, Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
denom = tl.load(L + off_m).to(tl.float32)
# compute
do = do / denom[:, None]
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
tl.store(Delta + off_m, delta)
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L, M,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX, D0,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
# init tile_ptr
stride_qz_2d = stride_qz // stride_qm // stride_qk
stride_qh_2d = stride_qh // stride_qm // stride_qk
q_tile_ptr = tl.make_block_ptr(base=Q,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
k_tile_ptr = tl.make_block_ptr(base=K,
shape=(D0, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
v_tile_ptr = tl.make_block_ptr(base=V,
shape=(D0, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
do_tile_ptr = tl.make_block_ptr(base=DO,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
dq_tile_ptr = tl.make_block_ptr(base=DQ,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
dk_tile_ptr = tl.make_block_ptr(base=DK,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
dv_tile_ptr = tl.make_block_ptr(base=DV,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
# offset pointers for batch/head
DQ += off_z * stride_qz + off_h * stride_qh
for start_n in range(0, num_block):
lo = start_n * BLOCK_M
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
m_ptrs = M + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_tile_ptr, boundary_check=(0, 1))
v = tl.load(v_tile_ptr, boundary_check=(0, 1))
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_tile_ptr, boundary_check=(0, 1))
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, tl.trans(k))
qk = tl.where(offs_m_curr[:, None] >= (
offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_tile_ptr, boundary_check=(0, 1))
dv += tl.dot(tl.trans(p.to(tl.float16)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, tl.trans(v))
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(tl.trans(ds.to(tl.float16)), q)
# compute dq
dq = tl.load(dq_tile_ptr)
dq += tl.dot(ds.to(tl.float16), k)
tl.store(dq_tile_ptr, dq)
# increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_tile_ptr = tl.advance(q_tile_ptr, [BLOCK_M, 0])
do_tile_ptr = tl.advance(do_tile_ptr, [BLOCK_M, 0])
dq_tile_ptr = tl.advance(dq_tile_ptr, [BLOCK_M, 0])
q_tile_ptr = tl.advance(q_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0])
do_tile_ptr = tl.advance(do_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0])
dq_tile_ptr = tl.advance(dq_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0])
# increment tile pointers
k_tile_ptr = tl.advance(k_tile_ptr, [BLOCK_M, 0])
v_tile_ptr = tl.advance(v_tile_ptr, [BLOCK_M, 0])
# write-back
tl.store(dv_tile_ptr, dv.to(tl.float16), boundary_check=(0, 1))
tl.store(dk_tile_ptr, dk.to(tl.float16), boundary_check=(0, 1))
dv_tile_ptr = tl.advance(dv_tile_ptr, [BLOCK_M, 0])
dk_tile_ptr = tl.advance(dk_tile_ptr, [BLOCK_M, 0])
empty = torch.empty(128, device="cuda")
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale):
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
L = torch.empty(
(q.shape[0] * q.shape[1], q.shape[2]),
device=q.device,
dtype=torch.float32)
m = torch.empty(
(q.shape[0] * q.shape[1], q.shape[2]),
device=q.device,
dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
D0 = q.shape[0] * q.shape[1] * q.shape[2]
_fwd_kernel[grid](
q, k, v, sm_scale,
L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2], D0,
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=2,
)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
return o
@staticmethod
def backward(ctx, do):
BLOCK = 128
q, k, v, o, l, m = ctx.saved_tensors
do = do.contiguous()
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
D0 = q.shape[0] * q.shape[1] * q.shape[2]
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
do_scaled, delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2], D0,
ctx.grid[0],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
num_stages=1,
)
return dq, dk, dv, None
attention = _attention.apply
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 128, 64),
# (4, 48, 256, 64),
# (4, 48, 512, 64),
# (4, 48, 1024, 64),
# (4, 48, 2048, 64),
# (4, 48, 4096, 64),
# (4, 48, 8192, 64), out of memory
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+")
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty(
(Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(
mean=0.1, std=0.2).requires_grad_()
k = torch.empty(
(Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(
mean=0.4, std=0.2).requires_grad_()
v = torch.empty(
(Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(
mean=0.3, std=0.2).requires_grad_()
sm_scale = 0.2
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
for z in range(Z):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
# p = torch.exp(p)
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# triton implementation
tri_out = attention(q, k, v, sm_scale)
# print(ref_out)
# print(tri_out)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0)
torch.testing.assert_close(ref_dq, tri_dq, atol=1e-2, rtol=0)
torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=0)
torch.testing.assert_close(ref_dk, tri_dk, atol=1e-2, rtol=0)
try:
from flash_attn.flash_attn_interface import flash_attn_func
HAS_FLASH = True
except BaseException:
HAS_FLASH = False
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = [triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 14)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={
'H': N_HEADS,
'BATCH': BATCH,
'D_HEAD': D_HEAD,
'dtype': torch.float16,
'mode': mode}
) for mode in ['fwd', 'bwd']]
@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
assert mode in ['fwd', 'bwd']
warmup = 25
rep = 100
if provider == "triton":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
sm_scale = 1.3
fn = lambda: attention(q, k, v, sm_scale)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros(
(BATCH + 1,), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
# only works on post-Ampere GPUs right now
# bench_flash_attention.run(save_path='.', print_data=True)

View File

@@ -0,0 +1,469 @@
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import itertools
import os
import re
import pytest
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
from .utils import get_proper_err, get_variant_golden
@triton.jit
def matmul_no_scf_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr
):
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
a = tl.load(a_block_ptr)
b = tl.load(b_block_ptr)
c = tl.dot(a, b)
if FLOAT16_OUTPUT:
c = c.to(tl.float16)
if USE_TMA_EPILOGUE:
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
tl.store(c_block_ptr, c)
else:
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, c)
@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,ENABLE_WS',
itertools.chain(
*[
[
# numCTAs = 1, no TMA multicast:
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
# static mask, cluster 4x1
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
# dynamic mask, cluster 2x2
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
# small M, N
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
] for USE_TMA_EPILOGUE in [True, False]
for ENABLE_WS in [False, True]
]))
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, ENABLE_WS):
if '-'.join(map(str, [USE_TMA_EPILOGUE, ENABLE_WS])) in [
'True-True'
]:
pytest.skip("error, skip")
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if OUTPUT_TYPE == "float16":
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
else:
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
matmul_no_scf_kernel[(1, 1)](a_ptr=a, b_ptr=b, c_ptr=c,
M=M, N=N, K=K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"),
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE,
enable_warp_specialization=ENABLE_WS)
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
golden = torch.matmul(a_f32, b_f32)
golden_variant = get_variant_golden(a_f32, b_f32)
golden_abs_err, golden_rel_err = get_proper_err(golden, golden_variant)
torch.set_printoptions(profile="full")
assert_close(
c,
golden,
rtol=max(1e-2, 1.1 * golden_rel_err),
atol=max(1e-3, 1.1 * golden_abs_err),
check_dtype=False)
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_wm, stride_wn,
stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr,
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr,
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr,
Z_ORDER_0: tl.constexpr, Z_ORDER_1: tl.constexpr
):
pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
block_offset_m = pid_m * BLOCK_M
block_offset_n = pid_n * BLOCK_N
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(A_ORDER_0, A_ORDER_1))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(B_ORDER_0, B_ORDER_1))
w_tile_ptr = tl.make_block_ptr(base=w_ptr, shape=(N, N), strides=(stride_wm, stride_wn),
offsets=(0, block_offset_n), block_shape=(BLOCK_N, BLOCK_N), order=(Z_ORDER_1, Z_ORDER_0))
z = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
offs_m = block_offset_m + tl.arange(0, BLOCK_M)
offs_n = block_offset_n + tl.arange(0, BLOCK_N)
z_ptrs = z_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn
bias_ptrs = bias_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn
mask = (offs_m < M)[:, None] & (offs_n < N)[None, :]
for k in range(0, K, BLOCK_K):
a = tl.load(a_tile_ptr, boundary_check=(0, 1))
b = tl.load(b_tile_ptr, boundary_check=(0, 1))
z += tl.dot(a, b)
a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_K])
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_K, 0])
z = z.to(out_dtype)
if ADD_MATRIX:
z += tl.load(bias_ptrs, mask=mask)
if ADD_ROWS:
ZRs = bias_ptr + offs_m * stride_zm
z += tl.load(ZRs)[:, None]
if ADD_COLS:
ZCs = bias_ptr + offs_n * stride_zn
z += tl.load(ZCs)[None, :]
if DO_SOFTMAX:
max = tl.max(z, 1)
z = z - max[:, None]
num = tl.exp(z.to(tl.float32)).to(max.dtype)
den = tl.sum(num, 1)
z = num / den[:, None]
if CHAIN_DOT:
w = tl.load(w_tile_ptr)
z = tl.dot(z.to(w.dtype), w)
z = z.to(out_dtype)
if USE_TMA_STORE:
z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn),
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(Z_ORDER_0, Z_ORDER_1))
tl.store(z_block_ptr, z, boundary_check=(0, 1))
else:
tl.store(z_ptrs, z, mask=mask)
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_C,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
[(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws)
for shape_w_c in [
# badcase from cublas-important-layers
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for enable_ws in [False, True]
] + [(*shape_w_c, trans_a, trans_b, trans_c, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
# softmax works for one CTA
for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
]
for epilogue in ['softmax']
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for trans_a in [False, True]
for trans_b in [False, True]
for trans_c in [False, True]
for num_stages in [3]
for enable_ws in [False, True]
] + [(*shape_w_c, trans_a, trans_b, trans_c, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
]
for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot']
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for trans_a in [False, True]
for trans_b in [False, True]
for trans_c in [False, True]
for num_stages in [3]
for enable_ws in [False, True]
if not (epilogue == 'chain-dot' and (shape_w_c[5] is not None or shape_w_c[0] != shape_w_c[1]))
] + [(*shape_w_c, trans_a, trans_b, trans_c, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
]
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for trans_a in [False, True]
for trans_b in [False, True]
for trans_c in [False, True]
for num_stages in [3]
for enable_ws in [False, True]
] + [(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_c, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
# loop over instr shapes
for n in [16, 32, 64, 128, 256]
for trans_c in [False, True]
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for num_stages in [2, 4, 5, 7]
for enable_ws in [False, True]
] + [(*shape_w_c, *shape, False, True, trans_c, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
# irregular shapes
for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
]
for shape in list(itertools.product([*range(512, 4096, 360)], [*range(512, 4096, 360)], [512, 1024]))
for trans_c in [False, True]
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for num_stages in [2, 3, 4]
for enable_ws in [False, True]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()
[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_C, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
# with ENABLE_TMA=0 and ENABLE_MMA_V3=0
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('illegal memory access.')
# with ENABLE_TMA=1 and ENABLE_MMA_V3=1
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K])) in [
'64-64-32-8-1-128-256-64',
]:
pytest.skip('Tensor-likes are not close!')
if NUM_CTAS > 1 and NUM_WARPS == 8:
pytest.skip('Tensor-likes are not close!')
# with ENABLE_TMA=1 and ENABLE_MMA_V3=1
if ENABLE_WS:
# example:
# [128-128-64-4-1-None-None-None-False-False-False-chain-dot-float16-False-3-True]
pytest.skip('hang!')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_C):
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
if epilogue == 'chain-dot':
if (TRANS_C):
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
else:
w = torch.randn((M, M), device='cuda', dtype=torch.float16)
else:
w = torch.randn((1, 1), device='cuda', dtype=torch.float16).T
if (TRANS_C):
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
else:
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
pgm = matmul_kernel[grid](a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z,
M=M, N=N, K=K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_wm=w.stride(0), stride_wn=w.stride(1),
stride_zm=z.stride(0), stride_zn=z.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8,
out_dtype=out_dtype,
USE_TMA_STORE=USE_TMA_STORE,
ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS=epilogue == 'add-rows',
ADD_COLS=epilogue == 'add-cols',
DO_SOFTMAX=epilogue == 'softmax',
CHAIN_DOT=epilogue == 'chain-dot',
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1],
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1],
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1],
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES,
enable_warp_specialization=ENABLE_WS)
torch.set_printoptions(profile="full")
# print("abs_err: {}, rel_err: {}".format(golden_abs_err, golden_rel_err))
# print("golden: ")
# print(golden)
# print("result: ")
# print(z)
# print("max_gap: {}".format(torch.max(torch.abs(z - golden))))
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden,
rtol=1e-2,
atol=1e-3,
check_dtype=False)
enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower()
if enable_mmav3 in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
ptx = pgm.asm['ptx']
assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(BLOCK_N), ptx)

View File

@@ -0,0 +1,166 @@
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import pytest
import torch
import triton
import triton.language as tl
@triton.jit
def gemm_fusion_kernel(A, B, C, E,
M, N, K,
stride_am, stride_ak, stride_bn, stride_bk, stride_cn, stride_ck, stride_em, stride_ek,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
pid = tl.program_id(0)
a_tile_ptr = tl.make_block_ptr(base=A, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=B, shape=(N, K), strides=(stride_bn, stride_bk), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))
c_tile_ptr = tl.make_block_ptr(base=C, shape=(N, K), strides=(stride_cn, stride_ck), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))
e_tile_ptr = tl.make_block_ptr(base=E, shape=(M, K), strides=(stride_em, stride_ek), offsets=(pid * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
acc_e = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
a = tl.load(a_tile_ptr)
for i in range(0, N, BLOCK_N):
b = tl.load(b_tile_ptr)
o_ab = tl.dot(a, tl.trans(b))
c = tl.load(c_tile_ptr)
o_ab = o_ab.to(tl.float16)
acc_e += tl.dot(o_ab, c)
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_N, 0])
c_tile_ptr = tl.advance(c_tile_ptr, [BLOCK_N, 0])
acc_e = acc_e.to(tl.float16)
tl.store(e_tile_ptr, acc_e)
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="not passed on ampere")
def test_gemm_fusion():
M, N, K = 4096, 4096, 64
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64
A = torch.empty(
(M, K), dtype=torch.float16, device='cuda').normal_(
mean=0.1, std=0.2)
B = torch.empty(
(N, K), dtype=torch.float16, device='cuda').normal_(
mean=0.1, std=0.2)
C = torch.empty(
(N, K), dtype=torch.float16, device='cuda').normal_(
mean=0.1, std=0.2)
E = torch.empty((M, K), dtype=torch.float16, device='cuda')
ref_out = torch.matmul(torch.matmul(A, B.T), C)
num_warps = 4
grid = (triton.cdiv(M, BLOCK_M), 1)
gemm_fusion_kernel[grid](A, B, C, E, M, N, K,
A.stride(0), A.stride(1), B.stride(0), B.stride(
1), C.stride(0), C.stride(1), E.stride(0), E.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K, num_warps=num_warps)
torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0)
@triton.jit
def batched_gemm_fusion(
Q, K, V, Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, NH, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
q_tile_ptr = tl.make_block_ptr(base=Q,
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
strides=(stride_qz, stride_qh, stride_qm, stride_qk),
offsets=(off_hz // NH, off_hz % NH, start_m, 0),
block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),
order=(3, 2, 1, 0))
k_tile_ptr = tl.make_block_ptr(base=K,
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
strides=(stride_kz, stride_kh, stride_kn, stride_kk),
offsets=(off_hz // NH, off_hz % NH, 0, 0),
block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),
order=(3, 2, 1, 0))
v_tile_ptr = tl.make_block_ptr(base=V,
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
strides=(stride_vz, stride_vh, stride_vk, stride_vn),
offsets=(off_hz // NH, off_hz % NH, 0, 0),
block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),
order=(3, 2, 1, 0))
o_tile_ptr = tl.make_block_ptr(base=Out,
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
strides=(stride_oz, stride_oh, stride_om, stride_on),
offsets=(off_hz // NH, off_hz % NH, start_m, 0),
block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),
order=(3, 2, 1, 0))
q = tl.load(q_tile_ptr, boundary_check=(0, 1, 2, 3))
q = tl.view(q, (BLOCK_M, BLOCK_DMODEL))
for i in range(0, N_CTX, BLOCK_N):
k = tl.load(k_tile_ptr, boundary_check=(0, 1, 2, 3))
k = tl.view(k, (BLOCK_N, BLOCK_DMODEL))
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
p = qk.to(tl.float16)
v = tl.load(v_tile_ptr, boundary_check=(0, 1, 2, 3))
v = tl.view(v, (BLOCK_N, BLOCK_DMODEL))
acc += tl.dot(p, v)
k_tile_ptr = tl.advance(k_tile_ptr, [0, 0, BLOCK_N, 0])
v_tile_ptr = tl.advance(v_tile_ptr, [0, 0, BLOCK_N, 0])
acc = tl.view(acc, (1, 1, BLOCK_M, BLOCK_DMODEL))
acc = acc.to(tl.float16)
tl.store(o_tile_ptr, acc)
@pytest.mark.skip(reason="don't support 4d across stack, left for future")
def test_batched_gemm_fusion():
Z = 4
NH = 48
H = 64
N_CTX = 2048
BLOCK_M, BLOCK_N, BLOCK_DMODEL = 128, 128, H
torch.manual_seed(20)
A = torch.empty((Z, NH, N_CTX, H), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)
B = torch.empty((Z, NH, N_CTX, H), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)
C = torch.empty((Z, NH, N_CTX, H), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)
E = torch.empty_like(A)
BT = B.transpose(-1, -2)
ref_out = torch.matmul(torch.matmul(A, BT), C)
num_warps = 4
grid = (triton.cdiv(N_CTX, BLOCK_M), B * NH)
batched_gemm_fusion[grid](A, B, C, E,
A.stride(0), A.stride(1), A.stride(2), A.stride(3),
B.stride(0), B.stride(1), B.stride(2), B.stride(3),
C.stride(0), C.stride(1), C.stride(2), C.stride(3),
E.stride(0), E.stride(1), E.stride(2), E.stride(3),
Z, NH, N_CTX,
BLOCK_M, BLOCK_DMODEL, BLOCK_N, num_warps=num_warps)
torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0)

View File

@@ -0,0 +1,89 @@
import pytest
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
dtype_mapping = {
'float16': torch.float16,
'float32': torch.float32,
}
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x_block_ptr = tl.make_block_ptr(
base=x_ptr, shape=(n_elements, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
block_shape=(BLOCK_SIZE, ), order=(0, )
)
x = tl.load(x_block_ptr, boundary_check=(0, ), padding_option='zero')
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
@pytest.mark.parametrize('SIZE,BLOCK_SIZE,dtype_str',
[(98432, 1024, dtype_str)
for dtype_str in ['float16', 'float32']
])
def test_add(SIZE, BLOCK_SIZE, dtype_str):
dtype = dtype_mapping[dtype_str]
output = torch.empty(SIZE, device='cuda', dtype=dtype)
x = torch.randn(SIZE, device='cuda', dtype=dtype)
y = torch.randn(SIZE, device='cuda', dtype=dtype)
def grid(meta):
return (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
add_kernel[grid](x, y, output, SIZE, BLOCK_SIZE=BLOCK_SIZE)
output_torch = x + y
torch.set_printoptions(profile='full')
assert_close(output, output_torch, rtol=1e-2, atol=1e-3, check_dtype=False)
@triton.jit
def load_reduce_kernel(
x_ptr,
y_ptr,
stride_xm,
stride_xn,
stride_y,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
x_ptr = tl.make_block_ptr(
base=x_ptr, shape=(BLOCK_M, BLOCK_N), strides=(stride_xm, stride_xn),
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)
)
x = tl.load(x_ptr)
y = tl.max(x, axis=1)
tl.store(y_ptr + tl.arange(0, BLOCK_M), y)
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,dtype_str',
[(128, 64, dtype_str)
for dtype_str in ['float16']
])
def test_load_reduce(BLOCK_M, BLOCK_N, dtype_str):
dtype = dtype_mapping[dtype_str]
x = torch.randn((BLOCK_M, BLOCK_N), device='cuda', dtype=dtype)
y = torch.empty((BLOCK_M, ), device='cuda', dtype=dtype)
load_reduce_kernel[(1,)](x, y, x.stride(0), x.stride(1), y.stride(0), BLOCK_M, BLOCK_N)
golden = x.max(dim=1)[0]
torch.set_printoptions(profile='full')
assert_close(y, golden, rtol=1e-2, atol=1e-3, check_dtype=False)

View File

@@ -0,0 +1,395 @@
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
Fused Attention
===============
persistent warp specialized version of python/tutorials/06-fused-attention.py.
As of now, it only supports non-persistent warp specialized version of _fwd kernel.
"""
import pytest
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({}, num_warps=4, num_stages=2, enable_warp_specialization=True),
],
key=['Q', 'K', 'V'],
)
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
L, M,
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs)
# loop over k, v and update accumulator
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
# -- compute qk ----
k = tl.load(k_ptrs)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# compute new m
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
# correct old l
l_prev *= tl.exp(m_prev - m_curr)
# attention weights
p = tl.exp(qk - m_curr[:, None])
l_curr = tl.sum(p, 1) + l_prev
# rescale operands of matmuls
l_rcp = 1. / l_curr
p *= l_rcp[:, None]
acc *= (l_prev * l_rcp)[:, None]
# update acc
p = p.to(Q.dtype.element_ty)
v = tl.load(v_ptrs)
acc += tl.dot(p, v)
# update m_i and l_i
l_prev = l_curr
m_prev = m_curr
# update pointers
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(l_ptrs, l_prev)
tl.store(m_ptrs, m_prev)
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
@triton.jit
def _bwd_preprocess(
Out, DO, L,
NewDO, Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
denom = tl.load(L + off_m).to(tl.float32)
# compute
do = do / denom[:, None]
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
tl.store(Delta + off_m, delta)
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L, M,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_qz + off_h * stride_qh
V += off_z * stride_qz + off_h * stride_qh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_qz + off_h * stride_qh
DV += off_z * stride_qz + off_h * stride_qh
for start_n in range(0, num_block):
lo = start_n * BLOCK_M
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
m_ptrs = M + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, tl.trans(k))
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, tl.trans(v))
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
# compute dq
dq = tl.load(dq_ptrs)
dq += tl.dot(ds.to(Q.dtype.element_ty), k)
tl.store(dq_ptrs, dq)
# increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
empty = torch.empty(128, device="cuda")
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale):
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
# Only support num_warps = 4 now
assert num_warps == 4
_fwd_kernel[grid](
q, k, v, sm_scale,
L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk,
)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
return o
@staticmethod
def backward(ctx, do):
BLOCK = 128
q, k, v, o, l, m = ctx.saved_tensors
do = do.contiguous()
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
do_scaled, delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
num_stages=1,
)
return dq, dk, dv, None
attention = _attention.apply
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
# with ENABLE_TMA=0 and ENABLE_MMA_V3=0
pytest.skip('unspecified launch failure')
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
sm_scale = 0.2
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
for z in range(Z):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
# p = torch.exp(p)
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# # triton implementation
tri_out = attention(q, k, v, sm_scale)
# print(ref_out)
# print(tri_out)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=0)
assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0)
assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0)
try:
from flash_attn.flash_attn_interface import flash_attn_func
HAS_FLASH = True
except BaseException:
HAS_FLASH = False
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = [triton.testing.Benchmark(
x_names=['N_CTX'],
# x_vals=[2**i for i in range(10, 14)],
x_vals=[2**i for i in range(10, 11)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
# ) for mode in ['fwd', 'bwd']]
) for mode in ['fwd']]
@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
assert mode in ['fwd', 'bwd']
# warmup = 25
# rep = 100
warmup = 0
rep = 1
if provider == "triton":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
sm_scale = 1.3
fn = lambda: attention(q, k, v, sm_scale)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros(
(BATCH + 1,), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
# # only works on post-Ampere GPUs right now
# bench_flash_attention.run(save_path='.', print_data=True)

View File

@@ -0,0 +1,987 @@
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import itertools
import pytest
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
from .utils import get_proper_err, get_variant_golden
def isMMAV3OrTMAEnabled():
import os
for k in ('ENABLE_MMA_V3', 'ENABLE_TMA'):
if os.environ.get(k, '0').lower() in ['1', 'on', 'true']:
return True
return False
@triton.jit
def static_persistent_matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
NUM_SM: tl.constexpr,
):
start_tile = tl.program_id(axis=0)
m_tiles = tl.cdiv(M, BLOCK_M)
n_tiles = tl.cdiv(N, BLOCK_N)
num_tiles = m_tiles * n_tiles
offs_k = tl.arange(0, BLOCK_K)
for tile_id in range(start_tile, num_tiles, NUM_SM):
pid_m = tile_id // n_tiles
pid_n = tile_id % n_tiles
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
for k in range(0, K, BLOCK_K):
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
offs_cm = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offs_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
tl.store(c_ptrs, accumulator)
@triton.jit
def static_persistent_tma_matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
NUM_SM: tl.constexpr,
):
start_tile = tl.program_id(axis=0)
m_tiles = tl.cdiv(M, BLOCK_M)
n_tiles = tl.cdiv(N, BLOCK_N)
k_tiles = tl.cdiv(K, BLOCK_K)
num_tiles = m_tiles * n_tiles
pre_pid_m = start_tile // n_tiles
pre_pid_n = start_tile % n_tiles
block_offset_m = pre_pid_m * BLOCK_M
block_offset_n = pre_pid_n * BLOCK_N
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
for tile_id in range(start_tile, num_tiles, NUM_SM):
pid_m = tile_id // n_tiles
pid_n = tile_id % n_tiles
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
if tile_id >= NUM_SM:
a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, -k_tiles * BLOCK_K])
b_tile_ptr = tl.advance(b_tile_ptr, [-k_tiles * BLOCK_K, (pid_n - pre_pid_n) * BLOCK_N])
for k in range(0, K, BLOCK_K):
a = tl.load(a_tile_ptr)
b = tl.load(b_tile_ptr)
accumulator += tl.dot(a, b)
a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_K])
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_K, 0])
offs_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offs_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, accumulator)
pre_pid_m = pid_m
pre_pid_n = pid_n
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,TRANS_A,TRANS_B', [
[4096, 4096, 64, 64, 64, 16, 4, 1, False, True],
[4096, 4096, 64, 64, 64, 32, 4, 1, False, True],
[4096, 4096, 64, 256, 64, 16, 4, 1, False, True],
[4096, 4096, 64, 128, 128, 16, 4, 1, False, True],
# TODO: fix issue for 8-warp persistent kernel
# [4096, 4096, 64, 128, 128, 16, 8, 1, False, True],
# [4096, 4096, 64, 128, 256, 16, 8, 1, False, True],
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, TRANS_A, TRANS_B):
# TODO: fix RewriteTensorPtrPass
pytest.skip('RewriteTensorPtrPass issue')
if (TRANS_A):
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
a = .1 * torch.randn((M, K), device='cuda', dtype=torch.float16)
if (TRANS_B):
b = .1 * torch.randn((N, K), device='cuda', dtype=torch.float16).T
else:
b = .1 * torch.randn((K, N), device='cuda', dtype=torch.float16)
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
grid = lambda META: (num_SMs,)
def call_vintage():
static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
return c
def call_stylish():
static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
return c
th_c = torch.matmul(a, b)
# Test using old style of ptr calculation
tt_c = call_vintage()
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
# Cealr c
c = torch.randn((M, N), device=a.device, dtype=torch.float32)
# Test using make_block_ptr
tt_c = call_stylish()
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
@triton.jit
def warp_specialized_matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
tid = tl.program_id(axis=0)
n_tiles = tl.cdiv(N, BLOCK_N)
pid_m = tid // n_tiles
pid_n = tid % n_tiles
offs_k = tl.arange(0, BLOCK_K)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_am = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
offs_bn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
accumulator = accumulator.to(c_ptr.dtype.element_ty)
offs_cm = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offs_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
mask = (offs_cm < M)[:, None] & (offs_cn < N)[None, :]
tl.store(c_ptrs, accumulator, mask=mask)
@triton.jit
def tma_warp_specialized_matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
tid = tl.program_id(axis=0)
n_tiles = tl.cdiv(N, BLOCK_N)
pid_m = tid // n_tiles
pid_n = tid % n_tiles
block_offset_m = pid_m * BLOCK_M
block_offset_n = pid_n * BLOCK_N
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(a_tile_ptr)
b = tl.load(b_tile_ptr)
accumulator += tl.dot(a, b)
a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_K])
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_K, 0])
accumulator = accumulator.to(c_ptr.dtype.element_ty)
offs_cm = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offs_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
mask = (offs_cm < M)[:, None] & (offs_cn < N)[None, :]
tl.store(c_ptrs, accumulator, mask=mask)
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B', [
[2048, 2048, 64, 64, 64, 16, 1, False, True],
[4096, 4096, 64, 64, 64, 16, 1, False, True],
[128, 4096, 64, 64, 64, 16, 1, False, True],
[4096, 128, 64, 64, 64, 16, 1, False, True],
[4096, 4096, 64, 64, 64, 32, 1, False, True],
[4096, 4096, 256, 128, 128, 16, 1, False, True],
[4096, 4096, 320, 128, 64, 64, 1, False, True],
[4096, 4096, 320, 64, 128, 64, 1, False, True],
[4096, 4096, 320, 128, 128, 64, 1, False, True],
[4096, 4096, 256, 256, 64, 16, 1, False, True],
[4096, 4096, 256, 256, 64, 64, 1, False, True],
[4096, 4096, 256, 64, 256, 16, 1, False, True],
[4096, 4096, 256, 64, 256, 64, 1, False, True],
[4096, 4096, 256, 256, 128, 16, 1, False, True],
[4096, 4096, 256, 256, 128, 64, 1, False, True],
[4096, 4096, 256, 128, 256, 16, 1, False, True],
[4096, 4096, 256, 128, 256, 64, 1, False, True],
# numCTAs > 1
[2048, 2048, 64, 128, 128, 64, 2, False, True],
[2048, 2048, 64, 128, 128, 64, 2, False, True],
[2048, 2048, 128, 256, 128, 64, 4, False, True],
[4096, 4096, 128, 256, 128, 64, 4, False, True],
[4096, 4096, 256, 128, 256, 64, 4, False, True],
[4096, 4096, 256, 256, 256, 64, 4, False, True],
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B):
pytest.skip('hang')
if (TRANS_A):
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
a = .1 * torch.randn((M, K), device='cuda', dtype=torch.float16)
if (TRANS_B):
b = .1 * torch.randn((N, K), device='cuda', dtype=torch.float16).T
else:
b = .1 * torch.randn((K, N), device='cuda', dtype=torch.float16)
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
grid = lambda META: (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),)
def call_vintage():
warp_specialized_matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K,
num_warps=4,
num_ctas=NUM_CTAS,
enable_warp_specialization=True)
return c
def call_stylish():
tma_warp_specialized_matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K,
num_warps=4,
num_ctas=NUM_CTAS,
enable_warp_specialization=True)
return c
th_c = torch.matmul(a, b)
# Test using old style of ptr calculation
tt_c = call_vintage()
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
# Cealr c
c = torch.randn((M, N), device=a.device, dtype=torch.float32)
# # Test using make_block_ptr
tt_c = call_stylish()
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
# # #############################################Performance Evaluation#############################################
# fn = lambda: call_vintage()
# ms = triton.testing.do_bench(fn, warmup=25, rep=100)
# cur_gpu_perf = round(2. * M * N * K / ms * 1e-9, 2)
# print(' '.join(['Performance of', str(M), str(N), str(K), ':', str(ms), 'ms, ', str(cur_gpu_perf), 'TFLOPS']))
@triton.autotune(
configs=[
triton.Config({}, num_stages=3, num_warps=4, enable_warp_specialization=True),
# triton.Config({}, num_stages=3, num_warps=4, enable_warp_specialization=False),
],
key=['M', 'N', 'K'],
)
@triton.jit
def static_persistent_warp_specialized_matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
NUM_SM: tl.constexpr,
):
start_tile = tl.program_id(axis=0)
m_tiles = tl.cdiv(M, BLOCK_M)
n_tiles = tl.cdiv(N, BLOCK_N)
num_tiles = m_tiles * n_tiles
offs_k = tl.arange(0, BLOCK_K)
for tile_id in range(start_tile, num_tiles, NUM_SM):
pid_m = tile_id // n_tiles
pid_n = tile_id % n_tiles
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
for k in range(0, K, BLOCK_K):
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
offs_cm = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offs_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
tl.store(c_ptrs, accumulator)
@triton.autotune(
configs=[
triton.Config({}, num_stages=3, num_warps=4, enable_warp_specialization=True),
# triton.Config({}, num_stages=3, num_warps=4, enable_warp_specialization=False),
],
key=['M', 'N', 'K'],
)
@triton.jit
def static_persistent_tma_warp_specialized_matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
NUM_SM: tl.constexpr,
):
start_tile = tl.program_id(axis=0)
m_tiles = tl.cdiv(M, BLOCK_M)
n_tiles = tl.cdiv(N, BLOCK_N)
k_tiles = tl.cdiv(K, BLOCK_K)
num_tiles = m_tiles * n_tiles
pre_pid_m = start_tile // n_tiles
pre_pid_n = start_tile % n_tiles
block_offset_m = pre_pid_m * BLOCK_M
block_offset_n = pre_pid_n * BLOCK_N
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
for tile_id in range(start_tile, num_tiles, NUM_SM):
pid_m = tile_id // n_tiles
pid_n = tile_id % n_tiles
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
if tile_id >= NUM_SM:
a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, -k_tiles * BLOCK_K])
b_tile_ptr = tl.advance(b_tile_ptr, [-k_tiles * BLOCK_K, (pid_n - pre_pid_n) * BLOCK_N])
for k in range(0, K, BLOCK_K):
a = tl.load(a_tile_ptr)
b = tl.load(b_tile_ptr)
accumulator += tl.dot(a, b)
a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_K])
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_K, 0])
offs_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
offs_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, accumulator)
pre_pid_m = pid_m
pre_pid_n = pid_n
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B', [
[2048, 2048, 64, 64, 64, 16, 1, False, True],
[4096, 4096, 64, 64, 64, 16, 1, False, True],
[128, 4096, 64, 64, 64, 16, 1, False, True],
[4096, 128, 64, 64, 64, 16, 1, False, True],
[4096, 4096, 64, 64, 64, 32, 1, False, True],
[4096, 4096, 256, 128, 128, 16, 1, False, True],
[4096, 4096, 320, 128, 64, 64, 1, False, True],
[4096, 4096, 320, 64, 128, 64, 1, False, True],
[4096, 4096, 320, 128, 128, 64, 1, False, True],
[4096, 4096, 256, 256, 64, 16, 1, False, True],
[4096, 4096, 256, 256, 64, 64, 1, False, True],
[4096, 4096, 256, 64, 256, 16, 1, False, True],
[4096, 4096, 256, 64, 256, 64, 1, False, True],
[4096, 4096, 256, 256, 128, 16, 1, False, True],
[4096, 4096, 256, 256, 128, 64, 1, False, True],
[4096, 4096, 256, 128, 256, 16, 1, False, True],
[4096, 4096, 256, 128, 256, 64, 1, False, True],
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B):
# TODO: fix RewriteTensorPtrPass
pytest.skip('RewriteTensorPtrPass issue')
if (TRANS_A):
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
a = .1 * torch.randn((M, K), device='cuda', dtype=torch.float16)
if (TRANS_B):
b = .1 * torch.randn((N, K), device='cuda', dtype=torch.float16).T
else:
b = .1 * torch.randn((K, N), device='cuda', dtype=torch.float16)
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
grid = lambda META: (num_SMs,)
def call_vintage():
static_persistent_warp_specialized_matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs)
return c
def call_stylish():
static_persistent_tma_warp_specialized_matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs)
return c
th_c = torch.matmul(a, b)
# Test using old style of ptr calculation
tt_c = call_vintage()
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
# Cealr c
c = torch.randn((M, N), device=a.device, dtype=torch.float32)
# Test using make_block_ptr
tt_c = call_stylish()
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
# #############################################Performance Evaluation#############################################
# fn = lambda: call_stylish()
# ms = triton.testing.do_bench(fn, warmup=25, rep=100)
# cur_gpu_perf = round(2. * M * N * K / ms * 1e-9, 2)
# print(' '.join(['Performance of', str(M), str(N), str(K), ':', str(ms), 'ms, ', str(cur_gpu_perf), 'TFLOPS']))
@triton.jit
def static_persistent_matmul_no_scf_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr,
NUM_SM: tl.constexpr, USE_TMA_LOAD: tl.constexpr,
):
start_tile = tl.program_id(axis=0)
m_tiles = tl.cdiv(M, BLOCK_M)
n_tiles = tl.cdiv(N, BLOCK_N)
num_tiles = m_tiles * n_tiles
offs_k = tl.arange(0, BLOCK_K)
pre_pid_m = start_tile // n_tiles
pre_pid_n = start_tile % n_tiles
block_offset_m = pre_pid_m * BLOCK_M
block_offset_n = pre_pid_n * BLOCK_N
if USE_TMA_LOAD:
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
if USE_TMA_EPILOGUE:
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
for tile_id in range(start_tile, num_tiles, NUM_SM):
pid_m = tile_id // n_tiles
pid_n = tile_id % n_tiles
if USE_TMA_LOAD:
a_block_ptr = tl.advance(a_block_ptr, [(pid_m - pre_pid_m) * BLOCK_M, 0])
b_block_ptr = tl.advance(b_block_ptr, [0, (pid_n - pre_pid_n) * BLOCK_N])
a = tl.load(a_block_ptr)
b = tl.load(b_block_ptr)
else:
offs_am = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_bn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
c = tl.dot(a, b)
if FLOAT16_OUTPUT:
c = c.to(tl.float16)
if USE_TMA_EPILOGUE:
c_block_ptr = tl.advance(c_block_ptr, [(pid_m - pre_pid_m) * BLOCK_M, (pid_n - pre_pid_n) * BLOCK_N])
tl.store(c_block_ptr, c)
else:
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, c)
pre_pid_m = pid_m
pre_pid_n = pid_n
@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,USE_TMA_LOAD',
itertools.chain(
*[
[
# numCTAs = 1, no TMA multicast:
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
# small M, N
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
] for USE_TMA_EPILOGUE in [True, False]
for USE_TMA_LOAD in [True, False]
]))
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, USE_TMA_LOAD):
if isMMAV3OrTMAEnabled():
pytest.skip("known failure")
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if OUTPUT_TYPE == "float16":
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
else:
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
# TODO: set `enable_warp_specialization=False` will lead to compilation error.
static_persistent_matmul_no_scf_kernel[(num_SMs,)](a_ptr=a, b_ptr=b, c_ptr=c,
M=M, N=N, K=K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SM=num_SMs,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"),
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE,
USE_TMA_LOAD=USE_TMA_LOAD,
enable_warp_specialization=True)
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
golden = torch.matmul(a_f32, b_f32)
golden_variant = get_variant_golden(a_f32, b_f32)
golden_abs_err, golden_rel_err = get_proper_err(golden, golden_variant)
torch.set_printoptions(profile="full")
assert_close(
c,
golden,
rtol=max(1e-2, 1.1 * golden_rel_err),
atol=max(1e-3, 1.1 * golden_abs_err),
check_dtype=False)
@triton.jit
def full_static_persistent_matmul_kernel(
a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_wm, stride_wn,
stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr,
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr,
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr,
NUM_SM: tl.constexpr
):
start_pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_tiles = num_pid_m * num_pid_n
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = start_pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pre_pid_m = first_pid_m + ((start_pid % num_pid_in_group) % group_size_m)
pre_pid_n = (start_pid % num_pid_in_group) // group_size_m
pre_block_offset_m = pre_pid_m * BLOCK_M
pre_block_offset_n = pre_pid_n * BLOCK_N
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(pre_block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(A_ORDER_0, A_ORDER_1))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, pre_block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(B_ORDER_0, B_ORDER_1))
w_tile_ptr = tl.make_block_ptr(base=w_ptr, shape=(N, N), strides=(stride_wm, stride_wn),
offsets=(0, pre_block_offset_n), block_shape=(BLOCK_N, BLOCK_N), order=(0, 1))
if USE_TMA_STORE:
z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn),
offsets=(pre_block_offset_m, pre_block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
for tile_id in range(start_pid, num_tiles, NUM_SM):
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
block_offset_m = pid_m * BLOCK_M
block_offset_n = pid_n * BLOCK_N
offs_m = block_offset_m + tl.arange(0, BLOCK_M)
offs_n = block_offset_n + tl.arange(0, BLOCK_N)
z_ptrs = z_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn
bias_ptrs = bias_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn
mask = (offs_m < M)[:, None] & (offs_n < N)[None, :]
# TODO: lib/Dialect/TritonGPU/Transforms/RewriteTensorPointer.cpp does not support scf.if yet.
# if tile_id >= NUM_SM:
# a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, -tl.cdiv(K, BLOCK_K) * BLOCK_K])
# b_tile_ptr = tl.advance(b_tile_ptr, [-tl.cdiv(K, BLOCK_K) * BLOCK_K, (pid_n - pre_pid_n) * BLOCK_N])
a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, 0])
b_tile_ptr = tl.advance(b_tile_ptr, [0, (pid_n - pre_pid_n) * BLOCK_N])
z = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(a_tile_ptr, boundary_check=(0, 1))
b = tl.load(b_tile_ptr, boundary_check=(0, 1))
z += tl.dot(a, b)
a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_K])
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_K, 0])
a_tile_ptr = tl.advance(a_tile_ptr, [0, -tl.cdiv(K, BLOCK_K) * BLOCK_K])
b_tile_ptr = tl.advance(b_tile_ptr, [-tl.cdiv(K, BLOCK_K) * BLOCK_K, 0])
if (out_dtype == tl.constexpr(tl.float16)):
z = z.to(tl.float16)
if ADD_MATRIX:
z += tl.load(bias_ptrs, mask=mask)
if ADD_ROWS:
ZRs = bias_ptr + offs_m * stride_zm
z += tl.load(ZRs)[:, None]
if ADD_COLS:
ZCs = bias_ptr + offs_n * stride_zn
z += tl.load(ZCs)[None, :]
if DO_SOFTMAX:
max = tl.max(z, 1)
z = z - max[:, None]
num = tl.exp(z.to(tl.float32)).to(max.dtype)
den = tl.sum(num, 1)
z = num / den[:, None]
if CHAIN_DOT:
w = tl.load(w_tile_ptr)
w_tile_ptr = tl.advance(w_tile_ptr, [0, (pid_n - pre_pid_n) * BLOCK_N])
z = tl.dot(z.to(w.dtype), w)
if (out_dtype == tl.constexpr(tl.float16)):
z = z.to(tl.float16)
if USE_TMA_STORE:
z_block_ptr = tl.advance(z_block_ptr, [(pid_m - pre_pid_m) * BLOCK_M, (pid_n - pre_pid_n) * BLOCK_N])
tl.store(z_block_ptr, z, boundary_check=(0, 1))
else:
tl.store(z_ptrs, z, mask=mask)
pre_pid_m = pid_m
pre_pid_n = pid_n
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
[
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws)
for shape_w_c in [
# bad from cublas-important-layers
[4096, 1, 1024, False, False],
[2048, 204, 1000, True, False],
]
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for enable_ws in [True]
] + [
(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
# softmax works for one CTA
for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
# TODO: enable when num_warps != 4 is supported.
# [64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
]
for epilogue in ['softmax']
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for trans_a in [False, True]
for trans_b in [False, True]
for num_stages in [3]
for enable_ws in [True]
] + [
(*shape_w_c, trans_a, trans_b, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
# TODO: enable when num_ctas != 1 is supported.
# [128, 128, 16, 4, 4, 512, 256, 64],
# [128, 256, 32, 4, 8, 256, 256, 192],
# [512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
# TODO: enable when num_ctas != 1 is supported.
# [128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
# TODO: enable when num_ctas != 1 is supported.
# [16, 32, 64, 4, 4, 512, 256, 64],
]
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for trans_a in [False, True]
for trans_b in [False, True]
for num_stages in [3]
for enable_ws in [True]
] + [(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4] for num_ctas in [1]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4] for num_ctas in [1]],
# # TODO: enable when num_warps != 4 is supported.
# # repeat
# # [64, 64, 32, 8, 1, 128, 256, 64],
# # [64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 1, 513, 193, 192],
]
for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot']
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for trans_a in [False, True]
for trans_b in [False, True]
for num_stages in [3]
for enable_ws in [True]
if not (epilogue == 'chain-dot' and (shape_w_c[5] is not None or shape_w_c[0] != shape_w_c[1]))
] + [
(64, n, 16, 4, 1, 512, 256, 256, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
# loop over instr shapes
for n in [16, 32, 64, 128, 256]
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for num_stages in [2, 4, 5, 7]
for enable_ws in [True]
] + [
(*shape_w_c, *shape, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
# irregular shapes
for shape_w_c in [
[128, 128, 64, 4, 1]
# [256, 128, 64, 4, 2],
# [128, 128, 128, 4, 2],
]
for shape in list(itertools.product([*range(512, 4096, 360)], [*range(512, 4096, 360)], [512, 1024]))
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for num_stages in [2, 3, 4]
for enable_ws in [True]
]
)
@pytest.mark.skipif(torch.cuda.get_device_capability()
[0] < 9, reason="Requires compute capability >= 9")
def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
pytest.skip("known failure, will fix it later!!!")
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS])) in [
'128-128-128-4-1-256-256-192-none-float32-True-3-True',
]:
pytest.skip('out of resource: shared memory, Required: 263168')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, USE_TMA_STORE, ENABLE_WS])) in ([
'64-16-16-4-1-512-256-256-True-True',
] + [
f'128-128-64-4-1-{m}-{n}-{k}-True-True' for m in range(512, 4096, 360) for n in range(512, 4096, 360) for k in [512, 1024]
]):
pytest.skip('known kernel hang problem when tma store is enabled')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if epilogue == 'chain-dot':
pytest.skip('known failure: Assertion !region.empty() && unexpected empty region.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
if epilogue == 'chain-dot':
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
else:
w = torch.randn((1, 1), device='cuda', dtype=torch.float16).T
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
def grid(META):
return (num_SMs,)
full_static_persistent_matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z,
M=M, N=N, K=K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_wm=w.stride(0), stride_wn=w.stride(1),
stride_zm=z.stride(0), stride_zn=z.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8,
out_dtype=out_dtype,
USE_TMA_STORE=USE_TMA_STORE,
ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS=epilogue == 'add-rows',
ADD_COLS=epilogue == 'add-cols',
DO_SOFTMAX=epilogue == 'softmax',
CHAIN_DOT=epilogue == 'chain-dot',
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1],
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1],
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES,
enable_warp_specialization=ENABLE_WS,
NUM_SM=num_SMs)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden,
rtol=1e-2,
atol=1e-3,
check_dtype=False)

View File

@@ -0,0 +1,123 @@
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import pytest
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
def get_variant_golden(a, b):
SIZE_M = a.shape[0]
SIZE_K = a.shape[1]
SIZE_N = b.shape[1]
assert a.shape[1] == b.shape[0]
zero_M_K = torch.zeros((SIZE_M, SIZE_K)).cuda()
zero_3M_K = torch.zeros((3 * SIZE_M, SIZE_K)).cuda()
zero_K_N = torch.zeros((SIZE_K, SIZE_N)).cuda()
zero_3K_N = torch.zeros((3 * SIZE_K, SIZE_N)).cuda()
a_padded = torch.cat((a, zero_M_K, zero_M_K), 0)
a_padded = torch.cat((a_padded, zero_3M_K, zero_3M_K), 1)
b_padded = torch.cat((b, zero_K_N, zero_K_N), 0)
b_padded = torch.cat((b_padded, zero_3K_N, zero_3K_N), 1)
c_padded = torch.matmul(a_padded, b_padded)
return c_padded[:SIZE_M, :SIZE_N]
# It's not easy to get a proper error threshold in different size
# Here the gemm calculation is padded to a different size in order to get
# a variant version of the golden result. And the error between golden and
# golden_variant provide reference on selecting the proper rtol / atol.
def get_proper_err(a, b, golden):
golden_variant = get_variant_golden(a, b)
golden_diff = golden - golden_variant
golden_abs_err = torch.max(torch.abs(golden_diff)).item()
golden_rel_err = torch.max(torch.abs(golden_diff / golden)).item()
return (golden_abs_err, golden_rel_err)
@triton.jit
def matmul_tma_load_store(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
OUTPUT_F16: tl.constexpr
):
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
a = tl.load(a_block_ptr)
b = tl.load(b_block_ptr)
c = tl.dot(a, b)
if OUTPUT_F16:
c = c.to(tl.float16)
tl.store(c_block_ptr, c)
@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_F16', [
[64, 64, 16, 1, 4, False, True, False],
[64, 64, 16, 1, 4, False, True, True],
[128, 64, 32, 1, 4, False, True, False],
[128, 64, 32, 1, 4, False, True, True],
[64, 128, 32, 1, 4, False, True, False],
[64, 128, 32, 1, 4, False, True, True],
[128, 128, 64, 1, 4, False, True, False],
[128, 128, 64, 1, 4, False, True, True],
])
def test_tma_load_store(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_F16):
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
if OUTPUT_F16:
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
matmul_tma_load_store[(1, 1)](a_ptr=a, b_ptr=b, c_ptr=c,
M=M, N=N, K=K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
OUTPUT_F16=OUTPUT_F16)
golden = torch.matmul(a, b)
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
torch.set_printoptions(profile="full")
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)

View File

@@ -0,0 +1,75 @@
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import os
import pytest
import torch
from test_util import get_proper_err
from torch.testing import assert_close
import triton
@pytest.mark.parametrize('TTGIR,TRANS_A,TRANS_B', [
# TODO: uncomment when it's done
# ["wgmma_tma_64_64_16_f16.ttgir", False, True],
])
def test_tma_wgmma_64_64_16_f16(TTGIR, TRANS_A, TRANS_B):
capability = torch.cuda.get_device_capability()
if capability[0] < 9:
pytest.skip("Only test wgmma on devices with sm >= 90")
SIZE_M = 64
SIZE_N = 64
SIZE_K = 16
if (TRANS_A):
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
else:
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
if (TRANS_B):
b = torch.randn((SIZE_N, SIZE_K), device='cuda', dtype=torch.float16).T
else:
b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
ttgir_path = os.path.dirname(__file__) + "/" + TTGIR
kernel = triton.compile(ttgir_path)
kernel[(1, 1, 1)](a.data_ptr(), b.data_ptr(), c.data_ptr(),
SIZE_M, SIZE_N, SIZE_K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0))
golden = torch.matmul(a, b)
torch.set_printoptions(profile="full", sci_mode=False)
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
assert_close(
c,
golden,
rtol=max(1e-4,
1.5 * golden_rel_err),
atol=max(
1e-4,
1.5 * golden_abs_err),
check_dtype=False)

View File

@@ -0,0 +1,52 @@
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import torch
def get_variant_golden(a, b):
SIZE_M = a.shape[0]
SIZE_K = a.shape[1]
SIZE_N = b.shape[1]
assert a.shape[1] == b.shape[0]
zero_M_K = torch.zeros((SIZE_M, SIZE_K)).cuda()
zero_3M_K = torch.zeros((3 * SIZE_M, SIZE_K)).cuda()
zero_K_N = torch.zeros((SIZE_K, SIZE_N)).cuda()
zero_3K_N = torch.zeros((3 * SIZE_K, SIZE_N)).cuda()
a_padded = torch.cat((a, zero_M_K, zero_M_K), 0)
a_padded = torch.cat((a_padded, zero_3M_K, zero_3M_K), 1)
b_padded = torch.cat((b, zero_K_N, zero_K_N), 0)
b_padded = torch.cat((b_padded, zero_3K_N, zero_3K_N), 1)
c_padded = torch.matmul(a_padded, b_padded)
return c_padded[:SIZE_M, :SIZE_N]
# It's not easy to get a proper error threshold in different size
# Here the gemm calculation is padded to a different size in order to get
# a variant version of the golden result. And the error between golden and
# golden_variant provide reference on selecting the proper rtol / atol.
def get_proper_err(a, b, golden):
golden_variant = get_variant_golden(a, b)
golden_diff = golden - golden_variant
golden_abs_err = torch.max(torch.abs(golden_diff)).item()
golden_rel_err = torch.max(torch.abs(golden_diff / golden)).item()
return (golden_abs_err, golden_rel_err)

View File

@@ -0,0 +1,52 @@
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0>
%3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1>
%4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0>
%5 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
%6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
%7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0>
%8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0>
%9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2>
%11 = tt.splat %arg4 : (i32) -> tensor<16x1xi32, #blocked2>
%12 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<16x1x!tt.ptr<f16>, #blocked2>
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
%16 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2>
%17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0>
%18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
%19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x16x!tt.ptr<f16>, #blocked0>
%20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr<f16>, #blocked0>, tensor<64x16xi32, #blocked0>
%21 = tt.load %20 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16xf16, #blocked0>
%22 = arith.muli %10, %11 : tensor<16x1xi32, #blocked2>
%23 = tt.addptr %12, %22 : tensor<16x1x!tt.ptr<f16>, #blocked2>, tensor<16x1xi32, #blocked2>
%24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr<f16>, #blocked2>) -> tensor<16x64x!tt.ptr<f16>, #blocked2>
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
%27 = triton_gpu.convert_layout %21 : (tensor<64x16xf16, #blocked0>) -> tensor<64x16xf16, #shared0>
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
%29 = tt.dot %27, %28, %cst {allowTF32 = true, transA = true, transB = true} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
%33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1>
%34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1>
%35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
%36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1>
tt.store %37, %38 : tensor<64x64xf32, #blocked1>
return
}
}

View File

@@ -0,0 +1,52 @@
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0>
%3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1>
%4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0>
%5 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
%6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
%7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0>
%8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0>
%9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2>
%11 = tt.splat %arg4 : (i32) -> tensor<16x1xi32, #blocked2>
%12 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<16x1x!tt.ptr<f16>, #blocked2>
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
%16 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2>
%17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0>
%18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
%19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x16x!tt.ptr<f16>, #blocked0>
%20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr<f16>, #blocked0>, tensor<64x16xi32, #blocked0>
%21 = tt.load %20 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16xf16, #blocked0>
%22 = arith.muli %10, %11 : tensor<16x1xi32, #blocked2>
%23 = tt.addptr %12, %22 : tensor<16x1x!tt.ptr<f16>, #blocked2>, tensor<16x1xi32, #blocked2>
%24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr<f16>, #blocked2>) -> tensor<16x64x!tt.ptr<f16>, #blocked2>
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
%27 = triton_gpu.convert_layout %21 : (tensor<64x16xf16, #blocked0>) -> tensor<64x16xf16, #shared0>
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
%29 = tt.dot %27, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
%33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1>
%34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1>
%35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
%36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1>
tt.store %37, %38 : tensor<64x64xf32, #blocked1>
return
}
}

View File

@@ -0,0 +1,59 @@
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0>
%3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1>
%4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0>
%5 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
%6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
%7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0>
%8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0>
%9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2>
%11 = tt.splat %arg4 : (i32) -> tensor<16x1xi32, #blocked2>
%12 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<16x1x!tt.ptr<f16>, #blocked2>
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
%16 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2>
%17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0>
%18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
%19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x16x!tt.ptr<f16>, #blocked0>
%20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr<f16>, #blocked0>, tensor<64x16xi32, #blocked0>
%ci0 = arith.constant 0 : i32
%i1_true = arith.constant 1 : i1
%i1_false = arith.constant 0 : i1
%t = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared0>
%mask0 = tt.splat %i1_true : (i1) -> tensor<64x16xi1, #blocked0>
%t0 = triton_gpu.insert_slice_async %20, %t, %ci0, %mask0 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<64x16x!tt.ptr<f16>, #blocked0> -> tensor<1x64x16xf16, #shared0>
triton_gpu.async_commit_group
triton_gpu.async_wait {num = 1 : i32}
%21 = triton_gpu.extract_slice %t0[%ci0, 0, 0][1, 64, 16][1, 1, 1] : tensor<1x64x16xf16, #shared0> to tensor<64x16xf16, #shared0>
%22 = arith.muli %10, %11 : tensor<16x1xi32, #blocked2>
%23 = tt.addptr %12, %22 : tensor<16x1x!tt.ptr<f16>, #blocked2>, tensor<16x1xi32, #blocked2>
%24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr<f16>, #blocked2>) -> tensor<16x64x!tt.ptr<f16>, #blocked2>
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
%33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1>
%34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1>
%35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
%36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1>
tt.store %37, %38 : tensor<64x64xf32, #blocked1>
return
}
}

View File

@@ -0,0 +1,63 @@
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0>
%3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1>
%4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0>
%5 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
%6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
%7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0>
%8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0>
%9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2>
%11 = tt.splat %arg4 : (i32) -> tensor<16x1xi32, #blocked2>
%12 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<16x1x!tt.ptr<f16>, #blocked2>
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
%16 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2>
%17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0>
%18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
%19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x16x!tt.ptr<f16>, #blocked0>
%20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr<f16>, #blocked0>, tensor<64x16xi32, #blocked0>
%ci0 = arith.constant 0 : i32
%i1_true = arith.constant 1 : i1
%i1_false = arith.constant 0 : i1
%t = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared0>
// TODO: even an empty init external call here will break the UT
%mbar = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : tensor<1xi64, #shared0>
// %mbar0 = triton_nvidia_gpu.extract_mbarrier %mbar[%ci0] : tensor<1xi64, #shared0>, i32 -> !tt.ptr<i64, 3>
%mask0 = tt.splat %i1_true : (i1) -> tensor<64x16xi1, #blocked0>
%t0 = triton_gpu.insert_slice_async %20, %t, %ci0, %mask0 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<64x16x!tt.ptr<f16>, #blocked0> -> tensor<1x64x16xf16, #shared0>
triton_gpu.async_wait {num = 1 : i32}
// triton_nvidia_gpu.mbarrier_arrive %mbar0 {trackAsyncOp = true} : !tt.ptr<i64, 3>
// triton_nvidia_gpu.mbarrier_wait %mbar0, %i1_false : !tt.ptr<i64, 3>
%21 = triton_gpu.extract_slice %t0[%ci0, 0, 0][1, 64, 16][1, 1, 1] : tensor<1x64x16xf16, #shared0> to tensor<64x16xf16, #shared0>
%22 = arith.muli %10, %11 : tensor<16x1xi32, #blocked2>
%23 = tt.addptr %12, %22 : tensor<16x1x!tt.ptr<f16>, #blocked2>, tensor<16x1xi32, #blocked2>
%24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr<f16>, #blocked2>) -> tensor<16x64x!tt.ptr<f16>, #blocked2>
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
%33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1>
%34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1>
%35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
%36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1>
tt.store %37, %38 : tensor<64x64xf32, #blocked1>
return
}
}

View File

@@ -0,0 +1,63 @@
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0>
%3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1>
%4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0>
%5 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
%6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
%7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0>
%8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0>
%9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2>
%11 = tt.splat %arg4 : (i32) -> tensor<1x64xi32, #blocked2>
%12 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<16x1x!tt.ptr<f16>, #blocked2>
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
%bstride1 = arith.muli %11, %15 : tensor<1x64xi32, #blocked2>
%16 = tt.broadcast %bstride1 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2>
%17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0>
%18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
%19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x16x!tt.ptr<f16>, #blocked0>
%20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr<f16>, #blocked0>, tensor<64x16xi32, #blocked0>
%ci0 = arith.constant 0 : i32
%i1_true = arith.constant 1 : i1
%i1_false = arith.constant 0 : i1
%at = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared0>
%maska = tt.splat %i1_true : (i1) -> tensor<64x16xi1, #blocked0>
%at0 = triton_gpu.insert_slice_async %20, %at, %ci0, %maska {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<64x16x!tt.ptr<f16>, #blocked0> -> tensor<1x64x16xf16, #shared0>
triton_gpu.async_commit_group
triton_gpu.async_wait {num = 1 : i32}
%21 = triton_gpu.extract_slice %at0[%ci0, 0, 0][1, 64, 16][1, 1, 1] : tensor<1x64x16xf16, #shared0> to tensor<64x16xf16, #shared0>
%23 = tt.addptr %12, %10 : tensor<16x1x!tt.ptr<f16>, #blocked2>, tensor<16x1xi32, #blocked2>
%24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr<f16>, #blocked2>) -> tensor<16x64x!tt.ptr<f16>, #blocked2>
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
%bt = triton_gpu.alloc_tensor : tensor<1x16x64xf16, #shared1>
%maskb = tt.splat %i1_true : (i1) -> tensor<16x64xi1, #blocked2>
%bt0 = triton_gpu.insert_slice_async %25, %bt, %ci0, %maskb {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<16x64x!tt.ptr<f16>, #blocked2> -> tensor<1x16x64xf16, #shared1>
triton_gpu.async_commit_group
triton_gpu.async_wait {num = 1 : i32}
%28 = triton_gpu.extract_slice %bt0[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
%33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1>
%34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1>
%35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
%36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1>
tt.store %37, %38 : tensor<64x64xf32, #blocked1>
return
}
}

View File

@@ -0,0 +1,67 @@
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0>
%3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1>
%4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0>
%5 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
%6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
%7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0>
%8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0>
%9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2>
%11 = tt.splat %arg4 : (i32) -> tensor<1x64xi32, #blocked2>
%12 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<16x1x!tt.ptr<f16>, #blocked2>
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
%bstride1 = arith.muli %11, %15 : tensor<1x64xi32, #blocked2>
%16 = tt.broadcast %bstride1 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2>
%17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0>
%18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
%19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x16x!tt.ptr<f16>, #blocked0>
%20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr<f16>, #blocked0>, tensor<64x16xi32, #blocked0>
%ci0 = arith.constant 0 : i32
%i1_true = arith.constant 1 : i1
%i1_false = arith.constant 0 : i1
%at = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared0>
%maska = tt.splat %i1_true : (i1) -> tensor<64x16xi1, #blocked0>
%mbar0 = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : !tt.ptr<i64, 3>
%at0 = triton_gpu.insert_slice_async %20, %at, %ci0, %maska {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<64x16x!tt.ptr<f16>, #blocked0> -> tensor<1x64x16xf16, #shared0>
triton_nvidia_gpu.mbarrier_arrive %mbar0 {trackAsyncOp = true} : !tt.ptr<i64, 3>
triton_nvidia_gpu.mbarrier_wait %mbar0, %i1_false : !tt.ptr<i64, 3>
// triton_gpu.async_wait {num = 1 : i32}
%21 = triton_gpu.extract_slice %at0[%ci0, 0, 0][1, 64, 16][1, 1, 1] : tensor<1x64x16xf16, #shared0> to tensor<64x16xf16, #shared0>
%23 = tt.addptr %12, %10 : tensor<16x1x!tt.ptr<f16>, #blocked2>, tensor<16x1xi32, #blocked2>
%24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr<f16>, #blocked2>) -> tensor<16x64x!tt.ptr<f16>, #blocked2>
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
%bt = triton_gpu.alloc_tensor : tensor<1x16x64xf16, #shared1>
%maskb = tt.splat %i1_true : (i1) -> tensor<16x64xi1, #blocked2>
%mbar1 = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : !tt.ptr<i64, 3>
%bt0 = triton_gpu.insert_slice_async %25, %bt, %ci0, %maskb {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<16x64x!tt.ptr<f16>, #blocked2> -> tensor<1x16x64xf16, #shared1>
triton_nvidia_gpu.mbarrier_arrive %mbar1 {trackAsyncOp = true} : !tt.ptr<i64, 3>
triton_nvidia_gpu.mbarrier_wait %mbar1, %i1_false : !tt.ptr<i64, 3>
// triton_gpu.async_wait {num = 1 : i32}
%28 = triton_gpu.extract_slice %bt0[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
%33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1>
%34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1>
%35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
%36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1>
tt.store %37, %38 : tensor<64x64xf32, #blocked1>
return
}
}

View File

@@ -0,0 +1,69 @@
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0>
%3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1>
%4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0>
%5 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
%6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
%7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0>
%8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0>
%9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2>
%11 = tt.splat %arg4 : (i32) -> tensor<1x64xi32, #blocked2>
%12 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<16x1x!tt.ptr<f16>, #blocked2>
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
%bstride1 = arith.muli %11, %15 : tensor<1x64xi32, #blocked2>
%16 = tt.broadcast %bstride1 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2>
%17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0>
%18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
%19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x16x!tt.ptr<f16>, #blocked0>
%20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr<f16>, #blocked0>, tensor<64x16xi32, #blocked0>
%ci0 = arith.constant 0 : i32
%ci1 = arith.constant 1 : i32
%i1_true = arith.constant 1 : i1
%i1_false = arith.constant 0 : i1
%at = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared0>
%maska = tt.splat %i1_true : (i1) -> tensor<64x16xi1, #blocked0>
%mbar = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : tensor<2xi64, #shared0>
%mbar0_s = triton_nvidia_gpu.extract_mbarrier %mbar[%ci0] : tensor<2xi64, #shared0>, i32 -> !tt.ptr<i64, 3>
%at0 = triton_gpu.insert_slice_async %20, %at, %ci0, %maska {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<64x16x!tt.ptr<f16>, #blocked0> -> tensor<1x64x16xf16, #shared0>
triton_nvidia_gpu.mbarrier_arrive %mbar0_s {trackAsyncOp = true} : !tt.ptr<i64, 3>
triton_nvidia_gpu.mbarrier_wait %mbar0_s, %i1_false : !tt.ptr<i64, 3>
// triton_gpu.async_wait {num = 1 : i32}
%21 = triton_gpu.extract_slice %at0[%ci0, 0, 0][1, 64, 16][1, 1, 1] : tensor<1x64x16xf16, #shared0> to tensor<64x16xf16, #shared0>
%23 = tt.addptr %12, %10 : tensor<16x1x!tt.ptr<f16>, #blocked2>, tensor<16x1xi32, #blocked2>
%24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr<f16>, #blocked2>) -> tensor<16x64x!tt.ptr<f16>, #blocked2>
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
%bt = triton_gpu.alloc_tensor : tensor<1x16x64xf16, #shared1>
%maskb = tt.splat %i1_true : (i1) -> tensor<16x64xi1, #blocked2>
%mbar1_s = triton_nvidia_gpu.extract_mbarrier %mbar[%ci1] : tensor<2xi64, #shared0>, i32 -> !tt.ptr<i64, 3>
%bt0 = triton_gpu.insert_slice_async %25, %bt, %ci0, %maskb {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<16x64x!tt.ptr<f16>, #blocked2> -> tensor<1x16x64xf16, #shared1>
triton_nvidia_gpu.mbarrier_arrive %mbar1_s {trackAsyncOp = true} : !tt.ptr<i64, 3>
triton_nvidia_gpu.mbarrier_wait %mbar1_s, %i1_false : !tt.ptr<i64, 3>
// triton_gpu.async_wait {num = 1 : i32}
%28 = triton_gpu.extract_slice %bt0[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
%33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1>
%34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1>
%35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
%36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1>
tt.store %37, %38 : tensor<64x64xf32, #blocked1>
return
}
}

View File

@@ -0,0 +1,64 @@
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], CTAsPerCGA = [1, 1], order = [0, 1], CTAOrder = [1, 0]}>
#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%aBasePtr : !tt.ptr<f16> {tt.divisibility = 16 : i32},
%bBasePtr : !tt.ptr<f16> {tt.divisibility = 16 : i32},
%cBasePtr : !tt.ptr<f32> {tt.divisibility = 16 : i32},
%sizeM : i32 {tt.divisibility = 16 : i32},
%sizeN : i32 {tt.divisibility = 16 : i32},
%sizeK : i32 {tt.divisibility = 16 : i32},
%aStride0 : i32 {tt.divisibility = 16 : i32},
%aStride1 : i32 {tt.divisibility = 16 : i32},
%bStride0 : i32 {tt.divisibility = 16 : i32},
%bStride1 : i32 {tt.divisibility = 16 : i32},
%cStride0 : i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
%ci0 = arith.constant 0 : i32
%ci1 = arith.constant 1 : i32
%i1_true = arith.constant 1 : i1
%aCoord0 = arith.constant 0 : i32
%aCoord1 = arith.constant 0 : i32
%bCoord0 = arith.constant 0 : i32
%bCoord1 = arith.constant 0 : i32
%mbar = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32, txCount = 2048 : i32 } : tensor<2xi64, #shared0>
%mbar_a = triton_nvidia_gpu.extract_mbarrier %mbar[%ci0] : tensor<2xi64, #shared0>, i32 -> !tt.ptr<i64, 3>
%mbar_b = triton_nvidia_gpu.extract_mbarrier %mbar[%ci1] : tensor<2xi64, #shared0>, i32 -> !tt.ptr<i64, 3>
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1>
%14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
// Load A
%a_smem = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared0>
%a_smem_loaded = triton_gpu.load_tile_async %aBasePtr[%sizeM, %sizeK][%aStride0, %aStride1][%aCoord0, %aCoord1], %mbar_a, %a_smem[%ci0] {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, order = dense<[1, 0]> : tensor<2xi32>} : !tt.ptr<i64, 3> -> tensor<1x64x16xf16, #shared0>
triton_nvidia_gpu.mbarrier_arrive %mbar_a {trackAsyncOp = false} : !tt.ptr<i64, 3>
triton_nvidia_gpu.mbarrier_wait %mbar_a, %i1_true : !tt.ptr<i64, 3>
%21 = triton_gpu.extract_slice %a_smem_loaded[%ci0, 0, 0][1, 64, 16][1, 1, 1] : tensor<1x64x16xf16, #shared0> to tensor<64x16xf16, #shared0>
// Load B
%b_smem = triton_gpu.alloc_tensor : tensor<1x16x64xf16, #shared1>
%b_smem_loaded= triton_gpu.load_tile_async %bBasePtr[%sizeK, %sizeN][%bStride0, %bStride1][%bCoord0, %bCoord1], %mbar_b, %b_smem[%ci0] {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, order = dense<[0, 1]> : tensor<2xi32>} : !tt.ptr<i64, 3> -> tensor<1x16x64xf16, #shared1>
triton_nvidia_gpu.mbarrier_arrive %mbar_b {trackAsyncOp = false} : !tt.ptr<i64, 3>
triton_nvidia_gpu.mbarrier_wait %mbar_b, %i1_true : !tt.ptr<i64, 3>
%28 = triton_gpu.extract_slice %b_smem_loaded[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
// Calling MMA
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
// Epilogue
%30 = tt.splat %cStride0: (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %cBasePtr: (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
%33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1>
%34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1>
%35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
%36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
%37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
%38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1>
tt.store %37, %38 : tensor<64x64xf32, #blocked1>
return
}
}

View File

@@ -0,0 +1,32 @@
import torch
def get_variant_golden(a, b):
SIZE_M = a.shape[0]
SIZE_K = a.shape[1]
SIZE_N = b.shape[1]
assert a.shape[1] == b.shape[0]
zero_M_K = torch.zeros((SIZE_M, SIZE_K), dtype=a.dtype).cuda()
zero_3M_K = torch.zeros((3 * SIZE_M, SIZE_K), dtype=a.dtype).cuda()
zero_K_N = torch.zeros((SIZE_K, SIZE_N), dtype=b.dtype).cuda()
zero_3K_N = torch.zeros((3 * SIZE_K, SIZE_N), dtype=b.dtype).cuda()
a_padded = torch.cat((a, zero_M_K, zero_M_K), 0)
a_padded = torch.cat((a_padded, zero_3M_K, zero_3M_K), 1)
b_padded = torch.cat((b, zero_K_N, zero_K_N), 0)
b_padded = torch.cat((b_padded, zero_3K_N, zero_3K_N), 1)
c_padded = torch.matmul(a_padded, b_padded)
return c_padded[:SIZE_M, :SIZE_N]
# It's not easy to get a proper error threshold in different size
# Here the gemm calculation is padded to a different size in order to get
# a variant version of the golden result. And the error between golden and
# golden_variant provide reference on selecting the proper rtol / atol.
def get_proper_err(golden, golden_variant):
golden_diff = golden - golden_variant
golden_abs_err = torch.max(torch.abs(golden_diff)).item()
# avoid problems when golden_rel_err is 'inf'
abs_golden = torch.abs(golden) + torch.full_like(golden, torch.finfo(golden.dtype).smallest_normal)
golden_rel_err = torch.max(torch.abs(golden_diff) / abs_golden).item()
return (golden_abs_err, golden_rel_err)