Reformat Python code with yapf. (#2589)

I've add an option to yapf to do what we want for long lines, see
https://github.com/google/yapf/pull/1177.  We can now have a real Python
formatter, yay!

To make this PR, I ran my modified yapf over the repository, then looked
over the full diff.  Where yapf was mangling the param list of long
function decls/calls (mostly kernels), I manually added `#` to put
linebreaks where we want.  I fixed up other formatting too -- mostly
adding or removing a trailing comma from lists.

Overall, trailing `#` was sufficient to get formatting similar to our
current code.  I didn't have to disable yapf anywhere.

---------

Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
Justin Lebar
2023-11-02 20:44:17 -07:00
committed by GitHub
parent dced22c4b7
commit df08301e76
85 changed files with 3802 additions and 3880 deletions

View File

@@ -18,7 +18,6 @@
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
Fused Attention
===============
@@ -35,18 +34,15 @@ import triton.language as tl
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
L, M,
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX, D0,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
def _fwd_kernel(Q, K, V, sm_scale, #
L, M, #
Out, #
stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vk, stride_vn, #
stride_oz, stride_oh, stride_om, stride_on, #
Z, H, N_CTX, D0, #
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
@@ -61,31 +57,38 @@ def _fwd_kernel(
stride_qh_2d = stride_qh // stride_qm // stride_qk
q_tile_ptr = tl.make_block_ptr(base=Q,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(
off_hz * stride_qh_2d + start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
k_tile_ptr = tl.make_block_ptr(base=K,
shape=(D0, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(off_hz * stride_qh_2d, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0))
v_tile_ptr = tl.make_block_ptr(base=V,
shape=(D0, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(off_hz * stride_qh_2d, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0))
out_tile_ptr = tl.make_block_ptr(base=Out,
shape=(D0, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
q_tile_ptr = tl.make_block_ptr(
base=Q,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
k_tile_ptr = tl.make_block_ptr(
base=K,
shape=(D0, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(off_hz * stride_qh_2d, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
v_tile_ptr = tl.make_block_ptr(
base=V,
shape=(D0, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(off_hz * stride_qh_2d, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
out_tile_ptr = tl.make_block_ptr(
base=Out,
shape=(D0, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
# load q: it will stay in SRAM throughout
q = tl.load(q_tile_ptr)
@@ -96,8 +99,7 @@ def _fwd_kernel(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (
start_n + offs_n[None, :]), qk, float("-inf"))
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# compute new m
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
# correct old l
@@ -133,11 +135,9 @@ def _fwd_kernel(
@triton.jit
def _bwd_preprocess(
Out, DO, L,
NewDO, Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
def _bwd_preprocess(Out, DO, L, #
NewDO, Delta, #
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
@@ -153,19 +153,14 @@ def _bwd_preprocess(
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L, M,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX, D0,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
def _bwd_kernel(Q, K, V, sm_scale, Out, DO, #
DQ, DK, DV, #
L, M, #
D, stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vk, stride_vn, #
Z, H, N_CTX, D0, #
num_block, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
@@ -173,55 +168,62 @@ def _bwd_kernel(
stride_qz_2d = stride_qz // stride_qm // stride_qk
stride_qh_2d = stride_qh // stride_qm // stride_qk
q_tile_ptr = tl.make_block_ptr(base=Q,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
k_tile_ptr = tl.make_block_ptr(base=K,
shape=(D0, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
v_tile_ptr = tl.make_block_ptr(base=V,
shape=(D0, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
do_tile_ptr = tl.make_block_ptr(base=DO,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
dq_tile_ptr = tl.make_block_ptr(base=DQ,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
dk_tile_ptr = tl.make_block_ptr(base=DK,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
dv_tile_ptr = tl.make_block_ptr(base=DV,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(
off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
q_tile_ptr = tl.make_block_ptr(
base=Q,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
k_tile_ptr = tl.make_block_ptr(
base=K,
shape=(D0, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
v_tile_ptr = tl.make_block_ptr(
base=V,
shape=(D0, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
do_tile_ptr = tl.make_block_ptr(
base=DO,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
dq_tile_ptr = tl.make_block_ptr(
base=DQ,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
dk_tile_ptr = tl.make_block_ptr(
base=DK,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
dv_tile_ptr = tl.make_block_ptr(
base=DV,
shape=(D0, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
# offset pointers for batch/head
DQ += off_z * stride_qz + off_h * stride_qh
for start_n in range(0, num_block):
@@ -250,8 +252,7 @@ def _bwd_kernel(
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, tl.trans(k))
qk = tl.where(offs_m_curr[:, None] >= (
offs_n[None, :]), qk, float("-inf"))
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
@@ -301,29 +302,21 @@ class _attention(torch.autograd.Function):
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
L = torch.empty(
(q.shape[0] * q.shape[1], q.shape[2]),
device=q.device,
dtype=torch.float32)
m = torch.empty(
(q.shape[0] * q.shape[1], q.shape[2]),
device=q.device,
dtype=torch.float32)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
D0 = q.shape[0] * q.shape[1] * q.shape[2]
_fwd_kernel[grid](
q, k, v, sm_scale,
L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2], D0,
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=2,
)
q, k, v, sm_scale, #
L, m, #
o, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
q.shape[0], q.shape[1], q.shape[2], D0, #
BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=Lk, #
num_warps=num_warps, num_stages=2)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.grid = grid
@@ -343,25 +336,22 @@ class _attention(torch.autograd.Function):
delta = torch.empty_like(l)
D0 = q.shape[0] * q.shape[1] * q.shape[2]
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
do_scaled, delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2], D0,
ctx.grid[0],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
num_stages=1,
)
o, do, l, #
do_scaled, delta, #
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL)
_bwd_kernel[(ctx.grid[1], )](
q, k, v, ctx.sm_scale, #
o, do_scaled, #
dq, dk, dv, #
l, m, #
delta, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
q.shape[0], q.shape[1], q.shape[2], D0, #
ctx.grid[0], #
BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=ctx.BLOCK_DMODEL, #
num_warps=8, num_stages=1)
return dq, dk, dv, None
@@ -380,15 +370,9 @@ attention = _attention.apply
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+")
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty(
(Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(
mean=0.1, std=0.2).requires_grad_()
k = torch.empty(
(Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(
mean=0.4, std=0.2).requires_grad_()
v = torch.empty(
(Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(
mean=0.3, std=0.2).requires_grad_()
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
sm_scale = 0.2
dout = torch.randn_like(q)
# reference implementation
@@ -427,22 +411,25 @@ except BaseException:
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = [triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 14)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={
'H': N_HEADS,
'BATCH': BATCH,
'D_HEAD': D_HEAD,
'dtype': torch.float16,
'mode': mode}
) for mode in ['fwd', 'bwd']]
configs = [
triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 14)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={
'H': N_HEADS,
'BATCH': BATCH,
'D_HEAD': D_HEAD,
'dtype': torch.float16,
'mode': mode,
},
) for mode in ['fwd', 'bwd']
]
@triton.testing.perf_report(configs)
@@ -463,9 +450,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros(
(BATCH + 1,), device=device, dtype=torch.int32)
lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)

View File

@@ -32,19 +32,30 @@ import triton.language as tl
@triton.jit
def matmul_no_scf_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr
):
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
def matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr #
):
a_block_ptr = tl.make_block_ptr(
base=a_ptr,
shape=(M, K),
strides=(stride_am, stride_ak),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_K),
order=(1, 0),
)
b_block_ptr = tl.make_block_ptr(
base=b_ptr,
shape=(K, N),
strides=(stride_bk, stride_bn),
offsets=(0, 0),
block_shape=(BLOCK_K, BLOCK_N),
order=(0, 1),
)
a = tl.load(a_block_ptr)
b = tl.load(b_block_ptr)
@@ -54,8 +65,8 @@ def matmul_no_scf_kernel(
c = c.to(tl.float16)
if USE_TMA_EPILOGUE:
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
tl.store(c_block_ptr, c)
else:
offs_m = tl.arange(0, BLOCK_M)
@@ -64,33 +75,30 @@ def matmul_no_scf_kernel(
tl.store(c_ptrs, c)
@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,ENABLE_WS',
itertools.chain(
*[
[
# numCTAs = 1, no TMA multicast:
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
# static mask, cluster 4x1
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
# dynamic mask, cluster 2x2
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
# small M, N
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
] for USE_TMA_EPILOGUE in [True, False]
for ENABLE_WS in [False, True]
]))
@pytest.mark.parametrize(
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,ENABLE_WS',
itertools.chain(*[[
# numCTAs = 1, no TMA multicast:
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
# static mask, cluster 4x1
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
# dynamic mask, cluster 2x2
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS],
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
# small M, N
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS],
] for USE_TMA_EPILOGUE in [True, False] for ENABLE_WS in [False, True]]))
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, ENABLE_WS):
if (TRANS_A):
@@ -107,46 +115,41 @@ def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE
else:
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
matmul_no_scf_kernel[(1, 1)](a_ptr=a, b_ptr=b, c_ptr=c,
M=M, N=N, K=K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"),
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE,
enable_warp_specialization=ENABLE_WS)
matmul_no_scf_kernel[(1, 1)](
a_ptr=a, b_ptr=b, c_ptr=c, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_cm=c.stride(0), stride_cn=c.stride(1), #
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, #
num_warps=NUM_WARPS, #
num_ctas=NUM_CTAS, #
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), #
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE, #
enable_warp_specialization=ENABLE_WS)
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
golden = torch.matmul(a_f32, b_f32)
torch.set_printoptions(profile="full")
assert_close(
c,
golden,
rtol=1e-2,
atol=1e-3,
check_dtype=False)
assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_wm, stride_wn,
stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr,
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr,
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr,
W_ORDER_0: tl.constexpr, W_ORDER_1: tl.constexpr,
Z_ORDER_0: tl.constexpr, Z_ORDER_1: tl.constexpr
):
def matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_wm, stride_wn, #
stride_zm, stride_zn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, #
out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, #
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, #
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, #
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, #
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, #
W_ORDER_0: tl.constexpr, W_ORDER_1: tl.constexpr, #
Z_ORDER_0: tl.constexpr, Z_ORDER_1: tl.constexpr #
):
pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_m = tl.cdiv(M, BLOCK_M)
@@ -159,13 +162,31 @@ def matmul_kernel(
block_offset_m = pid_m * BLOCK_M
block_offset_n = pid_n * BLOCK_N
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(A_ORDER_0, A_ORDER_1))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(B_ORDER_0, B_ORDER_1))
a_tile_ptr = tl.make_block_ptr(
base=a_ptr,
shape=(M, K),
strides=(stride_am, stride_ak),
offsets=(block_offset_m, 0),
block_shape=(BLOCK_M, BLOCK_K),
order=(A_ORDER_0, A_ORDER_1),
)
b_tile_ptr = tl.make_block_ptr(
base=b_ptr,
shape=(K, N),
strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n),
block_shape=(BLOCK_K, BLOCK_N),
order=(B_ORDER_0, B_ORDER_1),
)
# for chain-dot, BLOCK_N must always be equal to N, and each program loads the whole W matrix
w_tile_ptr = tl.make_block_ptr(base=w_ptr, shape=(N, N), strides=(stride_wm, stride_wn),
offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_N), order=(W_ORDER_0, W_ORDER_1))
w_tile_ptr = tl.make_block_ptr(
base=w_ptr,
shape=(N, N),
strides=(stride_wm, stride_wn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_N),
order=(W_ORDER_0, W_ORDER_1),
)
z = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
offs_m = block_offset_m + tl.arange(0, BLOCK_M)
@@ -204,139 +225,146 @@ def matmul_kernel(
if USE_TMA_STORE:
z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn),
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(Z_ORDER_0, Z_ORDER_1))
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N),
order=(Z_ORDER_0, Z_ORDER_1))
tl.store(z_block_ptr, z, boundary_check=(0, 1))
else:
tl.store(z_ptrs, z, mask=mask)
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for enable_ws in [False, True]
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
]
for epilogue in ['softmax']
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for trans_a in [False,]
for trans_b in [True,]
for trans_output in [False,]
for num_stages in [3]
for enable_ws in [False, True]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
]
for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot']
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for trans_a in [False,]
for trans_b in [True,]
for trans_output in [False,]
for num_stages in [3]
for enable_ws in [False, True]
if not (epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
]
for out_dtype in ['float32',]
for use_tma_store in [False,]
for trans_a in [False, True]
for trans_b in [False, True]
for trans_output in [False, True]
for num_stages in [3]
for enable_ws in [False, True]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for n in [16, 32, 64, 128, 256]
for trans_output in [False,]
for out_dtype in ['float32',]
for use_tma_store in [False,]
for num_stages in [2, 4, 5, 7]
for enable_ws in [False, True]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
]
for shape in [
[512, 360, 1024],
[360, 4096, 512],
]
for trans_output in [False,]
for out_dtype in ['float32',]
for use_tma_store in [False, True]
for num_stages in [3, 4]
for enable_ws in [False, True]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()
[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
for enable_ws in [False, True]
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
]
for epilogue in ['softmax']
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for trans_a in [False]
for trans_b in [True]
for trans_output in [False]
for num_stages in [3]
for enable_ws in [False, True]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64]
for num_warps in [4, 8]
for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256]
for num_warps in [4, 8]
for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
]
for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot']
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for trans_a in [False]
for trans_b in [True]
for trans_output in [False]
for num_stages in [3]
for enable_ws in [False, True]
if not (epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
]
for out_dtype in ['float32']
for use_tma_store in [False]
for trans_a in [False, True]
for trans_b in [False, True]
for trans_output in [False, True]
for num_stages in [3]
for enable_ws in [False, True]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages,
enable_ws)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
for enable_ws in [False, True]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
]
for shape in [
[512, 360, 1024],
[360, 4096, 512],
]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False, True]
for num_stages in [3, 4]
for enable_ws in [False, True]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower()
@@ -413,38 +441,38 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A,
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
pgm = matmul_kernel[grid](a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z,
M=M, N=N, K=K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_wm=w.stride(0), stride_wn=w.stride(1),
stride_zm=z.stride(0), stride_zn=z.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8,
out_dtype=out_dtype,
USE_TMA_STORE=USE_TMA_STORE,
ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS=epilogue == 'add-rows',
ADD_COLS=epilogue == 'add-cols',
DO_SOFTMAX=epilogue == 'softmax',
CHAIN_DOT=epilogue == 'chain-dot',
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1],
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1],
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1],
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1],
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES,
enable_warp_specialization=ENABLE_WS)
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, #
enable_warp_specialization=ENABLE_WS)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden,
rtol=1e-2,
atol=1e-3,
check_dtype=False)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:

View File

@@ -27,16 +27,20 @@ import triton.language as tl
@triton.jit
def gemm_fusion_kernel(A, B, C, E,
M, N, K,
stride_am, stride_ak, stride_bn, stride_bk, stride_cn, stride_ck, stride_em, stride_ek,
def gemm_fusion_kernel(A, B, C, E, #
M, N, K, #
stride_am, stride_ak, stride_bn, stride_bk, stride_cn, stride_ck, stride_em, stride_ek, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
pid = tl.program_id(0)
a_tile_ptr = tl.make_block_ptr(base=A, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=B, shape=(N, K), strides=(stride_bn, stride_bk), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))
c_tile_ptr = tl.make_block_ptr(base=C, shape=(N, K), strides=(stride_cn, stride_ck), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))
e_tile_ptr = tl.make_block_ptr(base=E, shape=(M, K), strides=(stride_em, stride_ek), offsets=(pid * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
a_tile_ptr = tl.make_block_ptr(base=A, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=B, shape=(N, K), strides=(stride_bn, stride_bk), offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))
c_tile_ptr = tl.make_block_ptr(base=C, shape=(N, K), strides=(stride_cn, stride_ck), offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_K), order=(1, 0))
e_tile_ptr = tl.make_block_ptr(base=E, shape=(M, K), strides=(stride_em, stride_ek), offsets=(pid * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
acc_e = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
a = tl.load(a_tile_ptr)
@@ -57,66 +61,70 @@ def gemm_fusion_kernel(A, B, C, E,
def test_gemm_fusion():
M, N, K = 4096, 4096, 64
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64
A = torch.empty(
(M, K), dtype=torch.float16, device='cuda').normal_(
mean=0.1, std=0.2)
B = torch.empty(
(N, K), dtype=torch.float16, device='cuda').normal_(
mean=0.1, std=0.2)
C = torch.empty(
(N, K), dtype=torch.float16, device='cuda').normal_(
mean=0.1, std=0.2)
A = torch.empty((M, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)
B = torch.empty((N, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)
C = torch.empty((N, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2)
E = torch.empty((M, K), dtype=torch.float16, device='cuda')
ref_out = torch.matmul(torch.matmul(A, B.T), C)
num_warps = 4
grid = (triton.cdiv(M, BLOCK_M), 1)
gemm_fusion_kernel[grid](A, B, C, E, M, N, K,
A.stride(0), A.stride(1), B.stride(0), B.stride(
1), C.stride(0), C.stride(1), E.stride(0), E.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K, num_warps=num_warps)
gemm_fusion_kernel[grid](
A, B, C, E, M, N, K, #
A.stride(0), A.stride(1), #
B.stride(0), B.stride(1), #
C.stride(0), C.stride(1), #
E.stride(0), E.stride(1), #
BLOCK_M, BLOCK_N, BLOCK_K, #
num_warps=num_warps)
torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0)
@triton.jit
def batched_gemm_fusion(
Q, K, V, Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, NH, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
def batched_gemm_fusion(Q, K, V, Out, #
stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vk, stride_vn, #
stride_oz, stride_oh, stride_om, stride_on, #
Z, NH, N_CTX, #
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
BLOCK_N: tl.constexpr):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
q_tile_ptr = tl.make_block_ptr(base=Q,
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
strides=(stride_qz, stride_qh, stride_qm, stride_qk),
offsets=(off_hz // NH, off_hz % NH, start_m, 0),
block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),
order=(3, 2, 1, 0))
k_tile_ptr = tl.make_block_ptr(base=K,
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
strides=(stride_kz, stride_kh, stride_kn, stride_kk),
offsets=(off_hz // NH, off_hz % NH, 0, 0),
block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),
order=(3, 2, 1, 0))
v_tile_ptr = tl.make_block_ptr(base=V,
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
strides=(stride_vz, stride_vh, stride_vk, stride_vn),
offsets=(off_hz // NH, off_hz % NH, 0, 0),
block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),
order=(3, 2, 1, 0))
o_tile_ptr = tl.make_block_ptr(base=Out,
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
strides=(stride_oz, stride_oh, stride_om, stride_on),
offsets=(off_hz // NH, off_hz % NH, start_m, 0),
block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),
order=(3, 2, 1, 0))
q_tile_ptr = tl.make_block_ptr(
base=Q,
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
strides=(stride_qz, stride_qh, stride_qm, stride_qk),
offsets=(off_hz // NH, off_hz % NH, start_m, 0),
block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),
order=(3, 2, 1, 0),
)
k_tile_ptr = tl.make_block_ptr(
base=K,
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
strides=(stride_kz, stride_kh, stride_kn, stride_kk),
offsets=(off_hz // NH, off_hz % NH, 0, 0),
block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),
order=(3, 2, 1, 0),
)
v_tile_ptr = tl.make_block_ptr(
base=V,
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
strides=(stride_vz, stride_vh, stride_vk, stride_vn),
offsets=(off_hz // NH, off_hz % NH, 0, 0),
block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL),
order=(3, 2, 1, 0),
)
o_tile_ptr = tl.make_block_ptr(
base=Out,
shape=(Z, NH, N_CTX, BLOCK_DMODEL),
strides=(stride_oz, stride_oh, stride_om, stride_on),
offsets=(off_hz // NH, off_hz % NH, start_m, 0),
block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL),
order=(3, 2, 1, 0),
)
q = tl.load(q_tile_ptr, boundary_check=(0, 1, 2, 3))
q = tl.view(q, (BLOCK_M, BLOCK_DMODEL))
@@ -155,12 +163,13 @@ def test_batched_gemm_fusion():
ref_out = torch.matmul(torch.matmul(A, BT), C)
num_warps = 4
grid = (triton.cdiv(N_CTX, BLOCK_M), B * NH)
batched_gemm_fusion[grid](A, B, C, E,
A.stride(0), A.stride(1), A.stride(2), A.stride(3),
B.stride(0), B.stride(1), B.stride(2), B.stride(3),
C.stride(0), C.stride(1), C.stride(2), C.stride(3),
E.stride(0), E.stride(1), E.stride(2), E.stride(3),
Z, NH, N_CTX,
BLOCK_M, BLOCK_DMODEL, BLOCK_N, num_warps=num_warps)
batched_gemm_fusion[grid](
A, B, C, E, #
A.stride(0), A.stride(1), A.stride(2), A.stride(3), #
B.stride(0), B.stride(1), B.stride(2), B.stride(3), #
C.stride(0), C.stride(1), C.stride(2), C.stride(3), #
E.stride(0), E.stride(1), E.stride(2), E.stride(3), #
Z, NH, N_CTX, #
BLOCK_M, BLOCK_DMODEL, BLOCK_N, num_warps=num_warps)
torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0)

View File

@@ -24,10 +24,8 @@ def add_kernel(
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x_block_ptr = tl.make_block_ptr(
base=x_ptr, shape=(n_elements, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
block_shape=(BLOCK_SIZE, ), order=(0, )
)
x_block_ptr = tl.make_block_ptr(base=x_ptr, shape=(n_elements, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
block_shape=(BLOCK_SIZE, ), order=(0, ))
x = tl.load(x_block_ptr, boundary_check=(0, ), padding_option='zero')
y = tl.load(y_ptr + offsets, mask=mask)
@@ -36,9 +34,7 @@ def add_kernel(
@pytest.mark.parametrize('SIZE,BLOCK_SIZE,dtype_str',
[(98432, 1024, dtype_str)
for dtype_str in ['float16', 'float32']
])
[(98432, 1024, dtype_str) for dtype_str in ['float16', 'float32']])
def test_add(SIZE, BLOCK_SIZE, dtype_str):
dtype = dtype_mapping[dtype_str]
output = torch.empty(SIZE, device='cuda', dtype=dtype)
@@ -46,7 +42,8 @@ def test_add(SIZE, BLOCK_SIZE, dtype_str):
y = torch.randn(SIZE, device='cuda', dtype=dtype)
def grid(meta):
return (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
return (triton.cdiv(SIZE, meta['BLOCK_SIZE']), )
add_kernel[grid](x, y, output, SIZE, BLOCK_SIZE=BLOCK_SIZE)
output_torch = x + y
@@ -64,25 +61,20 @@ def load_reduce_kernel(
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
x_ptr = tl.make_block_ptr(
base=x_ptr, shape=(BLOCK_M, BLOCK_N), strides=(stride_xm, stride_xn),
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)
)
x_ptr = tl.make_block_ptr(base=x_ptr, shape=(BLOCK_M, BLOCK_N), strides=(stride_xm, stride_xn), offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
x = tl.load(x_ptr)
y = tl.max(x, axis=1)
tl.store(y_ptr + tl.arange(0, BLOCK_M), y)
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,dtype_str',
[(128, 64, dtype_str)
for dtype_str in ['float16']
])
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,dtype_str', [(128, 64, dtype_str) for dtype_str in ['float16']])
def test_load_reduce(BLOCK_M, BLOCK_N, dtype_str):
dtype = dtype_mapping[dtype_str]
x = torch.randn((BLOCK_M, BLOCK_N), device='cuda', dtype=dtype)
y = torch.empty((BLOCK_M, ), device='cuda', dtype=dtype)
load_reduce_kernel[(1,)](x, y, x.stride(0), x.stride(1), y.stride(0), BLOCK_M, BLOCK_N)
load_reduce_kernel[(1, )](x, y, x.stride(0), x.stride(1), y.stride(0), BLOCK_M, BLOCK_N)
golden = x.max(dim=1)[0]
torch.set_printoptions(profile='full')

View File

@@ -18,7 +18,6 @@
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
Fused Attention
===============
@@ -40,18 +39,17 @@ import triton.language as tl
key=['Q', 'K', 'V'],
)
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
L, M,
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
def _fwd_kernel(Q, K, V, sm_scale, #
L, M, #
Out, #
stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vk, stride_vn, #
stride_oz, stride_oh, stride_om, stride_on, #
Z, H, N_CTX, #
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
BLOCK_N: tl.constexpr #
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# initialize offsets
@@ -116,11 +114,10 @@ def _fwd_kernel(
@triton.jit
def _bwd_preprocess(
Out, DO, L,
NewDO, Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
def _bwd_preprocess(Out, DO, L, #
NewDO, Delta, #
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr #
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
@@ -136,19 +133,18 @@ def _bwd_preprocess(
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L, M,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
def _bwd_kernel(Q, K, V, sm_scale, Out, DO, #
DQ, DK, DV, #
L, M, #
D, #
stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vk, stride_vn, #
Z, H, N_CTX, #
num_block, #
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
BLOCK_N: tl.constexpr, #
):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
@@ -240,16 +236,16 @@ class _attention(torch.autograd.Function):
assert num_warps == 4
_fwd_kernel[grid](
q, k, v, sm_scale,
L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk,
q, k, v, sm_scale, #
L, m, #
o, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
q.shape[0], q.shape[1], q.shape[2], #
BLOCK_M=BLOCK, BLOCK_N=BLOCK, #
BLOCK_DMODEL=Lk #
)
ctx.save_for_backward(q, k, v, o, L, m)
@@ -269,24 +265,23 @@ class _attention(torch.autograd.Function):
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
do_scaled, delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
num_stages=1,
o, do, l, #
do_scaled, delta, #
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL)
_bwd_kernel[(ctx.grid[1], )](
q, k, v, ctx.sm_scale, #
o, do_scaled, #
dq, dk, dv, #
l, m, #
delta, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
q.shape[0], q.shape[1], q.shape[2], #
ctx.grid[0], #
BLOCK_M=BLOCK, BLOCK_N=BLOCK, #
BLOCK_DMODEL=ctx.BLOCK_DMODEL, #
num_warps=8, num_stages=1 #
)
return dq, dk, dv, None
@@ -339,19 +334,19 @@ except BaseException:
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = [triton.testing.Benchmark(
x_names=['N_CTX'],
# x_vals=[2**i for i in range(10, 14)],
x_vals=[2**i for i in range(10, 11)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
# ) for mode in ['fwd', 'bwd']]
) for mode in ['fwd']]
configs = [
triton.testing.Benchmark(
x_names=['N_CTX'],
# x_vals=[2**i for i in range(10, 14)],
x_vals=[2**i
for i in range(10, 11)], line_arg='provider', line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), styles=[('red', '-'), ('blue', '-')], ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
# ) for mode in ['fwd', 'bwd']]
)
for mode in ['fwd']
]
@triton.testing.perf_report(configs)
@@ -374,9 +369,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros(
(BATCH + 1,), device=device, dtype=torch.int32)
lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)

View File

@@ -29,14 +29,14 @@ import triton.language as tl
@triton.jit
def static_persistent_matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
NUM_SM: tl.constexpr,
def static_persistent_matmul_kernel( #
a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
NUM_SM: tl.constexpr #
):
start_tile = tl.program_id(axis=0)
m_tiles = tl.cdiv(M, BLOCK_M)
@@ -68,14 +68,14 @@ def static_persistent_matmul_kernel(
@triton.jit
def static_persistent_tma_matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
NUM_SM: tl.constexpr,
def static_persistent_tma_matmul_kernel( #
a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
NUM_SM: tl.constexpr #
):
start_tile = tl.program_id(axis=0)
m_tiles = tl.cdiv(M, BLOCK_M)
@@ -88,8 +88,10 @@ def static_persistent_tma_matmul_kernel(
block_offset_m = pre_pid_m * BLOCK_M
block_offset_n = pre_pid_n * BLOCK_N
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
for tile_id in range(start_tile, num_tiles, NUM_SM):
pid_m = tile_id // n_tiles
pid_n = tile_id % n_tiles
@@ -114,21 +116,23 @@ def static_persistent_tma_matmul_kernel(
pre_pid_n = pid_n
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA',
[(*shape, use_tma)
for shape in [
[4096, 4096, 64, 64, 64, 16, 4, 1, False, True],
[4096, 4096, 64, 64, 64, 32, 4, 1, False, True],
[4096, 4096, 64, 256, 64, 16, 4, 1, False, True],
[4096, 4096, 64, 128, 128, 16, 4, 1, False, True],
# TODO: fix issue for 8-warp persistent kernel
# [4096, 4096, 64, 128, 128, 16, 8, 1, False, True],
# [4096, 4096, 64, 128, 256, 16, 8, 1, False, True],
]
for use_tma in [False, True]
])
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA', [(
*shape, use_tma
) for shape in [
[4096, 4096, 64, 64, 64, 16, 4, 1, False, True],
[4096, 4096, 64, 64, 64, 32, 4, 1, False, True
],
[4096, 4096, 64, 256, 64, 16, 4, 1, False, True
],
[4096, 4096, 64, 128, 128, 16, 4, 1, False, True
],
# TODO: fix issue for 8-warp persistent kernel
# [4096, 4096, 64, 128, 128, 16, 8, 1, False, True],
# [4096, 4096, 64, 128, 256, 16, 8, 1, False, True],
] for use_tma in [False, True]])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA):
def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS,
TRANS_A, TRANS_B, USE_TMA):
if (TRANS_A):
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
@@ -141,25 +145,33 @@ def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLO
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),)
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )
if USE_TMA:
static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0),
stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS,
num_ctas=NUM_CTAS)
else:
static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0),
stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS,
num_ctas=NUM_CTAS)
th_c = torch.matmul(a, b)
torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False)
@triton.jit
def warp_specialized_matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
def warp_specialized_matmul_kernel( #
a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
):
tid = tl.program_id(axis=0)
n_tiles = tl.cdiv(N, BLOCK_N)
@@ -193,13 +205,13 @@ def warp_specialized_matmul_kernel(
@triton.jit
def tma_warp_specialized_matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
def tma_warp_specialized_matmul_kernel( #
a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
):
tid = tl.program_id(axis=0)
n_tiles = tl.cdiv(N, BLOCK_N)
@@ -232,8 +244,7 @@ def tma_warp_specialized_matmul_kernel(
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA',
[(*shape, use_tma)
for shape in [
[(*shape, use_tma) for shape in [
[2048, 2048, 64, 64, 64, 16, 1, False, True],
[4096, 4096, 64, 64, 64, 16, 1, False, True],
[128, 4096, 64, 64, 64, 16, 1, False, True],
@@ -257,9 +268,7 @@ def tma_warp_specialized_matmul_kernel(
[4096, 4096, 128, 256, 128, 64, 4, False, True],
[4096, 4096, 256, 128, 256, 64, 4, False, True],
[4096, 4096, 256, 256, 256, 64, 4, False, True],
]
for use_tma in [False, True]
])
] for use_tma in [False, True]])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA):
if (TRANS_A):
@@ -274,29 +283,29 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
grid = lambda META: (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),)
grid = lambda META: (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), )
if USE_TMA:
tma_warp_specialized_matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K,
num_warps=4,
num_ctas=NUM_CTAS,
a, b, c, #
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
BLOCK_M, BLOCK_N, BLOCK_K, #
num_warps=4, #
num_ctas=NUM_CTAS, #
enable_warp_specialization=True)
else:
warp_specialized_matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K,
num_warps=4,
num_ctas=NUM_CTAS,
a, b, c, #
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
BLOCK_M, BLOCK_N, BLOCK_K, #
num_warps=4, #
num_ctas=NUM_CTAS, #
enable_warp_specialization=True)
th_c = torch.matmul(a, b)
@@ -304,14 +313,14 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K
@triton.jit
def static_persistent_warp_specialized_matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
NUM_SM: tl.constexpr,
def static_persistent_warp_specialized_matmul_kernel( #
a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
NUM_SM: tl.constexpr #
):
start_tile = tl.program_id(axis=0)
m_tiles = tl.cdiv(M, BLOCK_M)
@@ -343,14 +352,14 @@ def static_persistent_warp_specialized_matmul_kernel(
@triton.jit
def static_persistent_tma_warp_specialized_matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
NUM_SM: tl.constexpr,
def static_persistent_tma_warp_specialized_matmul_kernel( #
a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
NUM_SM: tl.constexpr #
):
start_tile = tl.program_id(axis=0)
m_tiles = tl.cdiv(M, BLOCK_M)
@@ -363,8 +372,10 @@ def static_persistent_tma_warp_specialized_matmul_kernel(
block_offset_m = pre_pid_m * BLOCK_M
block_offset_n = pre_pid_n * BLOCK_N
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
for tile_id in range(start_tile, num_tiles, NUM_SM):
pid_m = tile_id // n_tiles
pid_n = tile_id % n_tiles
@@ -390,8 +401,7 @@ def static_persistent_tma_warp_specialized_matmul_kernel(
@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA',
[(*shape, use_tma)
for shape in [
[(*shape, use_tma) for shape in [
[2048, 2048, 64, 64, 64, 16, 1, False, True],
[4096, 4096, 64, 64, 64, 16, 1, False, True],
[128, 4096, 64, 64, 64, 16, 1, False, True],
@@ -415,11 +425,10 @@ def static_persistent_tma_warp_specialized_matmul_kernel(
[4096, 4096, 128, 256, 128, 64, 4, False, True],
[4096, 4096, 256, 128, 256, 64, 4, False, True],
[4096, 4096, 256, 256, 256, 64, 4, False, True],
]
for use_tma in [False, True]
])
] for use_tma in [False, True]])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA):
def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B,
USE_TMA):
if (TRANS_A):
a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
@@ -432,27 +441,22 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),)
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )
if USE_TMA:
static_persistent_tma_warp_specialized_matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs,
num_warps=4, num_ctas=NUM_CTAS,
a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M,
BLOCK_N, BLOCK_K, num_SMs, num_warps=4, num_ctas=NUM_CTAS, #
enable_warp_specialization=True)
else:
static_persistent_warp_specialized_matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs,
num_warps=4, num_ctas=NUM_CTAS,
a, b, c, #
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs, #
num_warps=4, num_ctas=NUM_CTAS, #
enable_warp_specialization=True)
th_c = torch.matmul(a, b)
@@ -460,16 +464,15 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N
@triton.jit
def static_persistent_matmul_no_scf_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr,
NUM_SM: tl.constexpr, USE_TMA_LOAD: tl.constexpr,
):
def static_persistent_matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr, #
NUM_SM: tl.constexpr, USE_TMA_LOAD: tl.constexpr #
):
start_tile = tl.program_id(axis=0)
m_tiles = tl.cdiv(M, BLOCK_M)
n_tiles = tl.cdiv(N, BLOCK_N)
@@ -487,7 +490,8 @@ def static_persistent_matmul_no_scf_kernel(
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
if USE_TMA_EPILOGUE:
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0))
for tile_id in range(start_tile, num_tiles, NUM_SM):
pid_m = tile_id // n_tiles
@@ -524,29 +528,27 @@ def static_persistent_matmul_no_scf_kernel(
pre_pid_n = pid_n
@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,USE_TMA_LOAD',
itertools.chain(
*[
[
# numCTAs = 1, no TMA multicast:
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
# small M, N
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
] for USE_TMA_EPILOGUE in [True, False]
for USE_TMA_LOAD in [True, False]
]))
@pytest.mark.parametrize(
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,USE_TMA_LOAD',
itertools.chain(*[[
# numCTAs = 1, no TMA multicast:
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
# small M, N
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD],
] for USE_TMA_EPILOGUE in [True, False] for USE_TMA_LOAD in [True, False]]))
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, USE_TMA_LOAD):
def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE,
USE_TMA_EPILOGUE, USE_TMA_LOAD):
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
@@ -564,46 +566,42 @@ def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TR
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
# TODO: set `enable_warp_specialization=False` will lead to compilation error.
static_persistent_matmul_no_scf_kernel[(num_SMs,)](a_ptr=a, b_ptr=b, c_ptr=c,
M=M, N=N, K=K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SM=num_SMs,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"),
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE,
USE_TMA_LOAD=USE_TMA_LOAD,
enable_warp_specialization=True)
static_persistent_matmul_no_scf_kernel[(num_SMs, )](
a_ptr=a, b_ptr=b, c_ptr=c, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_cm=c.stride(0), stride_cn=c.stride(1), #
BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SM=num_SMs, #
num_warps=NUM_WARPS, #
num_ctas=NUM_CTAS, #
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), #
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE, #
USE_TMA_LOAD=USE_TMA_LOAD, #
enable_warp_specialization=True)
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
golden = torch.matmul(a_f32, b_f32)
torch.set_printoptions(profile="full")
assert_close(
c,
golden,
rtol=1e-2,
atol=1e-3,
check_dtype=False)
assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
@triton.jit
def full_static_persistent_matmul_kernel(
a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_wm, stride_wn,
stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr,
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr,
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr,
NUM_SM: tl.constexpr
):
def full_static_persistent_matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_wm, stride_wn, #
stride_zm, stride_zn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, #
out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, #
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, #
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, #
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, #
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, #
NUM_SM: tl.constexpr #
):
start_pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_m = tl.cdiv(M, BLOCK_M)
@@ -618,15 +616,18 @@ def full_static_persistent_matmul_kernel(
pre_block_offset_m = pre_pid_m * BLOCK_M
pre_block_offset_n = pre_pid_n * BLOCK_N
a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(pre_block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(A_ORDER_0, A_ORDER_1))
offsets=(pre_block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K),
order=(A_ORDER_0, A_ORDER_1))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, pre_block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(B_ORDER_0, B_ORDER_1))
offsets=(0, pre_block_offset_n), block_shape=(BLOCK_K, BLOCK_N),
order=(B_ORDER_0, B_ORDER_1))
w_tile_ptr = tl.make_block_ptr(base=w_ptr, shape=(N, N), strides=(stride_wm, stride_wn),
offsets=(0, pre_block_offset_n), block_shape=(BLOCK_N, BLOCK_N), order=(0, 1))
if USE_TMA_STORE:
z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn),
offsets=(pre_block_offset_m, pre_block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
offsets=(pre_block_offset_m, pre_block_offset_n),
block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
for tile_id in range(start_pid, num_tiles, NUM_SM):
group_id = tile_id // num_pid_in_group
@@ -694,136 +695,120 @@ def full_static_persistent_matmul_kernel(
pre_pid_n = pid_n
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws)
for shape_w_c in [
[4096, 1, 1024, False, False],
[2048, 204, 1000, True, False],
[16, 524288, 32, False, True],
]
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for enable_ws in [True]
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
# softmax works for one CTA
for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
# TODO: enable when num_warps != 4 is supported.
# [64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
]
for epilogue in ['softmax']
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for trans_a in [False,]
for trans_b in [True,]
for num_stages in [3]
for enable_ws in [True]
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
]
for out_dtype in ['float32',]
for use_tma_store in [False,]
for trans_a in [False, True]
for trans_b in [False, True]
for num_stages in [3]
for enable_ws in [True]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4] for num_ctas in [1, 2]],
# # TODO: enable when num_warps != 4 is supported.
# # repeat
# # [64, 64, 32, 8, 1, 128, 256, 64],
# # [64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 1, 513, 193, 192],
]
for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot']
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for trans_a in [False,]
for trans_b in [True,]
for num_stages in [3]
for enable_ws in [True]
if not (epilogue == 'chain-dot' and (shape_w_c[5] is not None or shape_w_c[0] != shape_w_c[1]))
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for n in [16, 32, 64, 128, 256]
for out_dtype in ['float32']
for use_tma_store in [False,]
for num_stages in [2, 4, 5, 7]
for enable_ws in [True]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2]
]
for shape in [
[512, 360, 1024],
[360, 4096, 512],
]
for out_dtype in ['float32']
for use_tma_store in [False, True]
for num_stages in [3, 4]
for enable_ws in [True]
]
)
@pytest.mark.skipif(torch.cuda.get_device_capability()
[0] < 9, reason="Requires compute capability >= 9")
def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS])) in [
'128-128-128-4-1-256-256-192-none-float32-True-3-True',
]:
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws) for shape_w_c in [
[4096, 1, 1024, False, False],
[2048, 204, 1000, True, False],
[16, 524288, 32, False, True],
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for enable_ws in [True]
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
# softmax works for one CTA
for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
# TODO: enable when num_warps != 4 is supported.
# [64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
]
for epilogue in ['softmax']
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for trans_a in [False]
for trans_b in [True]
for num_stages in [3]
for enable_ws in [True]
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, 'none', out_dtype, use_tma_store, num_stages, enable_ws) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for num_stages in [3] for enable_ws in [True]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4] for num_ctas in [1, 2]],
# # TODO: enable when num_warps != 4 is supported.
# # repeat
# # [64, 64, 32, 8, 1, 128, 256, 64],
# # [64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 1, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'] for out_dtype in
['float16', 'float32'] for use_tma_store in [False, True] for trans_a in [False] for trans_b in [True] for
num_stages in [3] for enable_ws in [True] if not (epilogue == 'chain-dot' and
(shape_w_c[5] is not None or shape_w_c[0] != shape_w_c[1]))
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for n in [16, 32, 64, 128, 256]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
for enable_ws in [True]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [[128, 128, 64, 4, 1], [256, 128, 64, 4, 2], [128, 128, 128, 4, 2]]
for shape in [
[512, 360, 1024],
[360, 4096, 512],
]
for out_dtype in ['float32']
for use_tma_store in [False, True]
for num_stages in [3, 4]
for enable_ws in [True]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B,
epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
if '-'.join(
map(str, [
BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES,
ENABLE_WS
])) in [
'128-128-128-4-1-256-256-192-none-float32-True-3-True',
]:
pytest.skip('out of resource: shared memory, Required: 263168')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
@@ -893,37 +878,36 @@ def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WAR
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
def grid(META):
return (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),)
return (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )
full_static_persistent_matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z,
M=M, N=N, K=K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_wm=w.stride(0), stride_wn=w.stride(1),
stride_zm=z.stride(0), stride_zn=z.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8,
out_dtype=out_dtype,
USE_TMA_STORE=USE_TMA_STORE,
ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS=epilogue == 'add-rows',
ADD_COLS=epilogue == 'add-cols',
DO_SOFTMAX=epilogue == 'softmax',
CHAIN_DOT=epilogue == 'chain-dot',
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1],
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1],
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES,
enable_warp_specialization=ENABLE_WS,
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, #
enable_warp_specialization=ENABLE_WS, #
NUM_SM=num_SMs)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden,
rtol=1e-2,
atol=1e-3,
check_dtype=False)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)

