mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Initial code merge of Hopper support (#2036)
The initial code merge of Nvidia Hopper features support. Please be aware that the code merge is not finished yet and the trouble-shooting is still ongoing. The new hardware features (GMMA, TMA, STMATRIX etc.) and automatic warp-specialization are experimental for now and turned off by default. It is recommended for a trial when version 3.0 is released. The work is contributed by: ben-zhang-609, bealwang, donproc, qliu93, jsh20, allatit23, LyricZhao, ivanyinwz, goostavz & yangjunpro from Nvidia, in cooperation with: ptillet, Jokeren, ThomasRaoux & zahimoud from OpenAI. Co-authored-by: Goostav Zhu <gzhu@nvidia.com>
This commit is contained in:
0
python/test/unit/hopper/__init__.py
Normal file
0
python/test/unit/hopper/__init__.py
Normal file
480
python/test/unit/hopper/test_flashattention.py
Normal file
480
python/test/unit/hopper/test_flashattention.py
Normal file
@@ -0,0 +1,480 @@
|
||||
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files
|
||||
# (the "Software"), to deal in the Software without restriction,
|
||||
# including without limitation the rights to use, copy, modify, merge,
|
||||
# publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
# and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
This is a Triton implementation of the Flash Attention algorithm
|
||||
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
|
||||
"""
|
||||
|
||||
# import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
L, M,
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX, D0,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
|
||||
# TODO: may replace with TMA store without range offset
|
||||
# initialize offsets for store
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
# initialize pointer to m and l
|
||||
m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
stride_qh_2d = stride_qh // stride_qm // stride_qk
|
||||
|
||||
q_tile_ptr = tl.make_block_ptr(base=Q,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(
|
||||
off_hz * stride_qh_2d + start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
k_tile_ptr = tl.make_block_ptr(base=K,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(off_hz * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
v_tile_ptr = tl.make_block_ptr(base=V,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(off_hz * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
out_tile_ptr = tl.make_block_ptr(base=Out,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(q_tile_ptr)
|
||||
|
||||
# loop over k, v and update accumulators
|
||||
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_tile_ptr, boundary_check=(0, 1))
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (
|
||||
start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
# compute new m
|
||||
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
|
||||
# correct old l
|
||||
l_prev *= tl.exp(m_prev - m_curr)
|
||||
# attention weights
|
||||
p = tl.exp(qk - m_curr[:, None])
|
||||
l_curr = tl.sum(p, 1) + l_prev
|
||||
# rescale operands of matmuls
|
||||
l_rcp = 1. / l_curr
|
||||
p *= l_rcp[:, None]
|
||||
acc *= (l_prev * l_rcp)[:, None]
|
||||
# update acc
|
||||
p = p.to(tl.float16)
|
||||
v = tl.load(v_tile_ptr, boundary_check=(0, 1))
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_prev = l_curr
|
||||
m_prev = m_curr
|
||||
# update pointers
|
||||
k_tile_ptr = tl.advance(k_tile_ptr, [BLOCK_N, 0])
|
||||
v_tile_ptr = tl.advance(v_tile_ptr, [BLOCK_N, 0])
|
||||
# rematerialize offsets to save registers
|
||||
start_m = tl.program_id(0)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# write back l and m
|
||||
l_ptrs = L + off_hz * N_CTX + offs_m
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
tl.store(l_ptrs, l_prev)
|
||||
tl.store(m_ptrs, m_prev)
|
||||
|
||||
acc = acc.to(tl.float16)
|
||||
tl.store(out_tile_ptr, acc, boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO, L,
|
||||
NewDO, Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
# load
|
||||
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
denom = tl.load(L + off_m).to(tl.float32)
|
||||
# compute
|
||||
do = do / denom[:, None]
|
||||
delta = tl.sum(o * do, axis=1)
|
||||
# write-back
|
||||
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
|
||||
tl.store(Delta + off_m, delta)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L, M,
|
||||
D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX, D0,
|
||||
num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
# init tile_ptr
|
||||
stride_qz_2d = stride_qz // stride_qm // stride_qk
|
||||
stride_qh_2d = stride_qh // stride_qm // stride_qk
|
||||
|
||||
q_tile_ptr = tl.make_block_ptr(base=Q,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
k_tile_ptr = tl.make_block_ptr(base=K,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
v_tile_ptr = tl.make_block_ptr(base=V,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
do_tile_ptr = tl.make_block_ptr(base=DO,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
dq_tile_ptr = tl.make_block_ptr(base=DQ,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
dk_tile_ptr = tl.make_block_ptr(base=DK,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
dv_tile_ptr = tl.make_block_ptr(base=DV,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
# offset pointers for batch/head
|
||||
DQ += off_z * stride_qz + off_h * stride_qh
|
||||
for start_n in range(0, num_block):
|
||||
lo = start_n * BLOCK_M
|
||||
# initialize row/col offsets
|
||||
offs_qm = lo + tl.arange(0, BLOCK_M)
|
||||
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_m = tl.arange(0, BLOCK_N)
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
# initialize pointers to value-like data
|
||||
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D_ptrs = D + off_hz * N_CTX
|
||||
m_ptrs = M + off_hz * N_CTX
|
||||
# initialize dv amd dk
|
||||
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# k and v stay in SRAM throughout
|
||||
k = tl.load(k_tile_ptr, boundary_check=(0, 1))
|
||||
v = tl.load(v_tile_ptr, boundary_check=(0, 1))
|
||||
# loop over rows
|
||||
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
|
||||
offs_m_curr = start_m + offs_m
|
||||
# load q, k, v, do on-chip
|
||||
q = tl.load(q_tile_ptr, boundary_check=(0, 1))
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
qk = tl.dot(q, tl.trans(k))
|
||||
qk = tl.where(offs_m_curr[:, None] >= (
|
||||
offs_n[None, :]), qk, float("-inf"))
|
||||
m = tl.load(m_ptrs + offs_m_curr)
|
||||
p = tl.exp(qk * sm_scale - m[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_tile_ptr, boundary_check=(0, 1))
|
||||
dv += tl.dot(tl.trans(p.to(tl.float16)), do)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
dp += tl.dot(do, tl.trans(v))
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
ds = p * dp * sm_scale
|
||||
# compute dk = dot(ds.T, q)
|
||||
dk += tl.dot(tl.trans(ds.to(tl.float16)), q)
|
||||
# compute dq
|
||||
dq = tl.load(dq_tile_ptr)
|
||||
dq += tl.dot(ds.to(tl.float16), k)
|
||||
tl.store(dq_tile_ptr, dq)
|
||||
# increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_qm
|
||||
q_tile_ptr = tl.advance(q_tile_ptr, [BLOCK_M, 0])
|
||||
do_tile_ptr = tl.advance(do_tile_ptr, [BLOCK_M, 0])
|
||||
dq_tile_ptr = tl.advance(dq_tile_ptr, [BLOCK_M, 0])
|
||||
q_tile_ptr = tl.advance(q_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0])
|
||||
do_tile_ptr = tl.advance(do_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0])
|
||||
dq_tile_ptr = tl.advance(dq_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0])
|
||||
# increment tile pointers
|
||||
k_tile_ptr = tl.advance(k_tile_ptr, [BLOCK_M, 0])
|
||||
v_tile_ptr = tl.advance(v_tile_ptr, [BLOCK_M, 0])
|
||||
# write-back
|
||||
tl.store(dv_tile_ptr, dv.to(tl.float16), boundary_check=(0, 1))
|
||||
tl.store(dk_tile_ptr, dk.to(tl.float16), boundary_check=(0, 1))
|
||||
dv_tile_ptr = tl.advance(dv_tile_ptr, [BLOCK_M, 0])
|
||||
dk_tile_ptr = tl.advance(dk_tile_ptr, [BLOCK_M, 0])
|
||||
|
||||
|
||||
empty = torch.empty(128, device="cuda")
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, sm_scale):
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
|
||||
L = torch.empty(
|
||||
(q.shape[0] * q.shape[1], q.shape[2]),
|
||||
device=q.device,
|
||||
dtype=torch.float32)
|
||||
m = torch.empty(
|
||||
(q.shape[0] * q.shape[1], q.shape[2]),
|
||||
device=q.device,
|
||||
dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
D0 = q.shape[0] * q.shape[1] * q.shape[2]
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
L, m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2], D0,
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||
num_stages=2,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = Lk
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
BLOCK = 128
|
||||
q, k, v, o, l, m = ctx.saved_tensors
|
||||
do = do.contiguous()
|
||||
dq = torch.zeros_like(q, dtype=torch.float32)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
do_scaled = torch.empty_like(do)
|
||||
delta = torch.empty_like(l)
|
||||
D0 = q.shape[0] * q.shape[1] * q.shape[2]
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do, l,
|
||||
do_scaled, delta,
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
l, m,
|
||||
delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2], D0,
|
||||
ctx.grid[0],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||
num_stages=1,
|
||||
)
|
||||
return dq, dk, dv, None
|
||||
|
||||
|
||||
attention = _attention.apply
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 128, 64),
|
||||
# (4, 48, 256, 64),
|
||||
# (4, 48, 512, 64),
|
||||
# (4, 48, 1024, 64),
|
||||
# (4, 48, 2048, 64),
|
||||
# (4, 48, 4096, 64),
|
||||
# (4, 48, 8192, 64), out of memory
|
||||
])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+")
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty(
|
||||
(Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(
|
||||
mean=0.1, std=0.2).requires_grad_()
|
||||
k = torch.empty(
|
||||
(Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(
|
||||
mean=0.4, std=0.2).requires_grad_()
|
||||
v = torch.empty(
|
||||
(Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(
|
||||
mean=0.3, std=0.2).requires_grad_()
|
||||
sm_scale = 0.2
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||||
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
||||
for z in range(Z):
|
||||
for h in range(H):
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
# p = torch.exp(p)
|
||||
ref_out = torch.matmul(p, v)
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
# triton implementation
|
||||
tri_out = attention(q, k, v, sm_scale)
|
||||
# print(ref_out)
|
||||
# print(tri_out)
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0)
|
||||
torch.testing.assert_close(ref_dq, tri_dq, atol=1e-2, rtol=0)
|
||||
torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=0)
|
||||
torch.testing.assert_close(ref_dk, tri_dk, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
HAS_FLASH = True
|
||||
except BaseException:
|
||||
HAS_FLASH = False
|
||||
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = [triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
x_vals=[2**i for i in range(10, 14)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={
|
||||
'H': N_HEADS,
|
||||
'BATCH': BATCH,
|
||||
'D_HEAD': D_HEAD,
|
||||
'dtype': torch.float16,
|
||||
'mode': mode}
|
||||
) for mode in ['fwd', 'bwd']]
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
|
||||
assert mode in ['fwd', 'bwd']
|
||||
warmup = 25
|
||||
rep = 100
|
||||
if provider == "triton":
|
||||
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
sm_scale = 1.3
|
||||
fn = lambda: attention(q, k, v, sm_scale)
|
||||
if mode == 'bwd':
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
if provider == "flash":
|
||||
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
||||
cu_seqlens = torch.zeros(
|
||||
(BATCH + 1,), device=device, dtype=torch.int32)
|
||||
cu_seqlens[1:] = lengths.cumsum(0)
|
||||
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
|
||||
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
|
||||
if mode == 'bwd':
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
||||
|
||||
# only works on post-Ampere GPUs right now
|
||||
# bench_flash_attention.run(save_path='.', print_data=True)
|
||||
469
python/test/unit/hopper/test_gemm.py
Normal file
469
python/test/unit/hopper/test_gemm.py
Normal file
@@ -0,0 +1,469 @@
|
||||
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files
|
||||
# (the "Software"), to deal in the Software without restriction,
|
||||
# including without limitation the rights to use, copy, modify, merge,
|
||||
# publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
# and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from .utils import get_proper_err, get_variant_golden
|
||||
|
||||
|
||||
@triton.jit
|
||||
def matmul_no_scf_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr
|
||||
):
|
||||
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
a = tl.load(a_block_ptr)
|
||||
b = tl.load(b_block_ptr)
|
||||
|
||||
c = tl.dot(a, b)
|
||||
|
||||
if FLOAT16_OUTPUT:
|
||||
c = c.to(tl.float16)
|
||||
|
||||
if USE_TMA_EPILOGUE:
|
||||
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
|
||||
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
|
||||
tl.store(c_block_ptr, c)
|
||||
else:
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
||||
tl.store(c_ptrs, c)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,ENABLE_WS',
|
||||
itertools.chain(
|
||||
*[
|
||||
[
|
||||
# numCTAs = 1, no TMA multicast:
|
||||
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
# static mask, cluster 4x1
|
||||
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
# dynamic mask, cluster 2x2
|
||||
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
# small M, N
|
||||
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
] for USE_TMA_EPILOGUE in [True, False]
|
||||
for ENABLE_WS in [False, True]
|
||||
]))
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, ENABLE_WS):
|
||||
if '-'.join(map(str, [USE_TMA_EPILOGUE, ENABLE_WS])) in [
|
||||
'True-True'
|
||||
]:
|
||||
pytest.skip("error, skip")
|
||||
|
||||
if (TRANS_A):
|
||||
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
if (TRANS_B):
|
||||
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
|
||||
if OUTPUT_TYPE == "float16":
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
|
||||
else:
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
matmul_no_scf_kernel[(1, 1)](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
M=M, N=N, K=K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K,
|
||||
num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS,
|
||||
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"),
|
||||
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE,
|
||||
enable_warp_specialization=ENABLE_WS)
|
||||
a_f32 = a.to(torch.float32)
|
||||
b_f32 = b.to(torch.float32)
|
||||
golden = torch.matmul(a_f32, b_f32)
|
||||
golden_variant = get_variant_golden(a_f32, b_f32)
|
||||
golden_abs_err, golden_rel_err = get_proper_err(golden, golden_variant)
|
||||
torch.set_printoptions(profile="full")
|
||||
assert_close(
|
||||
c,
|
||||
golden,
|
||||
rtol=max(1e-2, 1.1 * golden_rel_err),
|
||||
atol=max(1e-3, 1.1 * golden_abs_err),
|
||||
check_dtype=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_wm, stride_wn,
|
||||
stride_zm, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
|
||||
out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr,
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
||||
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
|
||||
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr,
|
||||
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr,
|
||||
Z_ORDER_0: tl.constexpr, Z_ORDER_1: tl.constexpr
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_N)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_M)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
block_offset_m = pid_m * BLOCK_M
|
||||
block_offset_n = pid_n * BLOCK_N
|
||||
|
||||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(A_ORDER_0, A_ORDER_1))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(B_ORDER_0, B_ORDER_1))
|
||||
w_tile_ptr = tl.make_block_ptr(base=w_ptr, shape=(N, N), strides=(stride_wm, stride_wn),
|
||||
offsets=(0, block_offset_n), block_shape=(BLOCK_N, BLOCK_N), order=(Z_ORDER_1, Z_ORDER_0))
|
||||
z = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
|
||||
offs_m = block_offset_m + tl.arange(0, BLOCK_M)
|
||||
offs_n = block_offset_n + tl.arange(0, BLOCK_N)
|
||||
z_ptrs = z_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn
|
||||
bias_ptrs = bias_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn
|
||||
mask = (offs_m < M)[:, None] & (offs_n < N)[None, :]
|
||||
|
||||
for k in range(0, K, BLOCK_K):
|
||||
a = tl.load(a_tile_ptr, boundary_check=(0, 1))
|
||||
b = tl.load(b_tile_ptr, boundary_check=(0, 1))
|
||||
z += tl.dot(a, b)
|
||||
a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_K])
|
||||
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_K, 0])
|
||||
|
||||
z = z.to(out_dtype)
|
||||
|
||||
if ADD_MATRIX:
|
||||
z += tl.load(bias_ptrs, mask=mask)
|
||||
if ADD_ROWS:
|
||||
ZRs = bias_ptr + offs_m * stride_zm
|
||||
z += tl.load(ZRs)[:, None]
|
||||
if ADD_COLS:
|
||||
ZCs = bias_ptr + offs_n * stride_zn
|
||||
z += tl.load(ZCs)[None, :]
|
||||
if DO_SOFTMAX:
|
||||
max = tl.max(z, 1)
|
||||
z = z - max[:, None]
|
||||
num = tl.exp(z.to(tl.float32)).to(max.dtype)
|
||||
den = tl.sum(num, 1)
|
||||
z = num / den[:, None]
|
||||
if CHAIN_DOT:
|
||||
w = tl.load(w_tile_ptr)
|
||||
z = tl.dot(z.to(w.dtype), w)
|
||||
z = z.to(out_dtype)
|
||||
|
||||
if USE_TMA_STORE:
|
||||
z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn),
|
||||
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(Z_ORDER_0, Z_ORDER_1))
|
||||
tl.store(z_block_ptr, z, boundary_check=(0, 1))
|
||||
else:
|
||||
tl.store(z_ptrs, z, mask=mask)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_C,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
|
||||
[(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws)
|
||||
for shape_w_c in [
|
||||
# badcase from cublas-important-layers
|
||||
[4096, 1, 1024, False, False, True],
|
||||
[2048, 204, 1000, True, False, True],
|
||||
[4096, 1, 1024, False, False, False],
|
||||
[2048, 204, 1000, True, False, False],
|
||||
]
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for enable_ws in [False, True]
|
||||
] + [(*shape_w_c, trans_a, trans_b, trans_c, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
# softmax works for one CTA
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 64, 64, 64],
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[16, 16, 64, 4, 1, 16, 16, 64],
|
||||
[64, 64, 32, 8, 1, 64, 64, 64],
|
||||
[128, 128, 64, 4, 1, 128, 128, 128],
|
||||
]
|
||||
for epilogue in ['softmax']
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False, True]
|
||||
for trans_b in [False, True]
|
||||
for trans_c in [False, True]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [False, True]
|
||||
] + [(*shape_w_c, trans_a, trans_b, trans_c, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 128, 128, 64],
|
||||
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
|
||||
# for chain-dot
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[64, 64, 16, 4, 1, None, None, None],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 16, 64, 4, 1, 128, 128, 64],
|
||||
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
|
||||
# repeat
|
||||
[64, 64, 32, 8, 1, 128, 256, 64],
|
||||
[64, 64, 16, 8, 2, 128, 128, 64],
|
||||
# irregular shape
|
||||
[128, 128, 64, 4, 1, 500, 200, 128],
|
||||
[128, 128, 64, 4, 2, 513, 193, 192],
|
||||
]
|
||||
for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot']
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False, True]
|
||||
for trans_b in [False, True]
|
||||
for trans_c in [False, True]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [False, True]
|
||||
if not (epilogue == 'chain-dot' and (shape_w_c[5] is not None or shape_w_c[0] != shape_w_c[1]))
|
||||
] + [(*shape_w_c, trans_a, trans_b, trans_c, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 32, 4, 1, 128, 256, 64],
|
||||
[128, 128, 16, 4, 4, 512, 256, 64],
|
||||
[128, 256, 32, 4, 8, 256, 256, 192],
|
||||
[512, 256, 32, 4, 8, 1024, 256, 192],
|
||||
# BLOCK_K >= 128
|
||||
[64, 128, 128, 4, 1, 512, 256, 256],
|
||||
[128, 128, 128, 4, 1, 256, 256, 192],
|
||||
[128, 128, 128, 4, 2, 256, 256, 192],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 32, 32, 4, 1, 128, 256, 64],
|
||||
[32, 32, 16, 4, 1, 256, 256, 192],
|
||||
[16, 32, 64, 4, 4, 512, 256, 64],
|
||||
]
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False, True]
|
||||
for trans_b in [False, True]
|
||||
for trans_c in [False, True]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [False, True]
|
||||
] + [(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_c, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
# loop over instr shapes
|
||||
for n in [16, 32, 64, 128, 256]
|
||||
for trans_c in [False, True]
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for num_stages in [2, 4, 5, 7]
|
||||
for enable_ws in [False, True]
|
||||
] + [(*shape_w_c, *shape, False, True, trans_c, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
# irregular shapes
|
||||
for shape_w_c in [
|
||||
[128, 128, 64, 4, 1],
|
||||
[256, 128, 64, 4, 2],
|
||||
[128, 128, 128, 4, 2],
|
||||
]
|
||||
for shape in list(itertools.product([*range(512, 4096, 360)], [*range(512, 4096, 360)], [512, 1024]))
|
||||
for trans_c in [False, True]
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for num_stages in [2, 3, 4]
|
||||
for enable_ws in [False, True]
|
||||
])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()
|
||||
[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_C, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
|
||||
'16-32-64-4-4-512-256-64-True-False',
|
||||
'16-32-64-4-4-512-256-64-True-True',
|
||||
'16-32-64-4-4-512-256-64-False-False',
|
||||
'16-32-64-4-4-512-256-64-False-True',
|
||||
]:
|
||||
pytest.skip('shapePerCTA[1] < 16 not supported')
|
||||
|
||||
# with ENABLE_TMA=0 and ENABLE_MMA_V3=0
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
|
||||
'16-32-64-4-1-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-True',
|
||||
'16-32-64-8-2-256-256-256-False',
|
||||
'16-32-64-8-2-256-256-256-True',
|
||||
]:
|
||||
pytest.skip('illegal memory access.')
|
||||
|
||||
# with ENABLE_TMA=1 and ENABLE_MMA_V3=1
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K])) in [
|
||||
'64-64-32-8-1-128-256-64',
|
||||
]:
|
||||
pytest.skip('Tensor-likes are not close!')
|
||||
|
||||
if NUM_CTAS > 1 and NUM_WARPS == 8:
|
||||
pytest.skip('Tensor-likes are not close!')
|
||||
|
||||
# with ENABLE_TMA=1 and ENABLE_MMA_V3=1
|
||||
if ENABLE_WS:
|
||||
# example:
|
||||
# [128-128-64-4-1-None-None-None-False-False-False-chain-dot-float16-False-3-True]
|
||||
pytest.skip('hang!')
|
||||
|
||||
M = BLOCK_M if M is None else M
|
||||
N = BLOCK_N if N is None else N
|
||||
K = BLOCK_K if K is None else K
|
||||
|
||||
if (TRANS_A):
|
||||
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
a_order = [0, 1]
|
||||
else:
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
a_order = [1, 0]
|
||||
|
||||
if (TRANS_B):
|
||||
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
|
||||
b_order = [0, 1]
|
||||
else:
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
b_order = [1, 0]
|
||||
|
||||
if out_dtype == 'float16' and epilogue != 'softmax':
|
||||
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
|
||||
# fail with the following error: 'llvm.fmul' op requires the same type
|
||||
# for all operands and results
|
||||
out_dtype = tl.float16
|
||||
torch_out_dtype = torch.float16
|
||||
else:
|
||||
out_dtype = tl.float32
|
||||
torch_out_dtype = torch.float32
|
||||
|
||||
# avoid out of memory
|
||||
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
|
||||
if (TRANS_C):
|
||||
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
|
||||
else:
|
||||
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
|
||||
else:
|
||||
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
|
||||
|
||||
if epilogue == 'chain-dot':
|
||||
if (TRANS_C):
|
||||
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
w = torch.randn((M, M), device='cuda', dtype=torch.float16)
|
||||
else:
|
||||
w = torch.randn((1, 1), device='cuda', dtype=torch.float16).T
|
||||
|
||||
if (TRANS_C):
|
||||
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
|
||||
z_order = [1, 0]
|
||||
else:
|
||||
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
|
||||
z_order = [0, 1]
|
||||
|
||||
# torch result
|
||||
a_f32 = a.to(torch.float32)
|
||||
b_f32 = b.to(torch.float32)
|
||||
dot = torch.matmul(a_f32, b_f32)
|
||||
|
||||
def process_epilogue(d, bias, w, epilogue):
|
||||
if epilogue == 'add-matrix':
|
||||
ref = d + bias
|
||||
elif epilogue == 'add-rows':
|
||||
ref = d + bias[:, 0][:, None]
|
||||
elif epilogue == 'add-cols':
|
||||
ref = d + bias[0, :][None, :]
|
||||
elif epilogue == 'softmax':
|
||||
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
|
||||
denom = torch.sum(num, dim=-1, keepdims=True)
|
||||
ref = num / denom
|
||||
# ref = torch.softmax(d, 1)
|
||||
elif epilogue == 'chain-dot':
|
||||
ref = torch.matmul(d, w.to(torch.float32))
|
||||
else:
|
||||
ref = d
|
||||
return ref
|
||||
golden = process_epilogue(dot, bias, w, epilogue)
|
||||
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
|
||||
pgm = matmul_kernel[grid](a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z,
|
||||
M=M, N=N, K=K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_wm=w.stride(0), stride_wn=w.stride(1),
|
||||
stride_zm=z.stride(0), stride_zn=z.stride(1),
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8,
|
||||
out_dtype=out_dtype,
|
||||
USE_TMA_STORE=USE_TMA_STORE,
|
||||
ADD_MATRIX=epilogue == 'add-matrix',
|
||||
ADD_ROWS=epilogue == 'add-rows',
|
||||
ADD_COLS=epilogue == 'add-cols',
|
||||
DO_SOFTMAX=epilogue == 'softmax',
|
||||
CHAIN_DOT=epilogue == 'chain-dot',
|
||||
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1],
|
||||
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1],
|
||||
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1],
|
||||
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES,
|
||||
enable_warp_specialization=ENABLE_WS)
|
||||
|
||||
torch.set_printoptions(profile="full")
|
||||
# print("abs_err: {}, rel_err: {}".format(golden_abs_err, golden_rel_err))
|
||||
# print("golden: ")
|
||||
# print(golden)
|
||||
# print("result: ")
|
||||
# print(z)
|
||||
# print("max_gap: {}".format(torch.max(torch.abs(z - golden))))
|
||||
golden = torch.nn.functional.normalize(golden)
|
||||
z = torch.nn.functional.normalize(z)
|
||||
assert_close(z, golden,
|
||||
rtol=1e-2,
|
||||
atol=1e-3,
|
||||
check_dtype=False)
|
||||
|
||||
enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower()
|
||||
if enable_mmav3 in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
|
||||
ptx = pgm.asm['ptx']
|
||||
assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(BLOCK_N), ptx)
|
||||
166
python/test/unit/hopper/test_gemm_fusion.py
Normal file
166
python/test/unit/hopper/test_gemm_fusion.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files
|
||||
# (the "Software"), to deal in the Software without restriction,
|
||||
# including without limitation the rights to use, copy, modify, merge,
|
||||
# publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
# and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def gemm_fusion_kernel(A, B, C, E,
|
||||
M, N, K,
|
||||
stride_am, stride_ak, stride_bn, stride_bk, stride_cn, stride_ck, stride_em, stride_ek,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
a_tile_ptr = tl.make_block_ptr(base=A, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_tile_ptr = tl.make_block_ptr(base=B, shape=(N, K), strides=(stride_bn, stride_bk), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))
|
||||
c_tile_ptr = tl.make_block_ptr(base=C, shape=(N, K), strides=(stride_cn, stride_ck), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))
|
||||
e_tile_ptr = tl.make_block_ptr(base=E, shape=(M, K), strides=(stride_em, stride_ek), offsets=(pid * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
|
||||
acc_e = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
|
||||
a = tl.load(a_tile_ptr)
|
||||
for i in range(0, N, BLOCK_N):
|
||||
b = tl.load(b_tile_ptr)
|
||||
o_ab = tl.dot(a, tl.trans(b))
|
||||
c = tl.load(c_tile_ptr)
|
||||
o_ab = o_ab.to(tl.float16)
|
||||
acc_e += tl.dot(o_ab, c)
|
||||
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_N, 0])
|
||||
c_tile_ptr = tl.advance(c_tile_ptr, [BLOCK_N, 0])
|
||||
|
||||
acc_e = acc_e.to(tl.float16)
|
||||
tl.store(e_tile_ptr, acc_e)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="not passed on ampere")
|
||||
def test_gemm_fusion():
|
||||
M, N, K = 4096, 4096, 64
|
||||
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64
|
||||
A = torch.empty(
|
||||
(M, K), dtype=torch.float16, device='cuda').normal_(
|
||||
mean=0.1, std=0.2)
|
||||
B = torch.empty(
|
||||
(N, K), dtype=torch.float16, device='cuda').normal_(
|
||||
mean=0.1, std=0.2)
|
||||
C = torch.empty(
|
||||
(N, K), dtype=torch.float16, device='cuda').normal_(
|
||||
mean=0.1, std=0.2)
|
||||
E = torch.empty((M, K), dtype=torch.float16, device='cuda')
|
||||
ref_out = torch.matmul(torch.matmul(A, B.T), C)
|
||||
num_warps = 4
|
||||
grid = (triton.cdiv(M, BLOCK_M), 1)
|
||||
gemm_fusion_kernel[grid](A, B, C, E, M, N, K,
|
||||
A.stride(0), A.stride(1), B.stride(0), B.stride(
|
||||
1), C.stride(0), C.stride(1), E.stride(0), E.stride(1),
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_warps=num_warps)
|
||||
|
||||
torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def batched_gemm_fusion(
|
||||
Q, K, V, Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, NH, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
q_tile_ptr = tl.make_block_ptr(base=Q,
|
||||
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qz, stride_qh, stride_qm, stride_qk),
|
||||
offsets=(off_hz // NH, off_hz % NH, start_m, 0),
|
||||
block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),
|
||||
order=(3, 2, 1, 0))
|
||||
k_tile_ptr = tl.make_block_ptr(base=K,
|
||||
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_kz, stride_kh, stride_kn, stride_kk),
|
||||
offsets=(off_hz // NH, off_hz % NH, 0, 0),
|
||||
block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),
|
||||
order=(3, 2, 1, 0))
|
||||
v_tile_ptr = tl.make_block_ptr(base=V,
|
||||
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_vz, stride_vh, stride_vk, stride_vn),
|
||||
offsets=(off_hz // NH, off_hz % NH, 0, 0),
|
||||
block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),
|
||||
order=(3, 2, 1, 0))
|
||||
o_tile_ptr = tl.make_block_ptr(base=Out,
|
||||
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_oz, stride_oh, stride_om, stride_on),
|
||||
offsets=(off_hz // NH, off_hz % NH, start_m, 0),
|
||||
block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),
|
||||
order=(3, 2, 1, 0))
|
||||
|
||||
q = tl.load(q_tile_ptr, boundary_check=(0, 1, 2, 3))
|
||||
q = tl.view(q, (BLOCK_M, BLOCK_DMODEL))
|
||||
for i in range(0, N_CTX, BLOCK_N):
|
||||
k = tl.load(k_tile_ptr, boundary_check=(0, 1, 2, 3))
|
||||
k = tl.view(k, (BLOCK_N, BLOCK_DMODEL))
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
|
||||
p = qk.to(tl.float16)
|
||||
v = tl.load(v_tile_ptr, boundary_check=(0, 1, 2, 3))
|
||||
v = tl.view(v, (BLOCK_N, BLOCK_DMODEL))
|
||||
acc += tl.dot(p, v)
|
||||
|
||||
k_tile_ptr = tl.advance(k_tile_ptr, [0, 0, BLOCK_N, 0])
|
||||
v_tile_ptr = tl.advance(v_tile_ptr, [0, 0, BLOCK_N, 0])
|
||||
|
||||
acc = tl.view(acc, (1, 1, BLOCK_M, BLOCK_DMODEL))
|
||||
acc = acc.to(tl.float16)
|
||||
tl.store(o_tile_ptr, acc)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="don't support 4d across stack, left for future")
|
||||
def test_batched_gemm_fusion():
|
||||
Z = 4
|
||||
NH = 48
|
||||
H = 64
|
||||
N_CTX = 2048
|
||||
BLOCK_M, BLOCK_N, BLOCK_DMODEL = 128, 128, H
|
||||
torch.manual_seed(20)
|
||||
A = torch.empty((Z, NH, N_CTX, H), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)
|
||||
B = torch.empty((Z, NH, N_CTX, H), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)
|
||||
C = torch.empty((Z, NH, N_CTX, H), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)
|
||||
E = torch.empty_like(A)
|
||||
BT = B.transpose(-1, -2)
|
||||
ref_out = torch.matmul(torch.matmul(A, BT), C)
|
||||
num_warps = 4
|
||||
grid = (triton.cdiv(N_CTX, BLOCK_M), B * NH)
|
||||
batched_gemm_fusion[grid](A, B, C, E,
|
||||
A.stride(0), A.stride(1), A.stride(2), A.stride(3),
|
||||
B.stride(0), B.stride(1), B.stride(2), B.stride(3),
|
||||
C.stride(0), C.stride(1), C.stride(2), C.stride(3),
|
||||
E.stride(0), E.stride(1), E.stride(2), E.stride(3),
|
||||
Z, NH, N_CTX,
|
||||
BLOCK_M, BLOCK_DMODEL, BLOCK_N, num_warps=num_warps)
|
||||
|
||||
torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0)
|
||||
89
python/test/unit/hopper/test_mixed_io.py
Normal file
89
python/test/unit/hopper/test_mixed_io.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
dtype_mapping = {
|
||||
'float16': torch.float16,
|
||||
'float32': torch.float32,
|
||||
}
|
||||
|
||||
|
||||
@triton.jit
|
||||
def add_kernel(
|
||||
x_ptr,
|
||||
y_ptr,
|
||||
output_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
x_block_ptr = tl.make_block_ptr(
|
||||
base=x_ptr, shape=(n_elements, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
|
||||
block_shape=(BLOCK_SIZE, ), order=(0, )
|
||||
)
|
||||
x = tl.load(x_block_ptr, boundary_check=(0, ), padding_option='zero')
|
||||
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x + y
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('SIZE,BLOCK_SIZE,dtype_str',
|
||||
[(98432, 1024, dtype_str)
|
||||
for dtype_str in ['float16', 'float32']
|
||||
])
|
||||
def test_add(SIZE, BLOCK_SIZE, dtype_str):
|
||||
dtype = dtype_mapping[dtype_str]
|
||||
output = torch.empty(SIZE, device='cuda', dtype=dtype)
|
||||
x = torch.randn(SIZE, device='cuda', dtype=dtype)
|
||||
y = torch.randn(SIZE, device='cuda', dtype=dtype)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
|
||||
add_kernel[grid](x, y, output, SIZE, BLOCK_SIZE=BLOCK_SIZE)
|
||||
|
||||
output_torch = x + y
|
||||
torch.set_printoptions(profile='full')
|
||||
assert_close(output, output_torch, rtol=1e-2, atol=1e-3, check_dtype=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def load_reduce_kernel(
|
||||
x_ptr,
|
||||
y_ptr,
|
||||
stride_xm,
|
||||
stride_xn,
|
||||
stride_y,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
x_ptr = tl.make_block_ptr(
|
||||
base=x_ptr, shape=(BLOCK_M, BLOCK_N), strides=(stride_xm, stride_xn),
|
||||
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)
|
||||
)
|
||||
x = tl.load(x_ptr)
|
||||
y = tl.max(x, axis=1)
|
||||
tl.store(y_ptr + tl.arange(0, BLOCK_M), y)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,dtype_str',
|
||||
[(128, 64, dtype_str)
|
||||
for dtype_str in ['float16']
|
||||
])
|
||||
def test_load_reduce(BLOCK_M, BLOCK_N, dtype_str):
|
||||
dtype = dtype_mapping[dtype_str]
|
||||
x = torch.randn((BLOCK_M, BLOCK_N), device='cuda', dtype=dtype)
|
||||
y = torch.empty((BLOCK_M, ), device='cuda', dtype=dtype)
|
||||
|
||||
load_reduce_kernel[(1,)](x, y, x.stride(0), x.stride(1), y.stride(0), BLOCK_M, BLOCK_N)
|
||||
|
||||
golden = x.max(dim=1)[0]
|
||||
torch.set_printoptions(profile='full')
|
||||
assert_close(y, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
|
||||
@@ -0,0 +1,395 @@
|
||||
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files
|
||||
# (the "Software"), to deal in the Software without restriction,
|
||||
# including without limitation the rights to use, copy, modify, merge,
|
||||
# publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
# and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
persistent warp specialized version of python/tutorials/06-fused-attention.py.
|
||||
As of now, it only supports non-persistent warp specialized version of _fwd kernel.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=4, num_stages=2, enable_warp_specialization=True),
|
||||
],
|
||||
key=['Q', 'K', 'V'],
|
||||
)
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
L, M,
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
|
||||
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
# Initialize pointers to Q, K, V
|
||||
q_ptrs = Q + off_q
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
# initialize pointer to m and l
|
||||
m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(q_ptrs)
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
# compute new m
|
||||
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
|
||||
# correct old l
|
||||
l_prev *= tl.exp(m_prev - m_curr)
|
||||
# attention weights
|
||||
p = tl.exp(qk - m_curr[:, None])
|
||||
l_curr = tl.sum(p, 1) + l_prev
|
||||
# rescale operands of matmuls
|
||||
l_rcp = 1. / l_curr
|
||||
p *= l_rcp[:, None]
|
||||
acc *= (l_prev * l_rcp)[:, None]
|
||||
# update acc
|
||||
p = p.to(Q.dtype.element_ty)
|
||||
v = tl.load(v_ptrs)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_prev = l_curr
|
||||
m_prev = m_curr
|
||||
# update pointers
|
||||
k_ptrs += BLOCK_N * stride_kn
|
||||
v_ptrs += BLOCK_N * stride_vk
|
||||
# rematerialize offsets to save registers
|
||||
start_m = tl.program_id(0)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# write back l and m
|
||||
l_ptrs = L + off_hz * N_CTX + offs_m
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
tl.store(l_ptrs, l_prev)
|
||||
tl.store(m_ptrs, m_prev)
|
||||
# initialize pointers to output
|
||||
offs_n = tl.arange(0, BLOCK_DMODEL)
|
||||
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO, L,
|
||||
NewDO, Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
# load
|
||||
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
denom = tl.load(L + off_m).to(tl.float32)
|
||||
# compute
|
||||
do = do / denom[:, None]
|
||||
delta = tl.sum(o * do, axis=1)
|
||||
# write-back
|
||||
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
|
||||
tl.store(Delta + off_m, delta)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L, M,
|
||||
D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX,
|
||||
num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
# offset pointers for batch/head
|
||||
Q += off_z * stride_qz + off_h * stride_qh
|
||||
K += off_z * stride_qz + off_h * stride_qh
|
||||
V += off_z * stride_qz + off_h * stride_qh
|
||||
DO += off_z * stride_qz + off_h * stride_qh
|
||||
DQ += off_z * stride_qz + off_h * stride_qh
|
||||
DK += off_z * stride_qz + off_h * stride_qh
|
||||
DV += off_z * stride_qz + off_h * stride_qh
|
||||
for start_n in range(0, num_block):
|
||||
lo = start_n * BLOCK_M
|
||||
# initialize row/col offsets
|
||||
offs_qm = lo + tl.arange(0, BLOCK_M)
|
||||
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_m = tl.arange(0, BLOCK_N)
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
# initialize pointers to value-like data
|
||||
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D_ptrs = D + off_hz * N_CTX
|
||||
m_ptrs = M + off_hz * N_CTX
|
||||
# initialize dv amd dk
|
||||
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# k and v stay in SRAM throughout
|
||||
k = tl.load(k_ptrs)
|
||||
v = tl.load(v_ptrs)
|
||||
# loop over rows
|
||||
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
|
||||
offs_m_curr = start_m + offs_m
|
||||
# load q, k, v, do on-chip
|
||||
q = tl.load(q_ptrs)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
qk = tl.dot(q, tl.trans(k))
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
m = tl.load(m_ptrs + offs_m_curr)
|
||||
p = tl.exp(qk * sm_scale - m[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
dp += tl.dot(do, tl.trans(v))
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
ds = p * dp * sm_scale
|
||||
# compute dk = dot(ds.T, q)
|
||||
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
|
||||
# compute dq
|
||||
dq = tl.load(dq_ptrs)
|
||||
dq += tl.dot(ds.to(Q.dtype.element_ty), k)
|
||||
tl.store(dq_ptrs, dq)
|
||||
# increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_qm
|
||||
q_ptrs += BLOCK_M * stride_qm
|
||||
do_ptrs += BLOCK_M * stride_qm
|
||||
# write-back
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
|
||||
empty = torch.empty(128, device="cuda")
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, sm_scale):
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
# Only support num_warps = 4 now
|
||||
assert num_warps == 4
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
L, m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = Lk
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
BLOCK = 128
|
||||
q, k, v, o, l, m = ctx.saved_tensors
|
||||
do = do.contiguous()
|
||||
dq = torch.zeros_like(q, dtype=torch.float32)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
do_scaled = torch.empty_like(do)
|
||||
delta = torch.empty_like(l)
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do, l,
|
||||
do_scaled, delta,
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
l, m,
|
||||
delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
ctx.grid[0],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||
num_stages=1,
|
||||
)
|
||||
return dq, dk, dv, None
|
||||
|
||||
|
||||
attention = _attention.apply
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
# with ENABLE_TMA=0 and ENABLE_MMA_V3=0
|
||||
pytest.skip('unspecified launch failure')
|
||||
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
|
||||
sm_scale = 0.2
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||||
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
||||
for z in range(Z):
|
||||
for h in range(H):
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
# p = torch.exp(p)
|
||||
ref_out = torch.matmul(p, v)
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
# # triton implementation
|
||||
tri_out = attention(q, k, v, sm_scale)
|
||||
# print(ref_out)
|
||||
# print(tri_out)
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
|
||||
assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=0)
|
||||
assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0)
|
||||
assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
HAS_FLASH = True
|
||||
except BaseException:
|
||||
HAS_FLASH = False
|
||||
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = [triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
# x_vals=[2**i for i in range(10, 14)],
|
||||
x_vals=[2**i for i in range(10, 11)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
|
||||
# ) for mode in ['fwd', 'bwd']]
|
||||
) for mode in ['fwd']]
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
|
||||
assert mode in ['fwd', 'bwd']
|
||||
# warmup = 25
|
||||
# rep = 100
|
||||
warmup = 0
|
||||
rep = 1
|
||||
if provider == "triton":
|
||||
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
sm_scale = 1.3
|
||||
fn = lambda: attention(q, k, v, sm_scale)
|
||||
if mode == 'bwd':
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
if provider == "flash":
|
||||
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
||||
cu_seqlens = torch.zeros(
|
||||
(BATCH + 1,), device=device, dtype=torch.int32)
|
||||
cu_seqlens[1:] = lengths.cumsum(0)
|
||||
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
|
||||
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
|
||||
if mode == 'bwd':
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
||||
|
||||
# # only works on post-Ampere GPUs right now
|
||||
# bench_flash_attention.run(save_path='.', print_data=True)
|
||||
987
python/test/unit/hopper/test_persistent_warp_specialized_gemm.py
Normal file
987
python/test/unit/hopper/test_persistent_warp_specialized_gemm.py
Normal file
@@ -0,0 +1,987 @@
|
||||
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files
|
||||
# (the "Software"), to deal in the Software without restriction,
|
||||
# including without limitation the rights to use, copy, modify, merge,
|
||||
# publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
# and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from .utils import get_proper_err, get_variant_golden
|
||||
|
||||
|
||||
def isMMAV3OrTMAEnabled():
|
||||
import os
|
||||
for k in ('ENABLE_MMA_V3', 'ENABLE_TMA'):
|
||||
if os.environ.get(k, '0').lower() in ['1', 'on', 'true']:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@triton.jit
|
||||
def static_persistent_matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
NUM_SM: tl.constexpr,
|
||||
):
|
||||
start_tile = tl.program_id(axis=0)
|
||||
m_tiles = tl.cdiv(M, BLOCK_M)
|
||||
n_tiles = tl.cdiv(N, BLOCK_N)
|
||||
num_tiles = m_tiles * n_tiles
|
||||
offs_k = tl.arange(0, BLOCK_K)
|
||||
|
||||
for tile_id in range(start_tile, num_tiles, NUM_SM):
|
||||
pid_m = tile_id // n_tiles
|
||||
pid_n = tile_id % n_tiles
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
|
||||
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
for k in range(0, K, BLOCK_K):
|
||||
a = tl.load(a_ptrs)
|
||||
b = tl.load(b_ptrs)
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_K * stride_ak
|
||||
b_ptrs += BLOCK_K * stride_bk
|
||||
|
||||
offs_cm = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offs_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
|
||||
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
|
||||
tl.store(c_ptrs, accumulator)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def static_persistent_tma_matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
NUM_SM: tl.constexpr,
|
||||
):
|
||||
start_tile = tl.program_id(axis=0)
|
||||
m_tiles = tl.cdiv(M, BLOCK_M)
|
||||
n_tiles = tl.cdiv(N, BLOCK_N)
|
||||
k_tiles = tl.cdiv(K, BLOCK_K)
|
||||
num_tiles = m_tiles * n_tiles
|
||||
|
||||
pre_pid_m = start_tile // n_tiles
|
||||
pre_pid_n = start_tile % n_tiles
|
||||
|
||||
block_offset_m = pre_pid_m * BLOCK_M
|
||||
block_offset_n = pre_pid_n * BLOCK_N
|
||||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
for tile_id in range(start_tile, num_tiles, NUM_SM):
|
||||
pid_m = tile_id // n_tiles
|
||||
pid_n = tile_id % n_tiles
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
if tile_id >= NUM_SM:
|
||||
a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, -k_tiles * BLOCK_K])
|
||||
b_tile_ptr = tl.advance(b_tile_ptr, [-k_tiles * BLOCK_K, (pid_n - pre_pid_n) * BLOCK_N])
|
||||
|
||||
for k in range(0, K, BLOCK_K):
|
||||
a = tl.load(a_tile_ptr)
|
||||
b = tl.load(b_tile_ptr)
|
||||
accumulator += tl.dot(a, b)
|
||||
a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_K])
|
||||
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_K, 0])
|
||||
|
||||
offs_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offs_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
|
||||
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
||||
tl.store(c_ptrs, accumulator)
|
||||
pre_pid_m = pid_m
|
||||
pre_pid_n = pid_n
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,TRANS_A,TRANS_B', [
|
||||
[4096, 4096, 64, 64, 64, 16, 4, 1, False, True],
|
||||
[4096, 4096, 64, 64, 64, 32, 4, 1, False, True],
|
||||
[4096, 4096, 64, 256, 64, 16, 4, 1, False, True],
|
||||
[4096, 4096, 64, 128, 128, 16, 4, 1, False, True],
|
||||
# TODO: fix issue for 8-warp persistent kernel
|
||||
# [4096, 4096, 64, 128, 128, 16, 8, 1, False, True],
|
||||
# [4096, 4096, 64, 128, 256, 16, 8, 1, False, True],
|
||||
])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, TRANS_A, TRANS_B):
|
||||
# TODO: fix RewriteTensorPtrPass
|
||||
pytest.skip('RewriteTensorPtrPass issue')
|
||||
|
||||
if (TRANS_A):
|
||||
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
a = .1 * torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
|
||||
if (TRANS_B):
|
||||
b = .1 * torch.randn((N, K), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
b = .1 * torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
|
||||
grid = lambda META: (num_SMs,)
|
||||
|
||||
def call_vintage():
|
||||
static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
|
||||
return c
|
||||
|
||||
def call_stylish():
|
||||
static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
|
||||
return c
|
||||
|
||||
th_c = torch.matmul(a, b)
|
||||
|
||||
# Test using old style of ptr calculation
|
||||
tt_c = call_vintage()
|
||||
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
|
||||
|
||||
# Cealr c
|
||||
c = torch.randn((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
# Test using make_block_ptr
|
||||
tt_c = call_stylish()
|
||||
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def warp_specialized_matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
):
|
||||
tid = tl.program_id(axis=0)
|
||||
n_tiles = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = tid // n_tiles
|
||||
pid_n = tid % n_tiles
|
||||
|
||||
offs_k = tl.arange(0, BLOCK_K)
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
offs_am = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
offs_bn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
|
||||
for k in range(0, K, BLOCK_K):
|
||||
a = tl.load(a_ptrs)
|
||||
b = tl.load(b_ptrs)
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_K * stride_ak
|
||||
b_ptrs += BLOCK_K * stride_bk
|
||||
accumulator = accumulator.to(c_ptr.dtype.element_ty)
|
||||
|
||||
offs_cm = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offs_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
|
||||
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
|
||||
mask = (offs_cm < M)[:, None] & (offs_cn < N)[None, :]
|
||||
tl.store(c_ptrs, accumulator, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tma_warp_specialized_matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
):
|
||||
tid = tl.program_id(axis=0)
|
||||
n_tiles = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = tid // n_tiles
|
||||
pid_n = tid % n_tiles
|
||||
|
||||
block_offset_m = pid_m * BLOCK_M
|
||||
block_offset_n = pid_n * BLOCK_N
|
||||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
|
||||
for k in range(0, K, BLOCK_K):
|
||||
a = tl.load(a_tile_ptr)
|
||||
b = tl.load(b_tile_ptr)
|
||||
accumulator += tl.dot(a, b)
|
||||
a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_K])
|
||||
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_K, 0])
|
||||
accumulator = accumulator.to(c_ptr.dtype.element_ty)
|
||||
|
||||
offs_cm = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offs_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
|
||||
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
|
||||
mask = (offs_cm < M)[:, None] & (offs_cn < N)[None, :]
|
||||
tl.store(c_ptrs, accumulator, mask=mask)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B', [
|
||||
[2048, 2048, 64, 64, 64, 16, 1, False, True],
|
||||
[4096, 4096, 64, 64, 64, 16, 1, False, True],
|
||||
[128, 4096, 64, 64, 64, 16, 1, False, True],
|
||||
[4096, 128, 64, 64, 64, 16, 1, False, True],
|
||||
[4096, 4096, 64, 64, 64, 32, 1, False, True],
|
||||
[4096, 4096, 256, 128, 128, 16, 1, False, True],
|
||||
[4096, 4096, 320, 128, 64, 64, 1, False, True],
|
||||
[4096, 4096, 320, 64, 128, 64, 1, False, True],
|
||||
[4096, 4096, 320, 128, 128, 64, 1, False, True],
|
||||
[4096, 4096, 256, 256, 64, 16, 1, False, True],
|
||||
[4096, 4096, 256, 256, 64, 64, 1, False, True],
|
||||
[4096, 4096, 256, 64, 256, 16, 1, False, True],
|
||||
[4096, 4096, 256, 64, 256, 64, 1, False, True],
|
||||
[4096, 4096, 256, 256, 128, 16, 1, False, True],
|
||||
[4096, 4096, 256, 256, 128, 64, 1, False, True],
|
||||
[4096, 4096, 256, 128, 256, 16, 1, False, True],
|
||||
[4096, 4096, 256, 128, 256, 64, 1, False, True],
|
||||
# numCTAs > 1
|
||||
[2048, 2048, 64, 128, 128, 64, 2, False, True],
|
||||
[2048, 2048, 64, 128, 128, 64, 2, False, True],
|
||||
[2048, 2048, 128, 256, 128, 64, 4, False, True],
|
||||
[4096, 4096, 128, 256, 128, 64, 4, False, True],
|
||||
[4096, 4096, 256, 128, 256, 64, 4, False, True],
|
||||
[4096, 4096, 256, 256, 256, 64, 4, False, True],
|
||||
])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B):
|
||||
pytest.skip('hang')
|
||||
|
||||
if (TRANS_A):
|
||||
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
a = .1 * torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
|
||||
if (TRANS_B):
|
||||
b = .1 * torch.randn((N, K), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
b = .1 * torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
grid = lambda META: (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),)
|
||||
|
||||
def call_vintage():
|
||||
warp_specialized_matmul_kernel[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
BLOCK_M, BLOCK_N, BLOCK_K,
|
||||
num_warps=4,
|
||||
num_ctas=NUM_CTAS,
|
||||
enable_warp_specialization=True)
|
||||
return c
|
||||
|
||||
def call_stylish():
|
||||
tma_warp_specialized_matmul_kernel[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
BLOCK_M, BLOCK_N, BLOCK_K,
|
||||
num_warps=4,
|
||||
num_ctas=NUM_CTAS,
|
||||
enable_warp_specialization=True)
|
||||
return c
|
||||
|
||||
th_c = torch.matmul(a, b)
|
||||
|
||||
# Test using old style of ptr calculation
|
||||
tt_c = call_vintage()
|
||||
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
|
||||
|
||||
# Cealr c
|
||||
c = torch.randn((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
# # Test using make_block_ptr
|
||||
tt_c = call_stylish()
|
||||
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
|
||||
|
||||
# # #############################################Performance Evaluation#############################################
|
||||
# fn = lambda: call_vintage()
|
||||
# ms = triton.testing.do_bench(fn, warmup=25, rep=100)
|
||||
# cur_gpu_perf = round(2. * M * N * K / ms * 1e-9, 2)
|
||||
# print(' '.join(['Performance of', str(M), str(N), str(K), ':', str(ms), 'ms, ', str(cur_gpu_perf), 'TFLOPS']))
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=3, num_warps=4, enable_warp_specialization=True),
|
||||
# triton.Config({}, num_stages=3, num_warps=4, enable_warp_specialization=False),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@triton.jit
|
||||
def static_persistent_warp_specialized_matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
NUM_SM: tl.constexpr,
|
||||
):
|
||||
start_tile = tl.program_id(axis=0)
|
||||
m_tiles = tl.cdiv(M, BLOCK_M)
|
||||
n_tiles = tl.cdiv(N, BLOCK_N)
|
||||
num_tiles = m_tiles * n_tiles
|
||||
offs_k = tl.arange(0, BLOCK_K)
|
||||
|
||||
for tile_id in range(start_tile, num_tiles, NUM_SM):
|
||||
pid_m = tile_id // n_tiles
|
||||
pid_n = tile_id % n_tiles
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
|
||||
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
for k in range(0, K, BLOCK_K):
|
||||
a = tl.load(a_ptrs)
|
||||
b = tl.load(b_ptrs)
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_K * stride_ak
|
||||
b_ptrs += BLOCK_K * stride_bk
|
||||
|
||||
offs_cm = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offs_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
|
||||
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
|
||||
tl.store(c_ptrs, accumulator)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=3, num_warps=4, enable_warp_specialization=True),
|
||||
# triton.Config({}, num_stages=3, num_warps=4, enable_warp_specialization=False),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@triton.jit
|
||||
def static_persistent_tma_warp_specialized_matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
NUM_SM: tl.constexpr,
|
||||
):
|
||||
start_tile = tl.program_id(axis=0)
|
||||
m_tiles = tl.cdiv(M, BLOCK_M)
|
||||
n_tiles = tl.cdiv(N, BLOCK_N)
|
||||
k_tiles = tl.cdiv(K, BLOCK_K)
|
||||
num_tiles = m_tiles * n_tiles
|
||||
|
||||
pre_pid_m = start_tile // n_tiles
|
||||
pre_pid_n = start_tile % n_tiles
|
||||
|
||||
block_offset_m = pre_pid_m * BLOCK_M
|
||||
block_offset_n = pre_pid_n * BLOCK_N
|
||||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
for tile_id in range(start_tile, num_tiles, NUM_SM):
|
||||
pid_m = tile_id // n_tiles
|
||||
pid_n = tile_id % n_tiles
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
if tile_id >= NUM_SM:
|
||||
a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, -k_tiles * BLOCK_K])
|
||||
b_tile_ptr = tl.advance(b_tile_ptr, [-k_tiles * BLOCK_K, (pid_n - pre_pid_n) * BLOCK_N])
|
||||
|
||||
for k in range(0, K, BLOCK_K):
|
||||
a = tl.load(a_tile_ptr)
|
||||
b = tl.load(b_tile_ptr)
|
||||
accumulator += tl.dot(a, b)
|
||||
a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_K])
|
||||
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_K, 0])
|
||||
|
||||
offs_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offs_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
|
||||
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
||||
tl.store(c_ptrs, accumulator)
|
||||
pre_pid_m = pid_m
|
||||
pre_pid_n = pid_n
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B', [
|
||||
[2048, 2048, 64, 64, 64, 16, 1, False, True],
|
||||
[4096, 4096, 64, 64, 64, 16, 1, False, True],
|
||||
[128, 4096, 64, 64, 64, 16, 1, False, True],
|
||||
[4096, 128, 64, 64, 64, 16, 1, False, True],
|
||||
[4096, 4096, 64, 64, 64, 32, 1, False, True],
|
||||
[4096, 4096, 256, 128, 128, 16, 1, False, True],
|
||||
[4096, 4096, 320, 128, 64, 64, 1, False, True],
|
||||
[4096, 4096, 320, 64, 128, 64, 1, False, True],
|
||||
[4096, 4096, 320, 128, 128, 64, 1, False, True],
|
||||
[4096, 4096, 256, 256, 64, 16, 1, False, True],
|
||||
[4096, 4096, 256, 256, 64, 64, 1, False, True],
|
||||
[4096, 4096, 256, 64, 256, 16, 1, False, True],
|
||||
[4096, 4096, 256, 64, 256, 64, 1, False, True],
|
||||
[4096, 4096, 256, 256, 128, 16, 1, False, True],
|
||||
[4096, 4096, 256, 256, 128, 64, 1, False, True],
|
||||
[4096, 4096, 256, 128, 256, 16, 1, False, True],
|
||||
[4096, 4096, 256, 128, 256, 64, 1, False, True],
|
||||
])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B):
|
||||
# TODO: fix RewriteTensorPtrPass
|
||||
pytest.skip('RewriteTensorPtrPass issue')
|
||||
if (TRANS_A):
|
||||
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
a = .1 * torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
|
||||
if (TRANS_B):
|
||||
b = .1 * torch.randn((N, K), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
b = .1 * torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
|
||||
grid = lambda META: (num_SMs,)
|
||||
|
||||
def call_vintage():
|
||||
static_persistent_warp_specialized_matmul_kernel[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs)
|
||||
return c
|
||||
|
||||
def call_stylish():
|
||||
static_persistent_tma_warp_specialized_matmul_kernel[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs)
|
||||
return c
|
||||
|
||||
th_c = torch.matmul(a, b)
|
||||
|
||||
# Test using old style of ptr calculation
|
||||
tt_c = call_vintage()
|
||||
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
|
||||
|
||||
# Cealr c
|
||||
c = torch.randn((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
# Test using make_block_ptr
|
||||
tt_c = call_stylish()
|
||||
torch.testing.assert_allclose(th_c, tt_c, atol=1e-2, rtol=0)
|
||||
|
||||
# #############################################Performance Evaluation#############################################
|
||||
# fn = lambda: call_stylish()
|
||||
# ms = triton.testing.do_bench(fn, warmup=25, rep=100)
|
||||
# cur_gpu_perf = round(2. * M * N * K / ms * 1e-9, 2)
|
||||
# print(' '.join(['Performance of', str(M), str(N), str(K), ':', str(ms), 'ms, ', str(cur_gpu_perf), 'TFLOPS']))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def static_persistent_matmul_no_scf_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr,
|
||||
NUM_SM: tl.constexpr, USE_TMA_LOAD: tl.constexpr,
|
||||
):
|
||||
start_tile = tl.program_id(axis=0)
|
||||
m_tiles = tl.cdiv(M, BLOCK_M)
|
||||
n_tiles = tl.cdiv(N, BLOCK_N)
|
||||
num_tiles = m_tiles * n_tiles
|
||||
offs_k = tl.arange(0, BLOCK_K)
|
||||
pre_pid_m = start_tile // n_tiles
|
||||
pre_pid_n = start_tile % n_tiles
|
||||
block_offset_m = pre_pid_m * BLOCK_M
|
||||
block_offset_n = pre_pid_n * BLOCK_N
|
||||
|
||||
if USE_TMA_LOAD:
|
||||
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
if USE_TMA_EPILOGUE:
|
||||
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
|
||||
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
|
||||
|
||||
for tile_id in range(start_tile, num_tiles, NUM_SM):
|
||||
pid_m = tile_id // n_tiles
|
||||
pid_n = tile_id % n_tiles
|
||||
|
||||
if USE_TMA_LOAD:
|
||||
a_block_ptr = tl.advance(a_block_ptr, [(pid_m - pre_pid_m) * BLOCK_M, 0])
|
||||
b_block_ptr = tl.advance(b_block_ptr, [0, (pid_n - pre_pid_n) * BLOCK_N])
|
||||
a = tl.load(a_block_ptr)
|
||||
b = tl.load(b_block_ptr)
|
||||
else:
|
||||
offs_am = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_bn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
a = tl.load(a_ptrs)
|
||||
b = tl.load(b_ptrs)
|
||||
|
||||
c = tl.dot(a, b)
|
||||
|
||||
if FLOAT16_OUTPUT:
|
||||
c = c.to(tl.float16)
|
||||
|
||||
if USE_TMA_EPILOGUE:
|
||||
c_block_ptr = tl.advance(c_block_ptr, [(pid_m - pre_pid_m) * BLOCK_M, (pid_n - pre_pid_n) * BLOCK_N])
|
||||
tl.store(c_block_ptr, c)
|
||||
else:
|
||||
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
||||
tl.store(c_ptrs, c)
|
||||
|
||||
pre_pid_m = pid_m
|
||||
pre_pid_n = pid_n
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,USE_TMA_LOAD',
|
||||
itertools.chain(
|
||||
*[
|
||||
[
|
||||
# numCTAs = 1, no TMA multicast:
|
||||
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
# small M, N
|
||||
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
] for USE_TMA_EPILOGUE in [True, False]
|
||||
for USE_TMA_LOAD in [True, False]
|
||||
]))
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, USE_TMA_LOAD):
|
||||
if isMMAV3OrTMAEnabled():
|
||||
pytest.skip("known failure")
|
||||
if (TRANS_A):
|
||||
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
if (TRANS_B):
|
||||
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
|
||||
if OUTPUT_TYPE == "float16":
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
|
||||
else:
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
|
||||
|
||||
# TODO: set `enable_warp_specialization=False` will lead to compilation error.
|
||||
static_persistent_matmul_no_scf_kernel[(num_SMs,)](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
M=M, N=N, K=K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SM=num_SMs,
|
||||
num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS,
|
||||
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"),
|
||||
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE,
|
||||
USE_TMA_LOAD=USE_TMA_LOAD,
|
||||
enable_warp_specialization=True)
|
||||
a_f32 = a.to(torch.float32)
|
||||
b_f32 = b.to(torch.float32)
|
||||
golden = torch.matmul(a_f32, b_f32)
|
||||
golden_variant = get_variant_golden(a_f32, b_f32)
|
||||
golden_abs_err, golden_rel_err = get_proper_err(golden, golden_variant)
|
||||
torch.set_printoptions(profile="full")
|
||||
assert_close(
|
||||
c,
|
||||
golden,
|
||||
rtol=max(1e-2, 1.1 * golden_rel_err),
|
||||
atol=max(1e-3, 1.1 * golden_abs_err),
|
||||
check_dtype=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def full_static_persistent_matmul_kernel(
|
||||
a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_wm, stride_wn,
|
||||
stride_zm, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
|
||||
out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr,
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
||||
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
|
||||
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr,
|
||||
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr,
|
||||
NUM_SM: tl.constexpr
|
||||
):
|
||||
start_pid = tl.program_id(axis=0)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_N)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_M)
|
||||
num_tiles = num_pid_m * num_pid_n
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = start_pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pre_pid_m = first_pid_m + ((start_pid % num_pid_in_group) % group_size_m)
|
||||
pre_pid_n = (start_pid % num_pid_in_group) // group_size_m
|
||||
|
||||
pre_block_offset_m = pre_pid_m * BLOCK_M
|
||||
pre_block_offset_n = pre_pid_n * BLOCK_N
|
||||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(pre_block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(A_ORDER_0, A_ORDER_1))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, pre_block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(B_ORDER_0, B_ORDER_1))
|
||||
w_tile_ptr = tl.make_block_ptr(base=w_ptr, shape=(N, N), strides=(stride_wm, stride_wn),
|
||||
offsets=(0, pre_block_offset_n), block_shape=(BLOCK_N, BLOCK_N), order=(0, 1))
|
||||
|
||||
if USE_TMA_STORE:
|
||||
z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn),
|
||||
offsets=(pre_block_offset_m, pre_block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
|
||||
|
||||
for tile_id in range(start_pid, num_tiles, NUM_SM):
|
||||
group_id = tile_id // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
|
||||
pid_n = (tile_id % num_pid_in_group) // group_size_m
|
||||
block_offset_m = pid_m * BLOCK_M
|
||||
block_offset_n = pid_n * BLOCK_N
|
||||
|
||||
offs_m = block_offset_m + tl.arange(0, BLOCK_M)
|
||||
offs_n = block_offset_n + tl.arange(0, BLOCK_N)
|
||||
z_ptrs = z_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn
|
||||
bias_ptrs = bias_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn
|
||||
mask = (offs_m < M)[:, None] & (offs_n < N)[None, :]
|
||||
|
||||
# TODO: lib/Dialect/TritonGPU/Transforms/RewriteTensorPointer.cpp does not support scf.if yet.
|
||||
# if tile_id >= NUM_SM:
|
||||
# a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, -tl.cdiv(K, BLOCK_K) * BLOCK_K])
|
||||
# b_tile_ptr = tl.advance(b_tile_ptr, [-tl.cdiv(K, BLOCK_K) * BLOCK_K, (pid_n - pre_pid_n) * BLOCK_N])
|
||||
|
||||
a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, 0])
|
||||
b_tile_ptr = tl.advance(b_tile_ptr, [0, (pid_n - pre_pid_n) * BLOCK_N])
|
||||
z = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(0, K, BLOCK_K):
|
||||
a = tl.load(a_tile_ptr, boundary_check=(0, 1))
|
||||
b = tl.load(b_tile_ptr, boundary_check=(0, 1))
|
||||
z += tl.dot(a, b)
|
||||
a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_K])
|
||||
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_K, 0])
|
||||
a_tile_ptr = tl.advance(a_tile_ptr, [0, -tl.cdiv(K, BLOCK_K) * BLOCK_K])
|
||||
b_tile_ptr = tl.advance(b_tile_ptr, [-tl.cdiv(K, BLOCK_K) * BLOCK_K, 0])
|
||||
|
||||
if (out_dtype == tl.constexpr(tl.float16)):
|
||||
z = z.to(tl.float16)
|
||||
|
||||
if ADD_MATRIX:
|
||||
z += tl.load(bias_ptrs, mask=mask)
|
||||
if ADD_ROWS:
|
||||
ZRs = bias_ptr + offs_m * stride_zm
|
||||
z += tl.load(ZRs)[:, None]
|
||||
if ADD_COLS:
|
||||
ZCs = bias_ptr + offs_n * stride_zn
|
||||
z += tl.load(ZCs)[None, :]
|
||||
if DO_SOFTMAX:
|
||||
max = tl.max(z, 1)
|
||||
z = z - max[:, None]
|
||||
num = tl.exp(z.to(tl.float32)).to(max.dtype)
|
||||
den = tl.sum(num, 1)
|
||||
z = num / den[:, None]
|
||||
if CHAIN_DOT:
|
||||
w = tl.load(w_tile_ptr)
|
||||
w_tile_ptr = tl.advance(w_tile_ptr, [0, (pid_n - pre_pid_n) * BLOCK_N])
|
||||
z = tl.dot(z.to(w.dtype), w)
|
||||
if (out_dtype == tl.constexpr(tl.float16)):
|
||||
z = z.to(tl.float16)
|
||||
|
||||
if USE_TMA_STORE:
|
||||
z_block_ptr = tl.advance(z_block_ptr, [(pid_m - pre_pid_m) * BLOCK_M, (pid_n - pre_pid_n) * BLOCK_N])
|
||||
tl.store(z_block_ptr, z, boundary_check=(0, 1))
|
||||
else:
|
||||
tl.store(z_ptrs, z, mask=mask)
|
||||
|
||||
pre_pid_m = pid_m
|
||||
pre_pid_n = pid_n
|
||||
|
||||
|
||||
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
|
||||
[
|
||||
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws)
|
||||
for shape_w_c in [
|
||||
# bad from cublas-important-layers
|
||||
[4096, 1, 1024, False, False],
|
||||
[2048, 204, 1000, True, False],
|
||||
]
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for enable_ws in [True]
|
||||
] + [
|
||||
(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
# softmax works for one CTA
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 64, 64, 64],
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[16, 16, 64, 4, 1, 16, 16, 64],
|
||||
# TODO: enable when num_warps != 4 is supported.
|
||||
# [64, 64, 32, 8, 1, 64, 64, 64],
|
||||
[128, 128, 64, 4, 1, 128, 128, 128],
|
||||
]
|
||||
for epilogue in ['softmax']
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False, True]
|
||||
for trans_b in [False, True]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [True]
|
||||
] + [
|
||||
(*shape_w_c, trans_a, trans_b, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 32, 4, 1, 128, 256, 64],
|
||||
# TODO: enable when num_ctas != 1 is supported.
|
||||
# [128, 128, 16, 4, 4, 512, 256, 64],
|
||||
# [128, 256, 32, 4, 8, 256, 256, 192],
|
||||
# [512, 256, 32, 4, 8, 1024, 256, 192],
|
||||
# BLOCK_K >= 128
|
||||
[64, 128, 128, 4, 1, 512, 256, 256],
|
||||
[128, 128, 128, 4, 1, 256, 256, 192],
|
||||
# TODO: enable when num_ctas != 1 is supported.
|
||||
# [128, 128, 128, 4, 2, 256, 256, 192],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 32, 32, 4, 1, 128, 256, 64],
|
||||
[32, 32, 16, 4, 1, 256, 256, 192],
|
||||
# TODO: enable when num_ctas != 1 is supported.
|
||||
# [16, 32, 64, 4, 4, 512, 256, 64],
|
||||
]
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False, True]
|
||||
for trans_b in [False, True]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [True]
|
||||
] + [(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 128, 128, 64],
|
||||
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4] for num_ctas in [1]],
|
||||
# for chain-dot
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[64, 64, 16, 4, 1, None, None, None],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 16, 64, 4, 1, 128, 128, 64],
|
||||
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4] for num_ctas in [1]],
|
||||
# # TODO: enable when num_warps != 4 is supported.
|
||||
# # repeat
|
||||
# # [64, 64, 32, 8, 1, 128, 256, 64],
|
||||
# # [64, 64, 16, 8, 2, 128, 128, 64],
|
||||
# irregular shape
|
||||
[128, 128, 64, 4, 1, 500, 200, 128],
|
||||
[128, 128, 64, 4, 1, 513, 193, 192],
|
||||
]
|
||||
for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot']
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False, True]
|
||||
for trans_b in [False, True]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [True]
|
||||
if not (epilogue == 'chain-dot' and (shape_w_c[5] is not None or shape_w_c[0] != shape_w_c[1]))
|
||||
] + [
|
||||
(64, n, 16, 4, 1, 512, 256, 256, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
# loop over instr shapes
|
||||
for n in [16, 32, 64, 128, 256]
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for num_stages in [2, 4, 5, 7]
|
||||
for enable_ws in [True]
|
||||
] + [
|
||||
(*shape_w_c, *shape, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
# irregular shapes
|
||||
for shape_w_c in [
|
||||
[128, 128, 64, 4, 1]
|
||||
# [256, 128, 64, 4, 2],
|
||||
# [128, 128, 128, 4, 2],
|
||||
]
|
||||
for shape in list(itertools.product([*range(512, 4096, 360)], [*range(512, 4096, 360)], [512, 1024]))
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for num_stages in [2, 3, 4]
|
||||
for enable_ws in [True]
|
||||
]
|
||||
)
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()
|
||||
[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
|
||||
pytest.skip("known failure, will fix it later!!!")
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS])) in [
|
||||
'128-128-128-4-1-256-256-192-none-float32-True-3-True',
|
||||
]:
|
||||
pytest.skip('out of resource: shared memory, Required: 263168')
|
||||
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, USE_TMA_STORE, ENABLE_WS])) in ([
|
||||
'64-16-16-4-1-512-256-256-True-True',
|
||||
] + [
|
||||
f'128-128-64-4-1-{m}-{n}-{k}-True-True' for m in range(512, 4096, 360) for n in range(512, 4096, 360) for k in [512, 1024]
|
||||
]):
|
||||
pytest.skip('known kernel hang problem when tma store is enabled')
|
||||
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
|
||||
'16-32-64-4-4-512-256-64-True-False',
|
||||
'16-32-64-4-4-512-256-64-True-True',
|
||||
'16-32-64-4-4-512-256-64-False-False',
|
||||
'16-32-64-4-4-512-256-64-False-True',
|
||||
]:
|
||||
pytest.skip('shapePerCTA[1] < 16 not supported')
|
||||
|
||||
if epilogue == 'chain-dot':
|
||||
pytest.skip('known failure: Assertion !region.empty() && unexpected empty region.')
|
||||
|
||||
M = BLOCK_M if M is None else M
|
||||
N = BLOCK_N if N is None else N
|
||||
K = BLOCK_K if K is None else K
|
||||
|
||||
if (TRANS_A):
|
||||
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
a_order = [0, 1]
|
||||
else:
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
a_order = [1, 0]
|
||||
|
||||
if (TRANS_B):
|
||||
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
|
||||
b_order = [0, 1]
|
||||
else:
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
b_order = [1, 0]
|
||||
|
||||
if out_dtype == 'float16' and epilogue != 'softmax':
|
||||
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
|
||||
# fail with the following error: 'llvm.fmul' op requires the same type
|
||||
# for all operands and results
|
||||
out_dtype = tl.float16
|
||||
torch_out_dtype = torch.float16
|
||||
else:
|
||||
out_dtype = tl.float32
|
||||
torch_out_dtype = torch.float32
|
||||
|
||||
# avoid out of memory
|
||||
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
|
||||
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
|
||||
else:
|
||||
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
|
||||
|
||||
if epilogue == 'chain-dot':
|
||||
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
w = torch.randn((1, 1), device='cuda', dtype=torch.float16).T
|
||||
|
||||
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
|
||||
|
||||
# torch result
|
||||
a_f32 = a.to(torch.float32)
|
||||
b_f32 = b.to(torch.float32)
|
||||
dot = torch.matmul(a_f32, b_f32)
|
||||
|
||||
def process_epilogue(d, bias, w, epilogue):
|
||||
if epilogue == 'add-matrix':
|
||||
ref = d + bias
|
||||
elif epilogue == 'add-rows':
|
||||
ref = d + bias[:, 0][:, None]
|
||||
elif epilogue == 'add-cols':
|
||||
ref = d + bias[0, :][None, :]
|
||||
elif epilogue == 'softmax':
|
||||
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
|
||||
denom = torch.sum(num, dim=-1, keepdims=True)
|
||||
ref = num / denom
|
||||
# ref = torch.softmax(d, 1)
|
||||
elif epilogue == 'chain-dot':
|
||||
ref = torch.matmul(d, w.to(torch.float32))
|
||||
else:
|
||||
ref = d
|
||||
return ref
|
||||
golden = process_epilogue(dot, bias, w, epilogue)
|
||||
|
||||
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
|
||||
|
||||
def grid(META):
|
||||
return (num_SMs,)
|
||||
full_static_persistent_matmul_kernel[grid](
|
||||
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z,
|
||||
M=M, N=N, K=K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_wm=w.stride(0), stride_wn=w.stride(1),
|
||||
stride_zm=z.stride(0), stride_zn=z.stride(1),
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8,
|
||||
out_dtype=out_dtype,
|
||||
USE_TMA_STORE=USE_TMA_STORE,
|
||||
ADD_MATRIX=epilogue == 'add-matrix',
|
||||
ADD_ROWS=epilogue == 'add-rows',
|
||||
ADD_COLS=epilogue == 'add-cols',
|
||||
DO_SOFTMAX=epilogue == 'softmax',
|
||||
CHAIN_DOT=epilogue == 'chain-dot',
|
||||
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1],
|
||||
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1],
|
||||
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES,
|
||||
enable_warp_specialization=ENABLE_WS,
|
||||
NUM_SM=num_SMs)
|
||||
|
||||
torch.set_printoptions(profile="full")
|
||||
golden = torch.nn.functional.normalize(golden)
|
||||
z = torch.nn.functional.normalize(z)
|
||||
assert_close(z, golden,
|
||||
rtol=1e-2,
|
||||
atol=1e-3,
|
||||
check_dtype=False)
|
||||
123
python/test/unit/hopper/test_tma_store_gemm.py
Normal file
123
python/test/unit/hopper/test_tma_store_gemm.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files
|
||||
# (the "Software"), to deal in the Software without restriction,
|
||||
# including without limitation the rights to use, copy, modify, merge,
|
||||
# publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
# and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def get_variant_golden(a, b):
|
||||
SIZE_M = a.shape[0]
|
||||
SIZE_K = a.shape[1]
|
||||
SIZE_N = b.shape[1]
|
||||
assert a.shape[1] == b.shape[0]
|
||||
zero_M_K = torch.zeros((SIZE_M, SIZE_K)).cuda()
|
||||
zero_3M_K = torch.zeros((3 * SIZE_M, SIZE_K)).cuda()
|
||||
zero_K_N = torch.zeros((SIZE_K, SIZE_N)).cuda()
|
||||
zero_3K_N = torch.zeros((3 * SIZE_K, SIZE_N)).cuda()
|
||||
a_padded = torch.cat((a, zero_M_K, zero_M_K), 0)
|
||||
a_padded = torch.cat((a_padded, zero_3M_K, zero_3M_K), 1)
|
||||
b_padded = torch.cat((b, zero_K_N, zero_K_N), 0)
|
||||
b_padded = torch.cat((b_padded, zero_3K_N, zero_3K_N), 1)
|
||||
c_padded = torch.matmul(a_padded, b_padded)
|
||||
return c_padded[:SIZE_M, :SIZE_N]
|
||||
|
||||
# It's not easy to get a proper error threshold in different size
|
||||
# Here the gemm calculation is padded to a different size in order to get
|
||||
# a variant version of the golden result. And the error between golden and
|
||||
# golden_variant provide reference on selecting the proper rtol / atol.
|
||||
|
||||
|
||||
def get_proper_err(a, b, golden):
|
||||
golden_variant = get_variant_golden(a, b)
|
||||
golden_diff = golden - golden_variant
|
||||
golden_abs_err = torch.max(torch.abs(golden_diff)).item()
|
||||
golden_rel_err = torch.max(torch.abs(golden_diff / golden)).item()
|
||||
return (golden_abs_err, golden_rel_err)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def matmul_tma_load_store(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
OUTPUT_F16: tl.constexpr
|
||||
):
|
||||
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
|
||||
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
|
||||
a = tl.load(a_block_ptr)
|
||||
b = tl.load(b_block_ptr)
|
||||
|
||||
c = tl.dot(a, b)
|
||||
if OUTPUT_F16:
|
||||
c = c.to(tl.float16)
|
||||
|
||||
tl.store(c_block_ptr, c)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_F16', [
|
||||
[64, 64, 16, 1, 4, False, True, False],
|
||||
[64, 64, 16, 1, 4, False, True, True],
|
||||
[128, 64, 32, 1, 4, False, True, False],
|
||||
[128, 64, 32, 1, 4, False, True, True],
|
||||
[64, 128, 32, 1, 4, False, True, False],
|
||||
[64, 128, 32, 1, 4, False, True, True],
|
||||
[128, 128, 64, 1, 4, False, True, False],
|
||||
[128, 128, 64, 1, 4, False, True, True],
|
||||
])
|
||||
def test_tma_load_store(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_F16):
|
||||
if (TRANS_A):
|
||||
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
|
||||
if (TRANS_B):
|
||||
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
|
||||
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
if OUTPUT_F16:
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
|
||||
|
||||
matmul_tma_load_store[(1, 1)](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
M=M, N=N, K=K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K,
|
||||
num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS,
|
||||
OUTPUT_F16=OUTPUT_F16)
|
||||
golden = torch.matmul(a, b)
|
||||
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
|
||||
torch.set_printoptions(profile="full")
|
||||
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)
|
||||
75
python/test/unit/hopper/ttgir_tests/test_tma.py
Normal file
75
python/test/unit/hopper/ttgir_tests/test_tma.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files
|
||||
# (the "Software"), to deal in the Software without restriction,
|
||||
# including without limitation the rights to use, copy, modify, merge,
|
||||
# publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
# and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from test_util import get_proper_err
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize('TTGIR,TRANS_A,TRANS_B', [
|
||||
# TODO: uncomment when it's done
|
||||
# ["wgmma_tma_64_64_16_f16.ttgir", False, True],
|
||||
])
|
||||
def test_tma_wgmma_64_64_16_f16(TTGIR, TRANS_A, TRANS_B):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 9:
|
||||
pytest.skip("Only test wgmma on devices with sm >= 90")
|
||||
|
||||
SIZE_M = 64
|
||||
SIZE_N = 64
|
||||
SIZE_K = 16
|
||||
if (TRANS_A):
|
||||
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
|
||||
|
||||
if (TRANS_B):
|
||||
b = torch.randn((SIZE_N, SIZE_K), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
|
||||
|
||||
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
|
||||
|
||||
ttgir_path = os.path.dirname(__file__) + "/" + TTGIR
|
||||
kernel = triton.compile(ttgir_path)
|
||||
kernel[(1, 1, 1)](a.data_ptr(), b.data_ptr(), c.data_ptr(),
|
||||
SIZE_M, SIZE_N, SIZE_K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0))
|
||||
|
||||
golden = torch.matmul(a, b)
|
||||
torch.set_printoptions(profile="full", sci_mode=False)
|
||||
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
|
||||
assert_close(
|
||||
c,
|
||||
golden,
|
||||
rtol=max(1e-4,
|
||||
1.5 * golden_rel_err),
|
||||
atol=max(
|
||||
1e-4,
|
||||
1.5 * golden_abs_err),
|
||||
check_dtype=False)
|
||||
52
python/test/unit/hopper/ttgir_tests/test_util.py
Normal file
52
python/test/unit/hopper/ttgir_tests/test_util.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files
|
||||
# (the "Software"), to deal in the Software without restriction,
|
||||
# including without limitation the rights to use, copy, modify, merge,
|
||||
# publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
# and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_variant_golden(a, b):
|
||||
SIZE_M = a.shape[0]
|
||||
SIZE_K = a.shape[1]
|
||||
SIZE_N = b.shape[1]
|
||||
assert a.shape[1] == b.shape[0]
|
||||
zero_M_K = torch.zeros((SIZE_M, SIZE_K)).cuda()
|
||||
zero_3M_K = torch.zeros((3 * SIZE_M, SIZE_K)).cuda()
|
||||
zero_K_N = torch.zeros((SIZE_K, SIZE_N)).cuda()
|
||||
zero_3K_N = torch.zeros((3 * SIZE_K, SIZE_N)).cuda()
|
||||
a_padded = torch.cat((a, zero_M_K, zero_M_K), 0)
|
||||
a_padded = torch.cat((a_padded, zero_3M_K, zero_3M_K), 1)
|
||||
b_padded = torch.cat((b, zero_K_N, zero_K_N), 0)
|
||||
b_padded = torch.cat((b_padded, zero_3K_N, zero_3K_N), 1)
|
||||
c_padded = torch.matmul(a_padded, b_padded)
|
||||
return c_padded[:SIZE_M, :SIZE_N]
|
||||
|
||||
# It's not easy to get a proper error threshold in different size
|
||||
# Here the gemm calculation is padded to a different size in order to get
|
||||
# a variant version of the golden result. And the error between golden and
|
||||
# golden_variant provide reference on selecting the proper rtol / atol.
|
||||
|
||||
|
||||
def get_proper_err(a, b, golden):
|
||||
golden_variant = get_variant_golden(a, b)
|
||||
golden_diff = golden - golden_variant
|
||||
golden_abs_err = torch.max(torch.abs(golden_diff)).item()
|
||||
golden_rel_err = torch.max(torch.abs(golden_diff / golden)).item()
|
||||
return (golden_abs_err, golden_rel_err)
|
||||
@@ -0,0 +1,52 @@
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
||||
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0>
|
||||
%3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
|
||||
%6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
||||
%7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0>
|
||||
%8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0>
|
||||
%9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2>
|
||||
%11 = tt.splat %arg4 : (i32) -> tensor<16x1xi32, #blocked2>
|
||||
%12 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<16x1x!tt.ptr<f16>, #blocked2>
|
||||
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
|
||||
%16 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2>
|
||||
%17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0>
|
||||
%18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
|
||||
%19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x16x!tt.ptr<f16>, #blocked0>
|
||||
%20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr<f16>, #blocked0>, tensor<64x16xi32, #blocked0>
|
||||
%21 = tt.load %20 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16xf16, #blocked0>
|
||||
%22 = arith.muli %10, %11 : tensor<16x1xi32, #blocked2>
|
||||
%23 = tt.addptr %12, %22 : tensor<16x1x!tt.ptr<f16>, #blocked2>, tensor<16x1xi32, #blocked2>
|
||||
%24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr<f16>, #blocked2>) -> tensor<16x64x!tt.ptr<f16>, #blocked2>
|
||||
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
|
||||
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
|
||||
%27 = triton_gpu.convert_layout %21 : (tensor<64x16xf16, #blocked0>) -> tensor<64x16xf16, #shared0>
|
||||
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
|
||||
%29 = tt.dot %27, %28, %cst {allowTF32 = true, transA = true, transB = true} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
||||
%33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1>
|
||||
%34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1>
|
||||
%35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
|
||||
%36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1>
|
||||
tt.store %37, %38 : tensor<64x64xf32, #blocked1>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
||||
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0>
|
||||
%3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
|
||||
%6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
||||
%7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0>
|
||||
%8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0>
|
||||
%9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2>
|
||||
%11 = tt.splat %arg4 : (i32) -> tensor<16x1xi32, #blocked2>
|
||||
%12 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<16x1x!tt.ptr<f16>, #blocked2>
|
||||
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
|
||||
%16 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2>
|
||||
%17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0>
|
||||
%18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
|
||||
%19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x16x!tt.ptr<f16>, #blocked0>
|
||||
%20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr<f16>, #blocked0>, tensor<64x16xi32, #blocked0>
|
||||
%21 = tt.load %20 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16xf16, #blocked0>
|
||||
%22 = arith.muli %10, %11 : tensor<16x1xi32, #blocked2>
|
||||
%23 = tt.addptr %12, %22 : tensor<16x1x!tt.ptr<f16>, #blocked2>, tensor<16x1xi32, #blocked2>
|
||||
%24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr<f16>, #blocked2>) -> tensor<16x64x!tt.ptr<f16>, #blocked2>
|
||||
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
|
||||
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
|
||||
%27 = triton_gpu.convert_layout %21 : (tensor<64x16xf16, #blocked0>) -> tensor<64x16xf16, #shared0>
|
||||
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
|
||||
%29 = tt.dot %27, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
||||
%33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1>
|
||||
%34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1>
|
||||
%35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
|
||||
%36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1>
|
||||
tt.store %37, %38 : tensor<64x64xf32, #blocked1>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
||||
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0>
|
||||
%3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
|
||||
%6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
||||
%7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0>
|
||||
%8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0>
|
||||
%9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2>
|
||||
%11 = tt.splat %arg4 : (i32) -> tensor<16x1xi32, #blocked2>
|
||||
%12 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<16x1x!tt.ptr<f16>, #blocked2>
|
||||
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
|
||||
%16 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2>
|
||||
%17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0>
|
||||
%18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
|
||||
%19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x16x!tt.ptr<f16>, #blocked0>
|
||||
%20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr<f16>, #blocked0>, tensor<64x16xi32, #blocked0>
|
||||
%ci0 = arith.constant 0 : i32
|
||||
%i1_true = arith.constant 1 : i1
|
||||
%i1_false = arith.constant 0 : i1
|
||||
%t = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared0>
|
||||
%mask0 = tt.splat %i1_true : (i1) -> tensor<64x16xi1, #blocked0>
|
||||
%t0 = triton_gpu.insert_slice_async %20, %t, %ci0, %mask0 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<64x16x!tt.ptr<f16>, #blocked0> -> tensor<1x64x16xf16, #shared0>
|
||||
triton_gpu.async_commit_group
|
||||
triton_gpu.async_wait {num = 1 : i32}
|
||||
%21 = triton_gpu.extract_slice %t0[%ci0, 0, 0][1, 64, 16][1, 1, 1] : tensor<1x64x16xf16, #shared0> to tensor<64x16xf16, #shared0>
|
||||
%22 = arith.muli %10, %11 : tensor<16x1xi32, #blocked2>
|
||||
%23 = tt.addptr %12, %22 : tensor<16x1x!tt.ptr<f16>, #blocked2>, tensor<16x1xi32, #blocked2>
|
||||
%24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr<f16>, #blocked2>) -> tensor<16x64x!tt.ptr<f16>, #blocked2>
|
||||
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
|
||||
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
|
||||
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
|
||||
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
||||
%33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1>
|
||||
%34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1>
|
||||
%35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
|
||||
%36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1>
|
||||
tt.store %37, %38 : tensor<64x64xf32, #blocked1>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
||||
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0>
|
||||
%3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
|
||||
%6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
||||
%7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0>
|
||||
%8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0>
|
||||
%9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2>
|
||||
%11 = tt.splat %arg4 : (i32) -> tensor<16x1xi32, #blocked2>
|
||||
%12 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<16x1x!tt.ptr<f16>, #blocked2>
|
||||
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
|
||||
%16 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2>
|
||||
%17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0>
|
||||
%18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
|
||||
%19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x16x!tt.ptr<f16>, #blocked0>
|
||||
%20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr<f16>, #blocked0>, tensor<64x16xi32, #blocked0>
|
||||
%ci0 = arith.constant 0 : i32
|
||||
%i1_true = arith.constant 1 : i1
|
||||
%i1_false = arith.constant 0 : i1
|
||||
%t = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared0>
|
||||
// TODO: even an empty init external call here will break the UT
|
||||
%mbar = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : tensor<1xi64, #shared0>
|
||||
// %mbar0 = triton_nvidia_gpu.extract_mbarrier %mbar[%ci0] : tensor<1xi64, #shared0>, i32 -> !tt.ptr<i64, 3>
|
||||
%mask0 = tt.splat %i1_true : (i1) -> tensor<64x16xi1, #blocked0>
|
||||
%t0 = triton_gpu.insert_slice_async %20, %t, %ci0, %mask0 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<64x16x!tt.ptr<f16>, #blocked0> -> tensor<1x64x16xf16, #shared0>
|
||||
triton_gpu.async_wait {num = 1 : i32}
|
||||
// triton_nvidia_gpu.mbarrier_arrive %mbar0 {trackAsyncOp = true} : !tt.ptr<i64, 3>
|
||||
// triton_nvidia_gpu.mbarrier_wait %mbar0, %i1_false : !tt.ptr<i64, 3>
|
||||
%21 = triton_gpu.extract_slice %t0[%ci0, 0, 0][1, 64, 16][1, 1, 1] : tensor<1x64x16xf16, #shared0> to tensor<64x16xf16, #shared0>
|
||||
%22 = arith.muli %10, %11 : tensor<16x1xi32, #blocked2>
|
||||
%23 = tt.addptr %12, %22 : tensor<16x1x!tt.ptr<f16>, #blocked2>, tensor<16x1xi32, #blocked2>
|
||||
%24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr<f16>, #blocked2>) -> tensor<16x64x!tt.ptr<f16>, #blocked2>
|
||||
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
|
||||
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
|
||||
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
|
||||
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
||||
%33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1>
|
||||
%34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1>
|
||||
%35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
|
||||
%36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1>
|
||||
tt.store %37, %38 : tensor<64x64xf32, #blocked1>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
||||
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0>
|
||||
%3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
|
||||
%6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
||||
%7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0>
|
||||
%8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0>
|
||||
%9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2>
|
||||
%11 = tt.splat %arg4 : (i32) -> tensor<1x64xi32, #blocked2>
|
||||
%12 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<16x1x!tt.ptr<f16>, #blocked2>
|
||||
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
|
||||
%bstride1 = arith.muli %11, %15 : tensor<1x64xi32, #blocked2>
|
||||
%16 = tt.broadcast %bstride1 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2>
|
||||
%17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0>
|
||||
%18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
|
||||
%19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x16x!tt.ptr<f16>, #blocked0>
|
||||
%20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr<f16>, #blocked0>, tensor<64x16xi32, #blocked0>
|
||||
%ci0 = arith.constant 0 : i32
|
||||
%i1_true = arith.constant 1 : i1
|
||||
%i1_false = arith.constant 0 : i1
|
||||
%at = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared0>
|
||||
%maska = tt.splat %i1_true : (i1) -> tensor<64x16xi1, #blocked0>
|
||||
%at0 = triton_gpu.insert_slice_async %20, %at, %ci0, %maska {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<64x16x!tt.ptr<f16>, #blocked0> -> tensor<1x64x16xf16, #shared0>
|
||||
triton_gpu.async_commit_group
|
||||
triton_gpu.async_wait {num = 1 : i32}
|
||||
%21 = triton_gpu.extract_slice %at0[%ci0, 0, 0][1, 64, 16][1, 1, 1] : tensor<1x64x16xf16, #shared0> to tensor<64x16xf16, #shared0>
|
||||
%23 = tt.addptr %12, %10 : tensor<16x1x!tt.ptr<f16>, #blocked2>, tensor<16x1xi32, #blocked2>
|
||||
%24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr<f16>, #blocked2>) -> tensor<16x64x!tt.ptr<f16>, #blocked2>
|
||||
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
|
||||
%bt = triton_gpu.alloc_tensor : tensor<1x16x64xf16, #shared1>
|
||||
%maskb = tt.splat %i1_true : (i1) -> tensor<16x64xi1, #blocked2>
|
||||
%bt0 = triton_gpu.insert_slice_async %25, %bt, %ci0, %maskb {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<16x64x!tt.ptr<f16>, #blocked2> -> tensor<1x16x64xf16, #shared1>
|
||||
triton_gpu.async_commit_group
|
||||
triton_gpu.async_wait {num = 1 : i32}
|
||||
%28 = triton_gpu.extract_slice %bt0[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
|
||||
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
||||
%33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1>
|
||||
%34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1>
|
||||
%35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
|
||||
%36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1>
|
||||
tt.store %37, %38 : tensor<64x64xf32, #blocked1>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
||||
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0>
|
||||
%3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
|
||||
%6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
||||
%7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0>
|
||||
%8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0>
|
||||
%9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2>
|
||||
%11 = tt.splat %arg4 : (i32) -> tensor<1x64xi32, #blocked2>
|
||||
%12 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<16x1x!tt.ptr<f16>, #blocked2>
|
||||
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
|
||||
%bstride1 = arith.muli %11, %15 : tensor<1x64xi32, #blocked2>
|
||||
%16 = tt.broadcast %bstride1 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2>
|
||||
%17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0>
|
||||
%18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
|
||||
%19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x16x!tt.ptr<f16>, #blocked0>
|
||||
%20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr<f16>, #blocked0>, tensor<64x16xi32, #blocked0>
|
||||
%ci0 = arith.constant 0 : i32
|
||||
%i1_true = arith.constant 1 : i1
|
||||
%i1_false = arith.constant 0 : i1
|
||||
%at = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared0>
|
||||
%maska = tt.splat %i1_true : (i1) -> tensor<64x16xi1, #blocked0>
|
||||
%mbar0 = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : !tt.ptr<i64, 3>
|
||||
%at0 = triton_gpu.insert_slice_async %20, %at, %ci0, %maska {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<64x16x!tt.ptr<f16>, #blocked0> -> tensor<1x64x16xf16, #shared0>
|
||||
triton_nvidia_gpu.mbarrier_arrive %mbar0 {trackAsyncOp = true} : !tt.ptr<i64, 3>
|
||||
triton_nvidia_gpu.mbarrier_wait %mbar0, %i1_false : !tt.ptr<i64, 3>
|
||||
// triton_gpu.async_wait {num = 1 : i32}
|
||||
%21 = triton_gpu.extract_slice %at0[%ci0, 0, 0][1, 64, 16][1, 1, 1] : tensor<1x64x16xf16, #shared0> to tensor<64x16xf16, #shared0>
|
||||
%23 = tt.addptr %12, %10 : tensor<16x1x!tt.ptr<f16>, #blocked2>, tensor<16x1xi32, #blocked2>
|
||||
%24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr<f16>, #blocked2>) -> tensor<16x64x!tt.ptr<f16>, #blocked2>
|
||||
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
|
||||
%bt = triton_gpu.alloc_tensor : tensor<1x16x64xf16, #shared1>
|
||||
%maskb = tt.splat %i1_true : (i1) -> tensor<16x64xi1, #blocked2>
|
||||
%mbar1 = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : !tt.ptr<i64, 3>
|
||||
%bt0 = triton_gpu.insert_slice_async %25, %bt, %ci0, %maskb {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<16x64x!tt.ptr<f16>, #blocked2> -> tensor<1x16x64xf16, #shared1>
|
||||
triton_nvidia_gpu.mbarrier_arrive %mbar1 {trackAsyncOp = true} : !tt.ptr<i64, 3>
|
||||
triton_nvidia_gpu.mbarrier_wait %mbar1, %i1_false : !tt.ptr<i64, 3>
|
||||
// triton_gpu.async_wait {num = 1 : i32}
|
||||
%28 = triton_gpu.extract_slice %bt0[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
|
||||
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
||||
%33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1>
|
||||
%34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1>
|
||||
%35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
|
||||
%36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1>
|
||||
tt.store %37, %38 : tensor<64x64xf32, #blocked1>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
|
||||
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<64x1xi32, #blocked0>
|
||||
%3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg3 : (i32) -> tensor<64x1xi32, #blocked0>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>, #blocked0>
|
||||
%6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
||||
%7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x16xi32, #blocked0>
|
||||
%8 = tt.broadcast %7 : (tensor<1x16xi32, #blocked0>) -> tensor<64x16xi32, #blocked0>
|
||||
%9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
|
||||
%10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<16x1xi32, #blocked2>
|
||||
%11 = tt.splat %arg4 : (i32) -> tensor<1x64xi32, #blocked2>
|
||||
%12 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<16x1x!tt.ptr<f16>, #blocked2>
|
||||
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
|
||||
%14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%15 = tt.expand_dims %13 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2>
|
||||
%bstride1 = arith.muli %11, %15 : tensor<1x64xi32, #blocked2>
|
||||
%16 = tt.broadcast %bstride1 : (tensor<1x64xi32, #blocked2>) -> tensor<16x64xi32, #blocked2>
|
||||
%17 = arith.muli %2, %4 : tensor<64x1xi32, #blocked0>
|
||||
%18 = tt.addptr %5, %17 : tensor<64x1x!tt.ptr<f16>, #blocked0>, tensor<64x1xi32, #blocked0>
|
||||
%19 = tt.broadcast %18 : (tensor<64x1x!tt.ptr<f16>, #blocked0>) -> tensor<64x16x!tt.ptr<f16>, #blocked0>
|
||||
%20 = tt.addptr %19, %8 : tensor<64x16x!tt.ptr<f16>, #blocked0>, tensor<64x16xi32, #blocked0>
|
||||
%ci0 = arith.constant 0 : i32
|
||||
%ci1 = arith.constant 1 : i32
|
||||
%i1_true = arith.constant 1 : i1
|
||||
%i1_false = arith.constant 0 : i1
|
||||
%at = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared0>
|
||||
%maska = tt.splat %i1_true : (i1) -> tensor<64x16xi1, #blocked0>
|
||||
%mbar = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32 } : tensor<2xi64, #shared0>
|
||||
%mbar0_s = triton_nvidia_gpu.extract_mbarrier %mbar[%ci0] : tensor<2xi64, #shared0>, i32 -> !tt.ptr<i64, 3>
|
||||
%at0 = triton_gpu.insert_slice_async %20, %at, %ci0, %maska {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<64x16x!tt.ptr<f16>, #blocked0> -> tensor<1x64x16xf16, #shared0>
|
||||
triton_nvidia_gpu.mbarrier_arrive %mbar0_s {trackAsyncOp = true} : !tt.ptr<i64, 3>
|
||||
triton_nvidia_gpu.mbarrier_wait %mbar0_s, %i1_false : !tt.ptr<i64, 3>
|
||||
// triton_gpu.async_wait {num = 1 : i32}
|
||||
%21 = triton_gpu.extract_slice %at0[%ci0, 0, 0][1, 64, 16][1, 1, 1] : tensor<1x64x16xf16, #shared0> to tensor<64x16xf16, #shared0>
|
||||
%23 = tt.addptr %12, %10 : tensor<16x1x!tt.ptr<f16>, #blocked2>, tensor<16x1xi32, #blocked2>
|
||||
%24 = tt.broadcast %23 : (tensor<16x1x!tt.ptr<f16>, #blocked2>) -> tensor<16x64x!tt.ptr<f16>, #blocked2>
|
||||
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
|
||||
%bt = triton_gpu.alloc_tensor : tensor<1x16x64xf16, #shared1>
|
||||
%maskb = tt.splat %i1_true : (i1) -> tensor<16x64xi1, #blocked2>
|
||||
%mbar1_s = triton_nvidia_gpu.extract_mbarrier %mbar[%ci1] : tensor<2xi64, #shared0>, i32 -> !tt.ptr<i64, 3>
|
||||
%bt0 = triton_gpu.insert_slice_async %25, %bt, %ci0, %maskb {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, type = 1 : i32} : tensor<16x64x!tt.ptr<f16>, #blocked2> -> tensor<1x16x64xf16, #shared1>
|
||||
triton_nvidia_gpu.mbarrier_arrive %mbar1_s {trackAsyncOp = true} : !tt.ptr<i64, 3>
|
||||
triton_nvidia_gpu.mbarrier_wait %mbar1_s, %i1_false : !tt.ptr<i64, 3>
|
||||
// triton_gpu.async_wait {num = 1 : i32}
|
||||
%28 = triton_gpu.extract_slice %bt0[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
|
||||
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
||||
%33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1>
|
||||
%34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1>
|
||||
%35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
|
||||
%36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1>
|
||||
tt.store %37, %38 : tensor<64x64xf32, #blocked1>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], CTAsPerCGA = [1, 1], order = [0, 1], CTAOrder = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 3, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}>
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
|
||||
func.func public @matmul_kernel_0d1d2d3d4c5d6c7d8c(%aBasePtr : !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
||||
%bBasePtr : !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
||||
%cBasePtr : !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%sizeM : i32 {tt.divisibility = 16 : i32},
|
||||
%sizeN : i32 {tt.divisibility = 16 : i32},
|
||||
%sizeK : i32 {tt.divisibility = 16 : i32},
|
||||
%aStride0 : i32 {tt.divisibility = 16 : i32},
|
||||
%aStride1 : i32 {tt.divisibility = 16 : i32},
|
||||
%bStride0 : i32 {tt.divisibility = 16 : i32},
|
||||
%bStride1 : i32 {tt.divisibility = 16 : i32},
|
||||
%cStride0 : i32 {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
|
||||
%ci0 = arith.constant 0 : i32
|
||||
%ci1 = arith.constant 1 : i32
|
||||
%i1_true = arith.constant 1 : i1
|
||||
%aCoord0 = arith.constant 0 : i32
|
||||
%aCoord1 = arith.constant 0 : i32
|
||||
%bCoord0 = arith.constant 0 : i32
|
||||
%bCoord1 = arith.constant 0 : i32
|
||||
%mbar = triton_nvidia_gpu.alloc_mbarrier { count = 128 : i32, txCount = 2048 : i32 } : tensor<2xi64, #shared0>
|
||||
%mbar_a = triton_nvidia_gpu.extract_mbarrier %mbar[%ci0] : tensor<2xi64, #shared0>, i32 -> !tt.ptr<i64, 3>
|
||||
%mbar_b = triton_nvidia_gpu.extract_mbarrier %mbar[%ci1] : tensor<2xi64, #shared0>, i32 -> !tt.ptr<i64, 3>
|
||||
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
%3 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<64x1xi32, #blocked1>
|
||||
%14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
|
||||
// Load A
|
||||
%a_smem = triton_gpu.alloc_tensor : tensor<1x64x16xf16, #shared0>
|
||||
%a_smem_loaded = triton_gpu.load_tile_async %aBasePtr[%sizeM, %sizeK][%aStride0, %aStride1][%aCoord0, %aCoord1], %mbar_a, %a_smem[%ci0] {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, order = dense<[1, 0]> : tensor<2xi32>} : !tt.ptr<i64, 3> -> tensor<1x64x16xf16, #shared0>
|
||||
triton_nvidia_gpu.mbarrier_arrive %mbar_a {trackAsyncOp = false} : !tt.ptr<i64, 3>
|
||||
triton_nvidia_gpu.mbarrier_wait %mbar_a, %i1_true : !tt.ptr<i64, 3>
|
||||
%21 = triton_gpu.extract_slice %a_smem_loaded[%ci0, 0, 0][1, 64, 16][1, 1, 1] : tensor<1x64x16xf16, #shared0> to tensor<64x16xf16, #shared0>
|
||||
|
||||
// Load B
|
||||
%b_smem = triton_gpu.alloc_tensor : tensor<1x16x64xf16, #shared1>
|
||||
%b_smem_loaded= triton_gpu.load_tile_async %bBasePtr[%sizeK, %sizeN][%bStride0, %bStride1][%bCoord0, %bCoord1], %mbar_b, %b_smem[%ci0] {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, order = dense<[0, 1]> : tensor<2xi32>} : !tt.ptr<i64, 3> -> tensor<1x16x64xf16, #shared1>
|
||||
triton_nvidia_gpu.mbarrier_arrive %mbar_b {trackAsyncOp = false} : !tt.ptr<i64, 3>
|
||||
triton_nvidia_gpu.mbarrier_wait %mbar_b, %i1_true : !tt.ptr<i64, 3>
|
||||
%28 = triton_gpu.extract_slice %b_smem_loaded[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
|
||||
|
||||
// Calling MMA
|
||||
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
|
||||
|
||||
// Epilogue
|
||||
%30 = tt.splat %cStride0: (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%31 = tt.splat %cBasePtr: (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
||||
%33 = tt.broadcast %32 : (tensor<1x64xi32, #blocked1>) -> tensor<64x64xi32, #blocked1>
|
||||
%34 = arith.muli %3, %30 : tensor<64x1xi32, #blocked1>
|
||||
%35 = tt.addptr %31, %34 : tensor<64x1x!tt.ptr<f32>, #blocked1>, tensor<64x1xi32, #blocked1>
|
||||
%36 = tt.broadcast %35 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%37 = tt.addptr %36, %33 : tensor<64x64x!tt.ptr<f32>, #blocked1>, tensor<64x64xi32, #blocked1>
|
||||
%38 = triton_gpu.convert_layout %29 : (tensor<64x64xf32, #mma>) -> tensor<64x64xf32, #blocked1>
|
||||
tt.store %37, %38 : tensor<64x64xf32, #blocked1>
|
||||
return
|
||||
}
|
||||
}
|
||||
32
python/test/unit/hopper/utils.py
Normal file
32
python/test/unit/hopper/utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
|
||||
|
||||
def get_variant_golden(a, b):
|
||||
SIZE_M = a.shape[0]
|
||||
SIZE_K = a.shape[1]
|
||||
SIZE_N = b.shape[1]
|
||||
assert a.shape[1] == b.shape[0]
|
||||
zero_M_K = torch.zeros((SIZE_M, SIZE_K), dtype=a.dtype).cuda()
|
||||
zero_3M_K = torch.zeros((3 * SIZE_M, SIZE_K), dtype=a.dtype).cuda()
|
||||
zero_K_N = torch.zeros((SIZE_K, SIZE_N), dtype=b.dtype).cuda()
|
||||
zero_3K_N = torch.zeros((3 * SIZE_K, SIZE_N), dtype=b.dtype).cuda()
|
||||
a_padded = torch.cat((a, zero_M_K, zero_M_K), 0)
|
||||
a_padded = torch.cat((a_padded, zero_3M_K, zero_3M_K), 1)
|
||||
b_padded = torch.cat((b, zero_K_N, zero_K_N), 0)
|
||||
b_padded = torch.cat((b_padded, zero_3K_N, zero_3K_N), 1)
|
||||
c_padded = torch.matmul(a_padded, b_padded)
|
||||
return c_padded[:SIZE_M, :SIZE_N]
|
||||
|
||||
# It's not easy to get a proper error threshold in different size
|
||||
# Here the gemm calculation is padded to a different size in order to get
|
||||
# a variant version of the golden result. And the error between golden and
|
||||
# golden_variant provide reference on selecting the proper rtol / atol.
|
||||
|
||||
|
||||
def get_proper_err(golden, golden_variant):
|
||||
golden_diff = golden - golden_variant
|
||||
golden_abs_err = torch.max(torch.abs(golden_diff)).item()
|
||||
# avoid problems when golden_rel_err is 'inf'
|
||||
abs_golden = torch.abs(golden) + torch.full_like(golden, torch.finfo(golden.dtype).smallest_normal)
|
||||
golden_rel_err = torch.max(torch.abs(golden_diff) / abs_golden).item()
|
||||
return (golden_abs_err, golden_rel_err)
|
||||
Reference in New Issue
Block a user