mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Reformat Python code with yapf. (#2589)
I've add an option to yapf to do what we want for long lines, see https://github.com/google/yapf/pull/1177. We can now have a real Python formatter, yay! To make this PR, I ran my modified yapf over the repository, then looked over the full diff. Where yapf was mangling the param list of long function decls/calls (mostly kernels), I manually added `#` to put linebreaks where we want. I fixed up other formatting too -- mostly adding or removing a trailing comma from lists. Overall, trailing `#` was sufficient to get formatting similar to our current code. I didn't have to disable yapf anywhere. --------- Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
@@ -18,7 +18,6 @@
|
||||
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
@@ -35,18 +34,15 @@ import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
L, M,
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX, D0,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
def _fwd_kernel(Q, K, V, sm_scale, #
|
||||
L, M, #
|
||||
Out, #
|
||||
stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vk, stride_vn, #
|
||||
stride_oz, stride_oh, stride_om, stride_on, #
|
||||
Z, H, N_CTX, D0, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
|
||||
@@ -61,31 +57,38 @@ def _fwd_kernel(
|
||||
|
||||
stride_qh_2d = stride_qh // stride_qm // stride_qk
|
||||
|
||||
q_tile_ptr = tl.make_block_ptr(base=Q,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(
|
||||
off_hz * stride_qh_2d + start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
k_tile_ptr = tl.make_block_ptr(base=K,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(off_hz * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
v_tile_ptr = tl.make_block_ptr(base=V,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(off_hz * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
out_tile_ptr = tl.make_block_ptr(base=Out,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
q_tile_ptr = tl.make_block_ptr(
|
||||
base=Q,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
k_tile_ptr = tl.make_block_ptr(
|
||||
base=K,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(off_hz * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
v_tile_ptr = tl.make_block_ptr(
|
||||
base=V,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(off_hz * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
out_tile_ptr = tl.make_block_ptr(
|
||||
base=Out,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(q_tile_ptr)
|
||||
|
||||
@@ -96,8 +99,7 @@ def _fwd_kernel(
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (
|
||||
start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
# compute new m
|
||||
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
|
||||
# correct old l
|
||||
@@ -133,11 +135,9 @@ def _fwd_kernel(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO, L,
|
||||
NewDO, Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
):
|
||||
def _bwd_preprocess(Out, DO, L, #
|
||||
NewDO, Delta, #
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
# load
|
||||
@@ -153,19 +153,14 @@ def _bwd_preprocess(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L, M,
|
||||
D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX, D0,
|
||||
num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
def _bwd_kernel(Q, K, V, sm_scale, Out, DO, #
|
||||
DQ, DK, DV, #
|
||||
L, M, #
|
||||
D, stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vk, stride_vn, #
|
||||
Z, H, N_CTX, D0, #
|
||||
num_block, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
@@ -173,55 +168,62 @@ def _bwd_kernel(
|
||||
stride_qz_2d = stride_qz // stride_qm // stride_qk
|
||||
stride_qh_2d = stride_qh // stride_qm // stride_qk
|
||||
|
||||
q_tile_ptr = tl.make_block_ptr(base=Q,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
k_tile_ptr = tl.make_block_ptr(base=K,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
v_tile_ptr = tl.make_block_ptr(base=V,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
do_tile_ptr = tl.make_block_ptr(base=DO,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
dq_tile_ptr = tl.make_block_ptr(base=DQ,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
dk_tile_ptr = tl.make_block_ptr(base=DK,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
dv_tile_ptr = tl.make_block_ptr(base=DV,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(
|
||||
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0))
|
||||
q_tile_ptr = tl.make_block_ptr(
|
||||
base=Q,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
k_tile_ptr = tl.make_block_ptr(
|
||||
base=K,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
v_tile_ptr = tl.make_block_ptr(
|
||||
base=V,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
do_tile_ptr = tl.make_block_ptr(
|
||||
base=DO,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
dq_tile_ptr = tl.make_block_ptr(
|
||||
base=DQ,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
dk_tile_ptr = tl.make_block_ptr(
|
||||
base=DK,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
dv_tile_ptr = tl.make_block_ptr(
|
||||
base=DV,
|
||||
shape=(D0, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
# offset pointers for batch/head
|
||||
DQ += off_z * stride_qz + off_h * stride_qh
|
||||
for start_n in range(0, num_block):
|
||||
@@ -250,8 +252,7 @@ def _bwd_kernel(
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
qk = tl.dot(q, tl.trans(k))
|
||||
qk = tl.where(offs_m_curr[:, None] >= (
|
||||
offs_n[None, :]), qk, float("-inf"))
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
m = tl.load(m_ptrs + offs_m_curr)
|
||||
p = tl.exp(qk * sm_scale - m[:, None])
|
||||
# compute dv
|
||||
@@ -301,29 +302,21 @@ class _attention(torch.autograd.Function):
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
|
||||
L = torch.empty(
|
||||
(q.shape[0] * q.shape[1], q.shape[2]),
|
||||
device=q.device,
|
||||
dtype=torch.float32)
|
||||
m = torch.empty(
|
||||
(q.shape[0] * q.shape[1], q.shape[2]),
|
||||
device=q.device,
|
||||
dtype=torch.float32)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
D0 = q.shape[0] * q.shape[1] * q.shape[2]
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
L, m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2], D0,
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||
num_stages=2,
|
||||
)
|
||||
q, k, v, sm_scale, #
|
||||
L, m, #
|
||||
o, #
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
|
||||
q.shape[0], q.shape[1], q.shape[2], D0, #
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=Lk, #
|
||||
num_warps=num_warps, num_stages=2)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.grid = grid
|
||||
@@ -343,25 +336,22 @@ class _attention(torch.autograd.Function):
|
||||
delta = torch.empty_like(l)
|
||||
D0 = q.shape[0] * q.shape[1] * q.shape[2]
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do, l,
|
||||
do_scaled, delta,
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
l, m,
|
||||
delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2], D0,
|
||||
ctx.grid[0],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||
num_stages=1,
|
||||
)
|
||||
o, do, l, #
|
||||
do_scaled, delta, #
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL)
|
||||
_bwd_kernel[(ctx.grid[1], )](
|
||||
q, k, v, ctx.sm_scale, #
|
||||
o, do_scaled, #
|
||||
dq, dk, dv, #
|
||||
l, m, #
|
||||
delta, #
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
||||
q.shape[0], q.shape[1], q.shape[2], D0, #
|
||||
ctx.grid[0], #
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=ctx.BLOCK_DMODEL, #
|
||||
num_warps=8, num_stages=1)
|
||||
return dq, dk, dv, None
|
||||
|
||||
|
||||
@@ -380,15 +370,9 @@ attention = _attention.apply
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+")
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty(
|
||||
(Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(
|
||||
mean=0.1, std=0.2).requires_grad_()
|
||||
k = torch.empty(
|
||||
(Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(
|
||||
mean=0.4, std=0.2).requires_grad_()
|
||||
v = torch.empty(
|
||||
(Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(
|
||||
mean=0.3, std=0.2).requires_grad_()
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
|
||||
sm_scale = 0.2
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
@@ -427,22 +411,25 @@ except BaseException:
|
||||
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = [triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
x_vals=[2**i for i in range(10, 14)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={
|
||||
'H': N_HEADS,
|
||||
'BATCH': BATCH,
|
||||
'D_HEAD': D_HEAD,
|
||||
'dtype': torch.float16,
|
||||
'mode': mode}
|
||||
) for mode in ['fwd', 'bwd']]
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
x_vals=[2**i for i in range(10, 14)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={
|
||||
'H': N_HEADS,
|
||||
'BATCH': BATCH,
|
||||
'D_HEAD': D_HEAD,
|
||||
'dtype': torch.float16,
|
||||
'mode': mode,
|
||||
},
|
||||
) for mode in ['fwd', 'bwd']
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
@@ -463,9 +450,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
if provider == "flash":
|
||||
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
||||
cu_seqlens = torch.zeros(
|
||||
(BATCH + 1,), device=device, dtype=torch.int32)
|
||||
lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device)
|
||||
cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32)
|
||||
cu_seqlens[1:] = lengths.cumsum(0)
|
||||
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
|
||||
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
|
||||
|
||||
@@ -32,19 +32,30 @@ import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def matmul_no_scf_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr
|
||||
):
|
||||
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
def matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr #
|
||||
):
|
||||
a_block_ptr = tl.make_block_ptr(
|
||||
base=a_ptr,
|
||||
shape=(M, K),
|
||||
strides=(stride_am, stride_ak),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_K),
|
||||
order=(1, 0),
|
||||
)
|
||||
b_block_ptr = tl.make_block_ptr(
|
||||
base=b_ptr,
|
||||
shape=(K, N),
|
||||
strides=(stride_bk, stride_bn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_K, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
a = tl.load(a_block_ptr)
|
||||
b = tl.load(b_block_ptr)
|
||||
|
||||
@@ -54,8 +65,8 @@ def matmul_no_scf_kernel(
|
||||
c = c.to(tl.float16)
|
||||
|
||||
if USE_TMA_EPILOGUE:
|
||||
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
|
||||
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
|
||||
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
|
||||
tl.store(c_block_ptr, c)
|
||||
else:
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
@@ -64,33 +75,30 @@ def matmul_no_scf_kernel(
|
||||
tl.store(c_ptrs, c)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,ENABLE_WS',
|
||||
itertools.chain(
|
||||
*[
|
||||
[
|
||||
# numCTAs = 1, no TMA multicast:
|
||||
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
# static mask, cluster 4x1
|
||||
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
# dynamic mask, cluster 2x2
|
||||
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
# small M, N
|
||||
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
] for USE_TMA_EPILOGUE in [True, False]
|
||||
for ENABLE_WS in [False, True]
|
||||
]))
|
||||
@pytest.mark.parametrize(
|
||||
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,ENABLE_WS',
|
||||
itertools.chain(*[[
|
||||
# numCTAs = 1, no TMA multicast:
|
||||
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
# static mask, cluster 4x1
|
||||
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
# dynamic mask, cluster 2x2
|
||||
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
# small M, N
|
||||
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
|
||||
] for USE_TMA_EPILOGUE in [True, False] for ENABLE_WS in [False, True]]))
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, ENABLE_WS):
|
||||
if (TRANS_A):
|
||||
@@ -107,46 +115,41 @@ def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE
|
||||
else:
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
matmul_no_scf_kernel[(1, 1)](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
M=M, N=N, K=K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K,
|
||||
num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS,
|
||||
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"),
|
||||
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE,
|
||||
enable_warp_specialization=ENABLE_WS)
|
||||
matmul_no_scf_kernel[(1, 1)](
|
||||
a_ptr=a, b_ptr=b, c_ptr=c, #
|
||||
M=M, N=N, K=K, #
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1), #
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1), #
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1), #
|
||||
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, #
|
||||
num_warps=NUM_WARPS, #
|
||||
num_ctas=NUM_CTAS, #
|
||||
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), #
|
||||
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE, #
|
||||
enable_warp_specialization=ENABLE_WS)
|
||||
a_f32 = a.to(torch.float32)
|
||||
b_f32 = b.to(torch.float32)
|
||||
golden = torch.matmul(a_f32, b_f32)
|
||||
torch.set_printoptions(profile="full")
|
||||
assert_close(
|
||||
c,
|
||||
golden,
|
||||
rtol=1e-2,
|
||||
atol=1e-3,
|
||||
check_dtype=False)
|
||||
assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_wm, stride_wn,
|
||||
stride_zm, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
|
||||
out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr,
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
||||
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
|
||||
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr,
|
||||
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr,
|
||||
W_ORDER_0: tl.constexpr, W_ORDER_1: tl.constexpr,
|
||||
Z_ORDER_0: tl.constexpr, Z_ORDER_1: tl.constexpr
|
||||
):
|
||||
def matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_wm, stride_wn, #
|
||||
stride_zm, stride_zn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, #
|
||||
out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, #
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, #
|
||||
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, #
|
||||
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, #
|
||||
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, #
|
||||
W_ORDER_0: tl.constexpr, W_ORDER_1: tl.constexpr, #
|
||||
Z_ORDER_0: tl.constexpr, Z_ORDER_1: tl.constexpr #
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_N)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_M)
|
||||
@@ -159,13 +162,31 @@ def matmul_kernel(
|
||||
block_offset_m = pid_m * BLOCK_M
|
||||
block_offset_n = pid_n * BLOCK_N
|
||||
|
||||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(A_ORDER_0, A_ORDER_1))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(B_ORDER_0, B_ORDER_1))
|
||||
a_tile_ptr = tl.make_block_ptr(
|
||||
base=a_ptr,
|
||||
shape=(M, K),
|
||||
strides=(stride_am, stride_ak),
|
||||
offsets=(block_offset_m, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_K),
|
||||
order=(A_ORDER_0, A_ORDER_1),
|
||||
)
|
||||
b_tile_ptr = tl.make_block_ptr(
|
||||
base=b_ptr,
|
||||
shape=(K, N),
|
||||
strides=(stride_bk, stride_bn),
|
||||
offsets=(0, block_offset_n),
|
||||
block_shape=(BLOCK_K, BLOCK_N),
|
||||
order=(B_ORDER_0, B_ORDER_1),
|
||||
)
|
||||
# for chain-dot, BLOCK_N must always be equal to N, and each program loads the whole W matrix
|
||||
w_tile_ptr = tl.make_block_ptr(base=w_ptr, shape=(N, N), strides=(stride_wm, stride_wn),
|
||||
offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_N), order=(W_ORDER_0, W_ORDER_1))
|
||||
w_tile_ptr = tl.make_block_ptr(
|
||||
base=w_ptr,
|
||||
shape=(N, N),
|
||||
strides=(stride_wm, stride_wn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_N),
|
||||
order=(W_ORDER_0, W_ORDER_1),
|
||||
)
|
||||
z = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
|
||||
offs_m = block_offset_m + tl.arange(0, BLOCK_M)
|
||||
@@ -204,139 +225,146 @@ def matmul_kernel(
|
||||
|
||||
if USE_TMA_STORE:
|
||||
z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn),
|
||||
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(Z_ORDER_0, Z_ORDER_1))
|
||||
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(Z_ORDER_0, Z_ORDER_1))
|
||||
tl.store(z_block_ptr, z, boundary_check=(0, 1))
|
||||
else:
|
||||
tl.store(z_ptrs, z, mask=mask)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
|
||||
[
|
||||
# corner shapes
|
||||
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws)
|
||||
for shape_w_c in [
|
||||
[4096, 1, 1024, False, False, True],
|
||||
[2048, 204, 1000, True, False, True],
|
||||
[4096, 1, 1024, False, False, False],
|
||||
[2048, 204, 1000, True, False, False],
|
||||
]
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for enable_ws in [False, True]
|
||||
] + [
|
||||
# softmax epilogue
|
||||
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 64, 64, 64],
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[16, 16, 64, 4, 1, 16, 16, 64],
|
||||
[64, 64, 32, 8, 1, 64, 64, 64],
|
||||
[128, 128, 64, 4, 1, 128, 128, 128],
|
||||
]
|
||||
for epilogue in ['softmax']
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False,]
|
||||
for trans_b in [True,]
|
||||
for trans_output in [False,]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [False, True]
|
||||
] + [
|
||||
# loop over epilogues besides of softmax
|
||||
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 128, 128, 64],
|
||||
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
|
||||
# for chain-dot
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[64, 64, 16, 4, 1, None, None, None],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 16, 64, 4, 1, 128, 128, 64],
|
||||
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
|
||||
# repeat
|
||||
[64, 64, 32, 8, 1, 128, 256, 64],
|
||||
[64, 64, 16, 8, 2, 128, 128, 64],
|
||||
# irregular shape
|
||||
[128, 128, 64, 4, 1, 500, 200, 128],
|
||||
[128, 128, 64, 4, 2, 513, 193, 192],
|
||||
]
|
||||
for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot']
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False,]
|
||||
for trans_b in [True,]
|
||||
for trans_output in [False,]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [False, True]
|
||||
if not (epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
|
||||
] + [
|
||||
# loop over tile shapes and transpose combinations
|
||||
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 32, 4, 1, 128, 256, 64],
|
||||
[128, 128, 16, 4, 4, 512, 256, 64],
|
||||
[128, 256, 32, 4, 8, 256, 256, 192],
|
||||
[512, 256, 32, 4, 8, 1024, 256, 192],
|
||||
# BLOCK_K >= 128
|
||||
[64, 128, 128, 4, 1, 512, 256, 256],
|
||||
[128, 128, 128, 4, 1, 256, 256, 192],
|
||||
[128, 128, 128, 4, 2, 256, 256, 192],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 32, 32, 4, 1, 128, 256, 64],
|
||||
[32, 32, 16, 4, 1, 256, 256, 192],
|
||||
[16, 32, 64, 4, 4, 512, 256, 64],
|
||||
]
|
||||
for out_dtype in ['float32',]
|
||||
for use_tma_store in [False,]
|
||||
for trans_a in [False, True]
|
||||
for trans_b in [False, True]
|
||||
for trans_output in [False, True]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [False, True]
|
||||
] + [
|
||||
# loop over instr shapes & pipeline stages
|
||||
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for n in [16, 32, 64, 128, 256]
|
||||
for trans_output in [False,]
|
||||
for out_dtype in ['float32',]
|
||||
for use_tma_store in [False,]
|
||||
for num_stages in [2, 4, 5, 7]
|
||||
for enable_ws in [False, True]
|
||||
] + [
|
||||
# irregular shapes
|
||||
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[128, 128, 64, 4, 1],
|
||||
[256, 128, 64, 4, 2],
|
||||
[128, 128, 128, 4, 2],
|
||||
]
|
||||
for shape in [
|
||||
[512, 360, 1024],
|
||||
[360, 4096, 512],
|
||||
]
|
||||
for trans_output in [False,]
|
||||
for out_dtype in ['float32',]
|
||||
for use_tma_store in [False, True]
|
||||
for num_stages in [3, 4]
|
||||
for enable_ws in [False, True]
|
||||
])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()
|
||||
[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
|
||||
@pytest.mark.parametrize(
|
||||
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
|
||||
[
|
||||
# corner shapes
|
||||
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws)
|
||||
for shape_w_c in [
|
||||
[4096, 1, 1024, False, False, True],
|
||||
[2048, 204, 1000, True, False, True],
|
||||
[4096, 1, 1024, False, False, False],
|
||||
[2048, 204, 1000, True, False, False],
|
||||
]
|
||||
for out_dtype in ['float16', 'float32'] #
|
||||
for use_tma_store in [False, True] #
|
||||
for enable_ws in [False, True]
|
||||
] + [
|
||||
# softmax epilogue
|
||||
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 64, 64, 64],
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[16, 16, 64, 4, 1, 16, 16, 64],
|
||||
[64, 64, 32, 8, 1, 64, 64, 64],
|
||||
[128, 128, 64, 4, 1, 128, 128, 128],
|
||||
]
|
||||
for epilogue in ['softmax']
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False]
|
||||
for trans_b in [True]
|
||||
for trans_output in [False]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [False, True]
|
||||
] + [
|
||||
# loop over epilogues besides of softmax
|
||||
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 128, 128, 64],
|
||||
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64]
|
||||
for num_warps in [4, 8]
|
||||
for num_ctas in [1, 2, 4]],
|
||||
# for chain-dot
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[64, 64, 16, 4, 1, None, None, None],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 16, 64, 4, 1, 128, 128, 64],
|
||||
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256]
|
||||
for num_warps in [4, 8]
|
||||
for num_ctas in [1, 2]],
|
||||
# repeat
|
||||
[64, 64, 32, 8, 1, 128, 256, 64],
|
||||
[64, 64, 16, 8, 2, 128, 128, 64],
|
||||
# irregular shape
|
||||
[128, 128, 64, 4, 1, 500, 200, 128],
|
||||
[128, 128, 64, 4, 2, 513, 193, 192],
|
||||
]
|
||||
for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot']
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False]
|
||||
for trans_b in [True]
|
||||
for trans_output in [False]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [False, True]
|
||||
if not (epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
|
||||
] + [
|
||||
# loop over tile shapes and transpose combinations
|
||||
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 32, 4, 1, 128, 256, 64],
|
||||
[128, 128, 16, 4, 4, 512, 256, 64],
|
||||
[128, 256, 32, 4, 8, 256, 256, 192],
|
||||
[512, 256, 32, 4, 8, 1024, 256, 192],
|
||||
# BLOCK_K >= 128
|
||||
[64, 128, 128, 4, 1, 512, 256, 256],
|
||||
[128, 128, 128, 4, 1, 256, 256, 192],
|
||||
[128, 128, 128, 4, 2, 256, 256, 192],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 32, 32, 4, 1, 128, 256, 64],
|
||||
[32, 32, 16, 4, 1, 256, 256, 192],
|
||||
[16, 32, 64, 4, 4, 512, 256, 64],
|
||||
]
|
||||
for out_dtype in ['float32']
|
||||
for use_tma_store in [False]
|
||||
for trans_a in [False, True]
|
||||
for trans_b in [False, True]
|
||||
for trans_output in [False, True]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [False, True]
|
||||
] + [
|
||||
# loop over instr shapes & pipeline stages
|
||||
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages,
|
||||
enable_ws)
|
||||
for n in [16, 32, 64, 128, 256]
|
||||
for trans_output in [False]
|
||||
for out_dtype in ['float32']
|
||||
for use_tma_store in [False]
|
||||
for num_stages in [2, 4, 5, 7]
|
||||
for enable_ws in [False, True]
|
||||
] + [
|
||||
# irregular shapes
|
||||
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[128, 128, 64, 4, 1],
|
||||
[256, 128, 64, 4, 2],
|
||||
[128, 128, 128, 4, 2],
|
||||
]
|
||||
for shape in [
|
||||
[512, 360, 1024],
|
||||
[360, 4096, 512],
|
||||
]
|
||||
for trans_output in [False]
|
||||
for out_dtype in ['float32']
|
||||
for use_tma_store in [False, True]
|
||||
for num_stages in [3, 4]
|
||||
for enable_ws in [False, True]
|
||||
])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
|
||||
out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
|
||||
'16-32-64-4-4-512-256-64-True-False',
|
||||
'16-32-64-4-4-512-256-64-True-True',
|
||||
'16-32-64-4-4-512-256-64-False-False',
|
||||
'16-32-64-4-4-512-256-64-False-True',
|
||||
'16-32-64-4-4-512-256-64-True-False',
|
||||
'16-32-64-4-4-512-256-64-True-True',
|
||||
'16-32-64-4-4-512-256-64-False-False',
|
||||
'16-32-64-4-4-512-256-64-False-True',
|
||||
]:
|
||||
pytest.skip('shapePerCTA[1] < 16 not supported')
|
||||
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
|
||||
'16-32-64-4-1-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-True',
|
||||
'16-32-64-8-2-256-256-256-False',
|
||||
'16-32-64-8-2-256-256-256-True',
|
||||
'16-32-64-4-1-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-True',
|
||||
'16-32-64-8-2-256-256-256-False',
|
||||
'16-32-64-8-2-256-256-256-True',
|
||||
]:
|
||||
pytest.skip('Known legacy issue, ldmatrix can only support x4')
|
||||
enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower()
|
||||
@@ -413,38 +441,38 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A,
|
||||
else:
|
||||
ref = d
|
||||
return ref
|
||||
|
||||
golden = process_epilogue(dot, bias, w, epilogue)
|
||||
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
|
||||
pgm = matmul_kernel[grid](a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z,
|
||||
M=M, N=N, K=K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_wm=w.stride(0), stride_wn=w.stride(1),
|
||||
stride_zm=z.stride(0), stride_zn=z.stride(1),
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8,
|
||||
out_dtype=out_dtype,
|
||||
USE_TMA_STORE=USE_TMA_STORE,
|
||||
ADD_MATRIX=epilogue == 'add-matrix',
|
||||
ADD_ROWS=epilogue == 'add-rows',
|
||||
ADD_COLS=epilogue == 'add-cols',
|
||||
DO_SOFTMAX=epilogue == 'softmax',
|
||||
CHAIN_DOT=epilogue == 'chain-dot',
|
||||
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1],
|
||||
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1],
|
||||
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1],
|
||||
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1],
|
||||
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES,
|
||||
enable_warp_specialization=ENABLE_WS)
|
||||
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
|
||||
|
||||
pgm = matmul_kernel[grid](
|
||||
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
|
||||
M=M, N=N, K=K, #
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1), #
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1), #
|
||||
stride_wm=w.stride(0), stride_wn=w.stride(1), #
|
||||
stride_zm=z.stride(0), stride_zn=z.stride(1), #
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
|
||||
out_dtype=out_dtype, #
|
||||
USE_TMA_STORE=USE_TMA_STORE, #
|
||||
ADD_MATRIX=epilogue == 'add-matrix', #
|
||||
ADD_ROWS=epilogue == 'add-rows', #
|
||||
ADD_COLS=epilogue == 'add-cols', #
|
||||
DO_SOFTMAX=epilogue == 'softmax', #
|
||||
CHAIN_DOT=epilogue == 'chain-dot', #
|
||||
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
|
||||
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
|
||||
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
|
||||
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
|
||||
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, #
|
||||
enable_warp_specialization=ENABLE_WS)
|
||||
|
||||
torch.set_printoptions(profile="full")
|
||||
golden = torch.nn.functional.normalize(golden)
|
||||
z = torch.nn.functional.normalize(z)
|
||||
assert_close(z, golden,
|
||||
rtol=1e-2,
|
||||
atol=1e-3,
|
||||
check_dtype=False)
|
||||
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
|
||||
|
||||
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
|
||||
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
|
||||
|
||||
@@ -27,16 +27,20 @@ import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def gemm_fusion_kernel(A, B, C, E,
|
||||
M, N, K,
|
||||
stride_am, stride_ak, stride_bn, stride_bk, stride_cn, stride_ck, stride_em, stride_ek,
|
||||
def gemm_fusion_kernel(A, B, C, E, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, stride_bn, stride_bk, stride_cn, stride_ck, stride_em, stride_ek, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
a_tile_ptr = tl.make_block_ptr(base=A, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_tile_ptr = tl.make_block_ptr(base=B, shape=(N, K), strides=(stride_bn, stride_bk), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))
|
||||
c_tile_ptr = tl.make_block_ptr(base=C, shape=(N, K), strides=(stride_cn, stride_ck), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))
|
||||
e_tile_ptr = tl.make_block_ptr(base=E, shape=(M, K), strides=(stride_em, stride_ek), offsets=(pid * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
a_tile_ptr = tl.make_block_ptr(base=A, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_tile_ptr = tl.make_block_ptr(base=B, shape=(N, K), strides=(stride_bn, stride_bk), offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))
|
||||
c_tile_ptr = tl.make_block_ptr(base=C, shape=(N, K), strides=(stride_cn, stride_ck), offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))
|
||||
e_tile_ptr = tl.make_block_ptr(base=E, shape=(M, K), strides=(stride_em, stride_ek), offsets=(pid * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
|
||||
acc_e = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
|
||||
a = tl.load(a_tile_ptr)
|
||||
@@ -57,66 +61,70 @@ def gemm_fusion_kernel(A, B, C, E,
|
||||
def test_gemm_fusion():
|
||||
M, N, K = 4096, 4096, 64
|
||||
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64
|
||||
A = torch.empty(
|
||||
(M, K), dtype=torch.float16, device='cuda').normal_(
|
||||
mean=0.1, std=0.2)
|
||||
B = torch.empty(
|
||||
(N, K), dtype=torch.float16, device='cuda').normal_(
|
||||
mean=0.1, std=0.2)
|
||||
C = torch.empty(
|
||||
(N, K), dtype=torch.float16, device='cuda').normal_(
|
||||
mean=0.1, std=0.2)
|
||||
A = torch.empty((M, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)
|
||||
B = torch.empty((N, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)
|
||||
C = torch.empty((N, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)
|
||||
E = torch.empty((M, K), dtype=torch.float16, device='cuda')
|
||||
ref_out = torch.matmul(torch.matmul(A, B.T), C)
|
||||
num_warps = 4
|
||||
grid = (triton.cdiv(M, BLOCK_M), 1)
|
||||
gemm_fusion_kernel[grid](A, B, C, E, M, N, K,
|
||||
A.stride(0), A.stride(1), B.stride(0), B.stride(
|
||||
1), C.stride(0), C.stride(1), E.stride(0), E.stride(1),
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_warps=num_warps)
|
||||
gemm_fusion_kernel[grid](
|
||||
A, B, C, E, M, N, K, #
|
||||
A.stride(0), A.stride(1), #
|
||||
B.stride(0), B.stride(1), #
|
||||
C.stride(0), C.stride(1), #
|
||||
E.stride(0), E.stride(1), #
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, #
|
||||
num_warps=num_warps)
|
||||
|
||||
torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def batched_gemm_fusion(
|
||||
Q, K, V, Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, NH, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
def batched_gemm_fusion(Q, K, V, Out, #
|
||||
stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vk, stride_vn, #
|
||||
stride_oz, stride_oh, stride_om, stride_on, #
|
||||
Z, NH, N_CTX, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
q_tile_ptr = tl.make_block_ptr(base=Q,
|
||||
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qz, stride_qh, stride_qm, stride_qk),
|
||||
offsets=(off_hz // NH, off_hz % NH, start_m, 0),
|
||||
block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),
|
||||
order=(3, 2, 1, 0))
|
||||
k_tile_ptr = tl.make_block_ptr(base=K,
|
||||
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_kz, stride_kh, stride_kn, stride_kk),
|
||||
offsets=(off_hz // NH, off_hz % NH, 0, 0),
|
||||
block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),
|
||||
order=(3, 2, 1, 0))
|
||||
v_tile_ptr = tl.make_block_ptr(base=V,
|
||||
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_vz, stride_vh, stride_vk, stride_vn),
|
||||
offsets=(off_hz // NH, off_hz % NH, 0, 0),
|
||||
block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),
|
||||
order=(3, 2, 1, 0))
|
||||
o_tile_ptr = tl.make_block_ptr(base=Out,
|
||||
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_oz, stride_oh, stride_om, stride_on),
|
||||
offsets=(off_hz // NH, off_hz % NH, start_m, 0),
|
||||
block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),
|
||||
order=(3, 2, 1, 0))
|
||||
q_tile_ptr = tl.make_block_ptr(
|
||||
base=Q,
|
||||
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qz, stride_qh, stride_qm, stride_qk),
|
||||
offsets=(off_hz // NH, off_hz % NH, start_m, 0),
|
||||
block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),
|
||||
order=(3, 2, 1, 0),
|
||||
)
|
||||
k_tile_ptr = tl.make_block_ptr(
|
||||
base=K,
|
||||
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_kz, stride_kh, stride_kn, stride_kk),
|
||||
offsets=(off_hz // NH, off_hz % NH, 0, 0),
|
||||
block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),
|
||||
order=(3, 2, 1, 0),
|
||||
)
|
||||
v_tile_ptr = tl.make_block_ptr(
|
||||
base=V,
|
||||
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_vz, stride_vh, stride_vk, stride_vn),
|
||||
offsets=(off_hz // NH, off_hz % NH, 0, 0),
|
||||
block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),
|
||||
order=(3, 2, 1, 0),
|
||||
)
|
||||
o_tile_ptr = tl.make_block_ptr(
|
||||
base=Out,
|
||||
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_oz, stride_oh, stride_om, stride_on),
|
||||
offsets=(off_hz // NH, off_hz % NH, start_m, 0),
|
||||
block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),
|
||||
order=(3, 2, 1, 0),
|
||||
)
|
||||
|
||||
q = tl.load(q_tile_ptr, boundary_check=(0, 1, 2, 3))
|
||||
q = tl.view(q, (BLOCK_M, BLOCK_DMODEL))
|
||||
@@ -155,12 +163,13 @@ def test_batched_gemm_fusion():
|
||||
ref_out = torch.matmul(torch.matmul(A, BT), C)
|
||||
num_warps = 4
|
||||
grid = (triton.cdiv(N_CTX, BLOCK_M), B * NH)
|
||||
batched_gemm_fusion[grid](A, B, C, E,
|
||||
A.stride(0), A.stride(1), A.stride(2), A.stride(3),
|
||||
B.stride(0), B.stride(1), B.stride(2), B.stride(3),
|
||||
C.stride(0), C.stride(1), C.stride(2), C.stride(3),
|
||||
E.stride(0), E.stride(1), E.stride(2), E.stride(3),
|
||||
Z, NH, N_CTX,
|
||||
BLOCK_M, BLOCK_DMODEL, BLOCK_N, num_warps=num_warps)
|
||||
batched_gemm_fusion[grid](
|
||||
A, B, C, E, #
|
||||
A.stride(0), A.stride(1), A.stride(2), A.stride(3), #
|
||||
B.stride(0), B.stride(1), B.stride(2), B.stride(3), #
|
||||
C.stride(0), C.stride(1), C.stride(2), C.stride(3), #
|
||||
E.stride(0), E.stride(1), E.stride(2), E.stride(3), #
|
||||
Z, NH, N_CTX, #
|
||||
BLOCK_M, BLOCK_DMODEL, BLOCK_N, num_warps=num_warps)
|
||||
|
||||
torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0)
|
||||
|
||||
@@ -24,10 +24,8 @@ def add_kernel(
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
x_block_ptr = tl.make_block_ptr(
|
||||
base=x_ptr, shape=(n_elements, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
|
||||
block_shape=(BLOCK_SIZE, ), order=(0, )
|
||||
)
|
||||
x_block_ptr = tl.make_block_ptr(base=x_ptr, shape=(n_elements, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
|
||||
block_shape=(BLOCK_SIZE, ), order=(0, ))
|
||||
x = tl.load(x_block_ptr, boundary_check=(0, ), padding_option='zero')
|
||||
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
@@ -36,9 +34,7 @@ def add_kernel(
|
||||
|
||||
|
||||
@pytest.mark.parametrize('SIZE,BLOCK_SIZE,dtype_str',
|
||||
[(98432, 1024, dtype_str)
|
||||
for dtype_str in ['float16', 'float32']
|
||||
])
|
||||
[(98432, 1024, dtype_str) for dtype_str in ['float16', 'float32']])
|
||||
def test_add(SIZE, BLOCK_SIZE, dtype_str):
|
||||
dtype = dtype_mapping[dtype_str]
|
||||
output = torch.empty(SIZE, device='cuda', dtype=dtype)
|
||||
@@ -46,7 +42,8 @@ def test_add(SIZE, BLOCK_SIZE, dtype_str):
|
||||
y = torch.randn(SIZE, device='cuda', dtype=dtype)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
|
||||
return (triton.cdiv(SIZE, meta['BLOCK_SIZE']), )
|
||||
|
||||
add_kernel[grid](x, y, output, SIZE, BLOCK_SIZE=BLOCK_SIZE)
|
||||
|
||||
output_torch = x + y
|
||||
@@ -64,25 +61,20 @@ def load_reduce_kernel(
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
x_ptr = tl.make_block_ptr(
|
||||
base=x_ptr, shape=(BLOCK_M, BLOCK_N), strides=(stride_xm, stride_xn),
|
||||
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)
|
||||
)
|
||||
x_ptr = tl.make_block_ptr(base=x_ptr, shape=(BLOCK_M, BLOCK_N), strides=(stride_xm, stride_xn), offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
|
||||
x = tl.load(x_ptr)
|
||||
y = tl.max(x, axis=1)
|
||||
tl.store(y_ptr + tl.arange(0, BLOCK_M), y)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,dtype_str',
|
||||
[(128, 64, dtype_str)
|
||||
for dtype_str in ['float16']
|
||||
])
|
||||
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,dtype_str', [(128, 64, dtype_str) for dtype_str in ['float16']])
|
||||
def test_load_reduce(BLOCK_M, BLOCK_N, dtype_str):
|
||||
dtype = dtype_mapping[dtype_str]
|
||||
x = torch.randn((BLOCK_M, BLOCK_N), device='cuda', dtype=dtype)
|
||||
y = torch.empty((BLOCK_M, ), device='cuda', dtype=dtype)
|
||||
|
||||
load_reduce_kernel[(1,)](x, y, x.stride(0), x.stride(1), y.stride(0), BLOCK_M, BLOCK_N)
|
||||
load_reduce_kernel[(1, )](x, y, x.stride(0), x.stride(1), y.stride(0), BLOCK_M, BLOCK_N)
|
||||
|
||||
golden = x.max(dim=1)[0]
|
||||
torch.set_printoptions(profile='full')
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
@@ -40,18 +39,17 @@ import triton.language as tl
|
||||
key=['Q', 'K', 'V'],
|
||||
)
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
L, M,
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
def _fwd_kernel(Q, K, V, sm_scale, #
|
||||
L, M, #
|
||||
Out, #
|
||||
stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vk, stride_vn, #
|
||||
stride_oz, stride_oh, stride_om, stride_on, #
|
||||
Z, H, N_CTX, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr #
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
# initialize offsets
|
||||
@@ -116,11 +114,10 @@ def _fwd_kernel(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO, L,
|
||||
NewDO, Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
):
|
||||
def _bwd_preprocess(Out, DO, L, #
|
||||
NewDO, Delta, #
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr #
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
# load
|
||||
@@ -136,19 +133,18 @@ def _bwd_preprocess(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L, M,
|
||||
D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX,
|
||||
num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
def _bwd_kernel(Q, K, V, sm_scale, Out, DO, #
|
||||
DQ, DK, DV, #
|
||||
L, M, #
|
||||
D, #
|
||||
stride_qz, stride_qh, stride_qm, stride_qk, #
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, #
|
||||
stride_vz, stride_vh, stride_vk, stride_vn, #
|
||||
Z, H, N_CTX, #
|
||||
num_block, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr, #
|
||||
):
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
@@ -240,16 +236,16 @@ class _attention(torch.autograd.Function):
|
||||
assert num_warps == 4
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
L, m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
q, k, v, sm_scale, #
|
||||
L, m, #
|
||||
o, #
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
|
||||
q.shape[0], q.shape[1], q.shape[2], #
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK, #
|
||||
BLOCK_DMODEL=Lk #
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
@@ -269,24 +265,23 @@ class _attention(torch.autograd.Function):
|
||||
do_scaled = torch.empty_like(do)
|
||||
delta = torch.empty_like(l)
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do, l,
|
||||
do_scaled, delta,
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
l, m,
|
||||
delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
ctx.grid[0],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||
num_stages=1,
|
||||
o, do, l, #
|
||||
do_scaled, delta, #
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL)
|
||||
_bwd_kernel[(ctx.grid[1], )](
|
||||
q, k, v, ctx.sm_scale, #
|
||||
o, do_scaled, #
|
||||
dq, dk, dv, #
|
||||
l, m, #
|
||||
delta, #
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
||||
q.shape[0], q.shape[1], q.shape[2], #
|
||||
ctx.grid[0], #
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK, #
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, #
|
||||
num_warps=8, num_stages=1 #
|
||||
)
|
||||
return dq, dk, dv, None
|
||||
|
||||
@@ -339,19 +334,19 @@ except BaseException:
|
||||
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = [triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
# x_vals=[2**i for i in range(10, 14)],
|
||||
x_vals=[2**i for i in range(10, 11)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
|
||||
# ) for mode in ['fwd', 'bwd']]
|
||||
) for mode in ['fwd']]
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
# x_vals=[2**i for i in range(10, 14)],
|
||||
x_vals=[2**i
|
||||
for i in range(10, 11)], line_arg='provider', line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), styles=[('red', '-'), ('blue', '-')], ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
|
||||
# ) for mode in ['fwd', 'bwd']]
|
||||
)
|
||||
for mode in ['fwd']
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
@@ -374,9 +369,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
if provider == "flash":
|
||||
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
||||
cu_seqlens = torch.zeros(
|
||||
(BATCH + 1,), device=device, dtype=torch.int32)
|
||||
lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device)
|
||||
cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32)
|
||||
cu_seqlens[1:] = lengths.cumsum(0)
|
||||
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
|
||||
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
|
||||
|
||||
@@ -29,14 +29,14 @@ import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def static_persistent_matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
NUM_SM: tl.constexpr,
|
||||
def static_persistent_matmul_kernel( #
|
||||
a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
NUM_SM: tl.constexpr #
|
||||
):
|
||||
start_tile = tl.program_id(axis=0)
|
||||
m_tiles = tl.cdiv(M, BLOCK_M)
|
||||
@@ -68,14 +68,14 @@ def static_persistent_matmul_kernel(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def static_persistent_tma_matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
NUM_SM: tl.constexpr,
|
||||
def static_persistent_tma_matmul_kernel( #
|
||||
a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
NUM_SM: tl.constexpr #
|
||||
):
|
||||
start_tile = tl.program_id(axis=0)
|
||||
m_tiles = tl.cdiv(M, BLOCK_M)
|
||||
@@ -88,8 +88,10 @@ def static_persistent_tma_matmul_kernel(
|
||||
|
||||
block_offset_m = pre_pid_m * BLOCK_M
|
||||
block_offset_n = pre_pid_n * BLOCK_N
|
||||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
for tile_id in range(start_tile, num_tiles, NUM_SM):
|
||||
pid_m = tile_id // n_tiles
|
||||
pid_n = tile_id % n_tiles
|
||||
@@ -114,21 +116,23 @@ def static_persistent_tma_matmul_kernel(
|
||||
pre_pid_n = pid_n
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA',
|
||||
[(*shape, use_tma)
|
||||
for shape in [
|
||||
[4096, 4096, 64, 64, 64, 16, 4, 1, False, True],
|
||||
[4096, 4096, 64, 64, 64, 32, 4, 1, False, True],
|
||||
[4096, 4096, 64, 256, 64, 16, 4, 1, False, True],
|
||||
[4096, 4096, 64, 128, 128, 16, 4, 1, False, True],
|
||||
# TODO: fix issue for 8-warp persistent kernel
|
||||
# [4096, 4096, 64, 128, 128, 16, 8, 1, False, True],
|
||||
# [4096, 4096, 64, 128, 256, 16, 8, 1, False, True],
|
||||
]
|
||||
for use_tma in [False, True]
|
||||
])
|
||||
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA', [(
|
||||
*shape, use_tma
|
||||
) for shape in [
|
||||
[4096, 4096, 64, 64, 64, 16, 4, 1, False, True],
|
||||
[4096, 4096, 64, 64, 64, 32, 4, 1, False, True
|
||||
],
|
||||
[4096, 4096, 64, 256, 64, 16, 4, 1, False, True
|
||||
],
|
||||
[4096, 4096, 64, 128, 128, 16, 4, 1, False, True
|
||||
],
|
||||
# TODO: fix issue for 8-warp persistent kernel
|
||||
# [4096, 4096, 64, 128, 128, 16, 8, 1, False, True],
|
||||
# [4096, 4096, 64, 128, 256, 16, 8, 1, False, True],
|
||||
] for use_tma in [False, True]])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA):
|
||||
def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS,
|
||||
TRANS_A, TRANS_B, USE_TMA):
|
||||
if (TRANS_A):
|
||||
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
@@ -141,25 +145,33 @@ def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLO
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
|
||||
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),)
|
||||
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )
|
||||
|
||||
if USE_TMA:
|
||||
static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
|
||||
static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0),
|
||||
stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS)
|
||||
else:
|
||||
static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
|
||||
static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0),
|
||||
stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS)
|
||||
|
||||
th_c = torch.matmul(a, b)
|
||||
torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def warp_specialized_matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
def warp_specialized_matmul_kernel( #
|
||||
a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
):
|
||||
tid = tl.program_id(axis=0)
|
||||
n_tiles = tl.cdiv(N, BLOCK_N)
|
||||
@@ -193,13 +205,13 @@ def warp_specialized_matmul_kernel(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tma_warp_specialized_matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
def tma_warp_specialized_matmul_kernel( #
|
||||
a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
):
|
||||
tid = tl.program_id(axis=0)
|
||||
n_tiles = tl.cdiv(N, BLOCK_N)
|
||||
@@ -232,8 +244,7 @@ def tma_warp_specialized_matmul_kernel(
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA',
|
||||
[(*shape, use_tma)
|
||||
for shape in [
|
||||
[(*shape, use_tma) for shape in [
|
||||
[2048, 2048, 64, 64, 64, 16, 1, False, True],
|
||||
[4096, 4096, 64, 64, 64, 16, 1, False, True],
|
||||
[128, 4096, 64, 64, 64, 16, 1, False, True],
|
||||
@@ -257,9 +268,7 @@ def tma_warp_specialized_matmul_kernel(
|
||||
[4096, 4096, 128, 256, 128, 64, 4, False, True],
|
||||
[4096, 4096, 256, 128, 256, 64, 4, False, True],
|
||||
[4096, 4096, 256, 256, 256, 64, 4, False, True],
|
||||
]
|
||||
for use_tma in [False, True]
|
||||
])
|
||||
] for use_tma in [False, True]])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA):
|
||||
if (TRANS_A):
|
||||
@@ -274,29 +283,29 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K
|
||||
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
grid = lambda META: (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),)
|
||||
grid = lambda META: (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), )
|
||||
|
||||
if USE_TMA:
|
||||
tma_warp_specialized_matmul_kernel[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
BLOCK_M, BLOCK_N, BLOCK_K,
|
||||
num_warps=4,
|
||||
num_ctas=NUM_CTAS,
|
||||
a, b, c, #
|
||||
M, N, K, #
|
||||
a.stride(0), a.stride(1), #
|
||||
b.stride(0), b.stride(1), #
|
||||
c.stride(0), c.stride(1), #
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, #
|
||||
num_warps=4, #
|
||||
num_ctas=NUM_CTAS, #
|
||||
enable_warp_specialization=True)
|
||||
else:
|
||||
warp_specialized_matmul_kernel[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
BLOCK_M, BLOCK_N, BLOCK_K,
|
||||
num_warps=4,
|
||||
num_ctas=NUM_CTAS,
|
||||
a, b, c, #
|
||||
M, N, K, #
|
||||
a.stride(0), a.stride(1), #
|
||||
b.stride(0), b.stride(1), #
|
||||
c.stride(0), c.stride(1), #
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, #
|
||||
num_warps=4, #
|
||||
num_ctas=NUM_CTAS, #
|
||||
enable_warp_specialization=True)
|
||||
|
||||
th_c = torch.matmul(a, b)
|
||||
@@ -304,14 +313,14 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K
|
||||
|
||||
|
||||
@triton.jit
|
||||
def static_persistent_warp_specialized_matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
NUM_SM: tl.constexpr,
|
||||
def static_persistent_warp_specialized_matmul_kernel( #
|
||||
a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
NUM_SM: tl.constexpr #
|
||||
):
|
||||
start_tile = tl.program_id(axis=0)
|
||||
m_tiles = tl.cdiv(M, BLOCK_M)
|
||||
@@ -343,14 +352,14 @@ def static_persistent_warp_specialized_matmul_kernel(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def static_persistent_tma_warp_specialized_matmul_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
NUM_SM: tl.constexpr,
|
||||
def static_persistent_tma_warp_specialized_matmul_kernel( #
|
||||
a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
NUM_SM: tl.constexpr #
|
||||
):
|
||||
start_tile = tl.program_id(axis=0)
|
||||
m_tiles = tl.cdiv(M, BLOCK_M)
|
||||
@@ -363,8 +372,10 @@ def static_persistent_tma_warp_specialized_matmul_kernel(
|
||||
|
||||
block_offset_m = pre_pid_m * BLOCK_M
|
||||
block_offset_n = pre_pid_n * BLOCK_N
|
||||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
for tile_id in range(start_tile, num_tiles, NUM_SM):
|
||||
pid_m = tile_id // n_tiles
|
||||
pid_n = tile_id % n_tiles
|
||||
@@ -390,8 +401,7 @@ def static_persistent_tma_warp_specialized_matmul_kernel(
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA',
|
||||
[(*shape, use_tma)
|
||||
for shape in [
|
||||
[(*shape, use_tma) for shape in [
|
||||
[2048, 2048, 64, 64, 64, 16, 1, False, True],
|
||||
[4096, 4096, 64, 64, 64, 16, 1, False, True],
|
||||
[128, 4096, 64, 64, 64, 16, 1, False, True],
|
||||
@@ -415,11 +425,10 @@ def static_persistent_tma_warp_specialized_matmul_kernel(
|
||||
[4096, 4096, 128, 256, 128, 64, 4, False, True],
|
||||
[4096, 4096, 256, 128, 256, 64, 4, False, True],
|
||||
[4096, 4096, 256, 256, 256, 64, 4, False, True],
|
||||
]
|
||||
for use_tma in [False, True]
|
||||
])
|
||||
] for use_tma in [False, True]])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA):
|
||||
def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B,
|
||||
USE_TMA):
|
||||
if (TRANS_A):
|
||||
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
@@ -432,27 +441,22 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
|
||||
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
|
||||
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),)
|
||||
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )
|
||||
|
||||
if USE_TMA:
|
||||
static_persistent_tma_warp_specialized_matmul_kernel[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs,
|
||||
num_warps=4, num_ctas=NUM_CTAS,
|
||||
a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M,
|
||||
BLOCK_N, BLOCK_K, num_SMs, num_warps=4, num_ctas=NUM_CTAS, #
|
||||
enable_warp_specialization=True)
|
||||
else:
|
||||
static_persistent_warp_specialized_matmul_kernel[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs,
|
||||
num_warps=4, num_ctas=NUM_CTAS,
|
||||
a, b, c, #
|
||||
M, N, K, #
|
||||
a.stride(0), a.stride(1), #
|
||||
b.stride(0), b.stride(1), #
|
||||
c.stride(0), c.stride(1), #
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs, #
|
||||
num_warps=4, num_ctas=NUM_CTAS, #
|
||||
enable_warp_specialization=True)
|
||||
|
||||
th_c = torch.matmul(a, b)
|
||||
@@ -460,16 +464,15 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N
|
||||
|
||||
|
||||
@triton.jit
|
||||
def static_persistent_matmul_no_scf_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr,
|
||||
NUM_SM: tl.constexpr, USE_TMA_LOAD: tl.constexpr,
|
||||
):
|
||||
def static_persistent_matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr, #
|
||||
NUM_SM: tl.constexpr, USE_TMA_LOAD: tl.constexpr #
|
||||
):
|
||||
start_tile = tl.program_id(axis=0)
|
||||
m_tiles = tl.cdiv(M, BLOCK_M)
|
||||
n_tiles = tl.cdiv(N, BLOCK_N)
|
||||
@@ -487,7 +490,8 @@ def static_persistent_matmul_no_scf_kernel(
|
||||
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
if USE_TMA_EPILOGUE:
|
||||
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
|
||||
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
|
||||
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(1, 0))
|
||||
|
||||
for tile_id in range(start_tile, num_tiles, NUM_SM):
|
||||
pid_m = tile_id // n_tiles
|
||||
@@ -524,29 +528,27 @@ def static_persistent_matmul_no_scf_kernel(
|
||||
pre_pid_n = pid_n
|
||||
|
||||
|
||||
@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,USE_TMA_LOAD',
|
||||
itertools.chain(
|
||||
*[
|
||||
[
|
||||
# numCTAs = 1, no TMA multicast:
|
||||
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
# small M, N
|
||||
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
] for USE_TMA_EPILOGUE in [True, False]
|
||||
for USE_TMA_LOAD in [True, False]
|
||||
]))
|
||||
@pytest.mark.parametrize(
|
||||
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,USE_TMA_LOAD',
|
||||
itertools.chain(*[[
|
||||
# numCTAs = 1, no TMA multicast:
|
||||
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
# small M, N
|
||||
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
|
||||
] for USE_TMA_EPILOGUE in [True, False] for USE_TMA_LOAD in [True, False]]))
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, USE_TMA_LOAD):
|
||||
def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE,
|
||||
USE_TMA_EPILOGUE, USE_TMA_LOAD):
|
||||
if (TRANS_A):
|
||||
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
@@ -564,46 +566,42 @@ def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TR
|
||||
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
|
||||
|
||||
# TODO: set `enable_warp_specialization=False` will lead to compilation error.
|
||||
static_persistent_matmul_no_scf_kernel[(num_SMs,)](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
M=M, N=N, K=K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SM=num_SMs,
|
||||
num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS,
|
||||
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"),
|
||||
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE,
|
||||
USE_TMA_LOAD=USE_TMA_LOAD,
|
||||
enable_warp_specialization=True)
|
||||
static_persistent_matmul_no_scf_kernel[(num_SMs, )](
|
||||
a_ptr=a, b_ptr=b, c_ptr=c, #
|
||||
M=M, N=N, K=K, #
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1), #
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1), #
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1), #
|
||||
BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SM=num_SMs, #
|
||||
num_warps=NUM_WARPS, #
|
||||
num_ctas=NUM_CTAS, #
|
||||
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), #
|
||||
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE, #
|
||||
USE_TMA_LOAD=USE_TMA_LOAD, #
|
||||
enable_warp_specialization=True)
|
||||
a_f32 = a.to(torch.float32)
|
||||
b_f32 = b.to(torch.float32)
|
||||
golden = torch.matmul(a_f32, b_f32)
|
||||
torch.set_printoptions(profile="full")
|
||||
assert_close(
|
||||
c,
|
||||
golden,
|
||||
rtol=1e-2,
|
||||
atol=1e-3,
|
||||
check_dtype=False)
|
||||
assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def full_static_persistent_matmul_kernel(
|
||||
a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_wm, stride_wn,
|
||||
stride_zm, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
|
||||
out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr,
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
||||
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
|
||||
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr,
|
||||
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr,
|
||||
NUM_SM: tl.constexpr
|
||||
):
|
||||
def full_static_persistent_matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_wm, stride_wn, #
|
||||
stride_zm, stride_zn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr, #
|
||||
out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, #
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, #
|
||||
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, #
|
||||
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, #
|
||||
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, #
|
||||
NUM_SM: tl.constexpr #
|
||||
):
|
||||
start_pid = tl.program_id(axis=0)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_N)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_M)
|
||||
@@ -618,15 +616,18 @@ def full_static_persistent_matmul_kernel(
|
||||
pre_block_offset_m = pre_pid_m * BLOCK_M
|
||||
pre_block_offset_n = pre_pid_n * BLOCK_N
|
||||
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(pre_block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(A_ORDER_0, A_ORDER_1))
|
||||
offsets=(pre_block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K),
|
||||
order=(A_ORDER_0, A_ORDER_1))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, pre_block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(B_ORDER_0, B_ORDER_1))
|
||||
offsets=(0, pre_block_offset_n), block_shape=(BLOCK_K, BLOCK_N),
|
||||
order=(B_ORDER_0, B_ORDER_1))
|
||||
w_tile_ptr = tl.make_block_ptr(base=w_ptr, shape=(N, N), strides=(stride_wm, stride_wn),
|
||||
offsets=(0, pre_block_offset_n), block_shape=(BLOCK_N, BLOCK_N), order=(0, 1))
|
||||
|
||||
if USE_TMA_STORE:
|
||||
z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn),
|
||||
offsets=(pre_block_offset_m, pre_block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
|
||||
offsets=(pre_block_offset_m, pre_block_offset_n),
|
||||
block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
|
||||
|
||||
for tile_id in range(start_pid, num_tiles, NUM_SM):
|
||||
group_id = tile_id // num_pid_in_group
|
||||
@@ -694,136 +695,120 @@ def full_static_persistent_matmul_kernel(
|
||||
pre_pid_n = pid_n
|
||||
|
||||
|
||||
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
|
||||
[
|
||||
# corner shapes
|
||||
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws)
|
||||
for shape_w_c in [
|
||||
[4096, 1, 1024, False, False],
|
||||
[2048, 204, 1000, True, False],
|
||||
[16, 524288, 32, False, True],
|
||||
]
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for enable_ws in [True]
|
||||
] + [
|
||||
# softmax epilogue
|
||||
(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
# softmax works for one CTA
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 64, 64, 64],
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[16, 16, 64, 4, 1, 16, 16, 64],
|
||||
# TODO: enable when num_warps != 4 is supported.
|
||||
# [64, 64, 32, 8, 1, 64, 64, 64],
|
||||
[128, 128, 64, 4, 1, 128, 128, 128],
|
||||
]
|
||||
for epilogue in ['softmax']
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False,]
|
||||
for trans_b in [True,]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [True]
|
||||
] + [
|
||||
# loop over tile shapes and transpose combinations
|
||||
(*shape_w_c, trans_a, trans_b, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 32, 4, 1, 128, 256, 64],
|
||||
[128, 128, 16, 4, 4, 512, 256, 64],
|
||||
[128, 256, 32, 4, 8, 256, 256, 192],
|
||||
[512, 256, 32, 4, 8, 1024, 256, 192],
|
||||
# BLOCK_K >= 128
|
||||
[64, 128, 128, 4, 1, 512, 256, 256],
|
||||
[128, 128, 128, 4, 1, 256, 256, 192],
|
||||
[128, 128, 128, 4, 2, 256, 256, 192],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 32, 32, 4, 1, 128, 256, 64],
|
||||
[32, 32, 16, 4, 1, 256, 256, 192],
|
||||
[16, 32, 64, 4, 4, 512, 256, 64],
|
||||
]
|
||||
for out_dtype in ['float32',]
|
||||
for use_tma_store in [False,]
|
||||
for trans_a in [False, True]
|
||||
for trans_b in [False, True]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [True]
|
||||
] + [
|
||||
# loop over epilogues besides of softmax
|
||||
(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 128, 128, 64],
|
||||
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4] for num_ctas in [1, 2, 4]],
|
||||
# for chain-dot
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[64, 64, 16, 4, 1, None, None, None],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 16, 64, 4, 1, 128, 128, 64],
|
||||
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4] for num_ctas in [1, 2]],
|
||||
# # TODO: enable when num_warps != 4 is supported.
|
||||
# # repeat
|
||||
# # [64, 64, 32, 8, 1, 128, 256, 64],
|
||||
# # [64, 64, 16, 8, 2, 128, 128, 64],
|
||||
# irregular shape
|
||||
[128, 128, 64, 4, 1, 500, 200, 128],
|
||||
[128, 128, 64, 4, 1, 513, 193, 192],
|
||||
]
|
||||
for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot']
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False,]
|
||||
for trans_b in [True,]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [True]
|
||||
if not (epilogue == 'chain-dot' and (shape_w_c[5] is not None or shape_w_c[0] != shape_w_c[1]))
|
||||
] + [
|
||||
# loop over instr shapes & pipeline stages
|
||||
(64, n, 16, 4, 1, 512, 256, 256, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for n in [16, 32, 64, 128, 256]
|
||||
for out_dtype in ['float32']
|
||||
for use_tma_store in [False,]
|
||||
for num_stages in [2, 4, 5, 7]
|
||||
for enable_ws in [True]
|
||||
] + [
|
||||
# irregular shapes
|
||||
(*shape_w_c, *shape, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[128, 128, 64, 4, 1],
|
||||
[256, 128, 64, 4, 2],
|
||||
[128, 128, 128, 4, 2]
|
||||
]
|
||||
for shape in [
|
||||
[512, 360, 1024],
|
||||
[360, 4096, 512],
|
||||
]
|
||||
for out_dtype in ['float32']
|
||||
for use_tma_store in [False, True]
|
||||
for num_stages in [3, 4]
|
||||
for enable_ws in [True]
|
||||
]
|
||||
)
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()
|
||||
[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS])) in [
|
||||
'128-128-128-4-1-256-256-192-none-float32-True-3-True',
|
||||
]:
|
||||
@pytest.mark.parametrize(
|
||||
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
|
||||
[
|
||||
# corner shapes
|
||||
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws) for shape_w_c in [
|
||||
[4096, 1, 1024, False, False],
|
||||
[2048, 204, 1000, True, False],
|
||||
[16, 524288, 32, False, True],
|
||||
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for enable_ws in [True]
|
||||
] + [
|
||||
# softmax epilogue
|
||||
(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
# softmax works for one CTA
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 64, 64, 64],
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[16, 16, 64, 4, 1, 16, 16, 64],
|
||||
# TODO: enable when num_warps != 4 is supported.
|
||||
# [64, 64, 32, 8, 1, 64, 64, 64],
|
||||
[128, 128, 64, 4, 1, 128, 128, 128],
|
||||
]
|
||||
for epilogue in ['softmax']
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False]
|
||||
for trans_b in [True]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [True]
|
||||
] + [
|
||||
# loop over tile shapes and transpose combinations
|
||||
(*shape_w_c, trans_a, trans_b, 'none', out_dtype, use_tma_store, num_stages, enable_ws) for shape_w_c in [
|
||||
[64, 64, 32, 4, 1, 128, 256, 64],
|
||||
[128, 128, 16, 4, 4, 512, 256, 64],
|
||||
[128, 256, 32, 4, 8, 256, 256, 192],
|
||||
[512, 256, 32, 4, 8, 1024, 256, 192],
|
||||
# BLOCK_K >= 128
|
||||
[64, 128, 128, 4, 1, 512, 256, 256],
|
||||
[128, 128, 128, 4, 1, 256, 256, 192],
|
||||
[128, 128, 128, 4, 2, 256, 256, 192],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 32, 32, 4, 1, 128, 256, 64],
|
||||
[32, 32, 16, 4, 1, 256, 256, 192],
|
||||
[16, 32, 64, 4, 4, 512, 256, 64],
|
||||
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
|
||||
[False, True] for num_stages in [3] for enable_ws in [True]
|
||||
] + [
|
||||
# loop over epilogues besides of softmax
|
||||
(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 128, 128, 64],
|
||||
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4] for num_ctas in [1, 2, 4]],
|
||||
# for chain-dot
|
||||
[128, 128, 64, 4, 1, None, None, None],
|
||||
[64, 64, 16, 4, 1, None, None, None],
|
||||
# small BLOCK_M and BLOCK_K
|
||||
[16, 16, 64, 4, 1, 128, 128, 64],
|
||||
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4] for num_ctas in [1, 2]],
|
||||
# # TODO: enable when num_warps != 4 is supported.
|
||||
# # repeat
|
||||
# # [64, 64, 32, 8, 1, 128, 256, 64],
|
||||
# # [64, 64, 16, 8, 2, 128, 128, 64],
|
||||
# irregular shape
|
||||
[128, 128, 64, 4, 1, 500, 200, 128],
|
||||
[128, 128, 64, 4, 1, 513, 193, 192],
|
||||
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'] for out_dtype in
|
||||
['float16', 'float32'] for use_tma_store in [False, True] for trans_a in [False] for trans_b in [True] for
|
||||
num_stages in [3] for enable_ws in [True] if not (epilogue == 'chain-dot' and
|
||||
(shape_w_c[5] is not None or shape_w_c[0] != shape_w_c[1]))
|
||||
] + [
|
||||
# loop over instr shapes & pipeline stages
|
||||
(64, n, 16, 4, 1, 512, 256, 256, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for n in [16, 32, 64, 128, 256]
|
||||
for out_dtype in ['float32']
|
||||
for use_tma_store in [False]
|
||||
for num_stages in [2, 4, 5, 7]
|
||||
for enable_ws in [True]
|
||||
] + [
|
||||
# irregular shapes
|
||||
(*shape_w_c, *shape, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [[128, 128, 64, 4, 1], [256, 128, 64, 4, 2], [128, 128, 128, 4, 2]]
|
||||
for shape in [
|
||||
[512, 360, 1024],
|
||||
[360, 4096, 512],
|
||||
]
|
||||
for out_dtype in ['float32']
|
||||
for use_tma_store in [False, True]
|
||||
for num_stages in [3, 4]
|
||||
for enable_ws in [True]
|
||||
])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B,
|
||||
epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
|
||||
if '-'.join(
|
||||
map(str, [
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES,
|
||||
ENABLE_WS
|
||||
])) in [
|
||||
'128-128-128-4-1-256-256-192-none-float32-True-3-True',
|
||||
]:
|
||||
pytest.skip('out of resource: shared memory, Required: 263168')
|
||||
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
|
||||
'16-32-64-4-4-512-256-64-True-False',
|
||||
'16-32-64-4-4-512-256-64-True-True',
|
||||
'16-32-64-4-4-512-256-64-False-False',
|
||||
'16-32-64-4-4-512-256-64-False-True',
|
||||
'16-32-64-4-4-512-256-64-True-False',
|
||||
'16-32-64-4-4-512-256-64-True-True',
|
||||
'16-32-64-4-4-512-256-64-False-False',
|
||||
'16-32-64-4-4-512-256-64-False-True',
|
||||
]:
|
||||
pytest.skip('shapePerCTA[1] < 16 not supported')
|
||||
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
|
||||
'16-32-64-4-1-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-True',
|
||||
'16-32-64-8-2-256-256-256-False',
|
||||
'16-32-64-8-2-256-256-256-True',
|
||||
'16-32-64-4-1-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-True',
|
||||
'16-32-64-8-2-256-256-256-False',
|
||||
'16-32-64-8-2-256-256-256-True',
|
||||
]:
|
||||
pytest.skip('Known legacy issue, ldmatrix can only support x4')
|
||||
|
||||
@@ -893,37 +878,36 @@ def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WAR
|
||||
else:
|
||||
ref = d
|
||||
return ref
|
||||
|
||||
golden = process_epilogue(dot, bias, w, epilogue)
|
||||
|
||||
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
|
||||
|
||||
def grid(META):
|
||||
return (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),)
|
||||
return (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )
|
||||
|
||||
full_static_persistent_matmul_kernel[grid](
|
||||
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z,
|
||||
M=M, N=N, K=K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_wm=w.stride(0), stride_wn=w.stride(1),
|
||||
stride_zm=z.stride(0), stride_zn=z.stride(1),
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8,
|
||||
out_dtype=out_dtype,
|
||||
USE_TMA_STORE=USE_TMA_STORE,
|
||||
ADD_MATRIX=epilogue == 'add-matrix',
|
||||
ADD_ROWS=epilogue == 'add-rows',
|
||||
ADD_COLS=epilogue == 'add-cols',
|
||||
DO_SOFTMAX=epilogue == 'softmax',
|
||||
CHAIN_DOT=epilogue == 'chain-dot',
|
||||
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1],
|
||||
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1],
|
||||
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES,
|
||||
enable_warp_specialization=ENABLE_WS,
|
||||
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
|
||||
M=M, N=N, K=K, #
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1), #
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1), #
|
||||
stride_wm=w.stride(0), stride_wn=w.stride(1), #
|
||||
stride_zm=z.stride(0), stride_zn=z.stride(1), #
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
|
||||
out_dtype=out_dtype, #
|
||||
USE_TMA_STORE=USE_TMA_STORE, #
|
||||
ADD_MATRIX=epilogue == 'add-matrix', #
|
||||
ADD_ROWS=epilogue == 'add-rows', #
|
||||
ADD_COLS=epilogue == 'add-cols', #
|
||||
DO_SOFTMAX=epilogue == 'softmax', #
|
||||
CHAIN_DOT=epilogue == 'chain-dot', #
|
||||
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
|
||||
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
|
||||
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, #
|
||||
enable_warp_specialization=ENABLE_WS, #
|
||||
NUM_SM=num_SMs)
|
||||
|
||||
torch.set_printoptions(profile="full")
|
||||
golden = torch.nn.functional.normalize(golden)
|
||||
z = torch.nn.functional.normalize(z)
|
||||
assert_close(z, golden,
|
||||
rtol=1e-2,
|
||||
atol=1e-3,
|
||||
check_dtype=False)
|
||||
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
|
||||
|
||||
@@ -19,7 +19,6 @@
|
||||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
@@ -29,21 +28,21 @@ import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def matmul_tma_load_store(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
OUTPUT_F16: tl.constexpr
|
||||
def matmul_tma_load_store( #
|
||||
a_ptr, b_ptr, c_ptr, #
|
||||
M, N, K, #
|
||||
stride_am, stride_ak, #
|
||||
stride_bk, stride_bn, #
|
||||
stride_cm, stride_cn, #
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
||||
OUTPUT_F16: tl.constexpr #
|
||||
):
|
||||
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
|
||||
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
|
||||
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
|
||||
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
|
||||
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0),
|
||||
block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
|
||||
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
|
||||
a = tl.load(a_block_ptr)
|
||||
b = tl.load(b_block_ptr)
|
||||
|
||||
@@ -78,15 +77,15 @@ def test_tma_load_store(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_F
|
||||
if OUTPUT_F16:
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
|
||||
|
||||
matmul_tma_load_store[(1, 1)](a_ptr=a, b_ptr=b, c_ptr=c,
|
||||
M=M, N=N, K=K,
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1),
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1),
|
||||
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K,
|
||||
num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS,
|
||||
OUTPUT_F16=OUTPUT_F16)
|
||||
matmul_tma_load_store[(1, 1)](
|
||||
a_ptr=a, b_ptr=b, c_ptr=c, #
|
||||
M=M, N=N, K=K, #
|
||||
stride_am=a.stride(0), stride_ak=a.stride(1), #
|
||||
stride_bk=b.stride(0), stride_bn=b.stride(1), #
|
||||
stride_cm=c.stride(0), stride_cn=c.stride(1), #
|
||||
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, #
|
||||
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, #
|
||||
OUTPUT_F16=OUTPUT_F16)
|
||||
golden = torch.matmul(a, b)
|
||||
torch.set_printoptions(profile="full")
|
||||
assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
|
||||
|
||||
@@ -54,17 +54,13 @@ def test_tma_wgmma_64_64_16_f16(TTGIR, TRANS_A, TRANS_B):
|
||||
|
||||
ttgir_path = os.path.dirname(__file__) + "/" + TTGIR
|
||||
kernel = triton.compile(ttgir_path)
|
||||
kernel[(1, 1, 1)](a.data_ptr(), b.data_ptr(), c.data_ptr(),
|
||||
SIZE_M, SIZE_N, SIZE_K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0))
|
||||
kernel[(1, 1, 1)]( #
|
||||
a.data_ptr(), b.data_ptr(), c.data_ptr(), #
|
||||
SIZE_M, SIZE_N, SIZE_K, #
|
||||
a.stride(0), a.stride(1), #
|
||||
b.stride(0), b.stride(1), #
|
||||
c.stride(0))
|
||||
|
||||
golden = torch.matmul(a, b)
|
||||
torch.set_printoptions(profile="full", sci_mode=False)
|
||||
assert_close(
|
||||
c,
|
||||
golden,
|
||||
rtol=1e-2,
|
||||
atol=1e-3,
|
||||
check_dtype=False)
|
||||
assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
|
||||
|
||||
Reference in New Issue
Block a user