View File

@@ -19,7 +19,6 @@
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import pytest
import torch
from torch.testing import assert_close
@@ -29,21 +28,21 @@ import triton.language as tl
@triton.jit
def matmul_tma_load_store(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
OUTPUT_F16: tl.constexpr
def matmul_tma_load_store( #
a_ptr, b_ptr, c_ptr, #
M, N, K, #
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
OUTPUT_F16: tl.constexpr #
):
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0),
block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))
a = tl.load(a_block_ptr)
b = tl.load(b_block_ptr)
@@ -78,15 +77,15 @@ def test_tma_load_store(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_F
if OUTPUT_F16:
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
matmul_tma_load_store[(1, 1)](a_ptr=a, b_ptr=b, c_ptr=c,
M=M, N=N, K=K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
OUTPUT_F16=OUTPUT_F16)
matmul_tma_load_store[(1, 1)](
a_ptr=a, b_ptr=b, c_ptr=c, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_cm=c.stride(0), stride_cn=c.stride(1), #
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, #
OUTPUT_F16=OUTPUT_F16)
golden = torch.matmul(a, b)
torch.set_printoptions(profile="full")
assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False)

View File

@@ -54,17 +54,13 @@ def test_tma_wgmma_64_64_16_f16(TTGIR, TRANS_A, TRANS_B):
ttgir_path = os.path.dirname(__file__) + "/" + TTGIR
kernel = triton.compile(ttgir_path)
kernel[(1, 1, 1)](a.data_ptr(), b.data_ptr(), c.data_ptr(),
SIZE_M, SIZE_N, SIZE_K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0))
kernel[(1, 1, 1)]( #
a.data_ptr(), b.data_ptr(), c.data_ptr(), #
SIZE_M, SIZE_N, SIZE_K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0))
golden = torch.matmul(a, b)
torch.set_printoptions(profile="full", sci_mode=False)
assert_close(
c,
golden,
rtol=1e-2,
atol=1e-3,
check_dtype=False)
assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False)