Enable split kernel in bwd pass (#303)

* Add fwd and bwd v2

Changes are largely from upstream.

* Split bwd kernel in dq and dk+dv

Only adds the split kernels. They are not enabled yet.

* Pull scalar multiplies out of the loop

* Enable split kernel for bwd pass

* Put back P_SEQ=128 in fwd test

Not used for bwd test

* Address review comments

* Address comments

Conditionally set causal/ splitkernel to False for bwd.

* Add block pointer semantics to bwd pass

This significantly increases perf for bwd, similar to fwd.
This commit is contained in:
Vinayak Gokhale
2023-08-29 13:51:29 -05:00
committed by GitHub
parent b834f42ae4
commit 9cdf3a58c3

View File

@@ -113,7 +113,7 @@ def _fwd_kernel(
@triton.jit
def _bwd_preprocess(
Out, DO, L,
Out, DO,
NewDO, Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
@@ -122,9 +122,6 @@ def _bwd_preprocess(
# load
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
denom = tl.load(L + off_m).to(tl.float32)
# compute
do = do / denom[:, None]
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
@@ -135,7 +132,7 @@ def _bwd_preprocess(
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L, M,
L,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
@@ -156,6 +153,8 @@ def _bwd_kernel(
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_qz + off_h * stride_qh
DV += off_z * stride_qz + off_h * stride_qh
# See fwd pass above for explanation.
qk_scale = sm_scale * 1.44269504
for start_n in range(0, num_block):
lo = start_n * BLOCK_M
# initialize row/col offsets
@@ -166,12 +165,12 @@ def _bwd_kernel(
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
v_ptrs = V + (offs_n[None, :] * stride_qm + offs_k[:, None] * stride_qk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
m_ptrs = M + off_hz * N_CTX
l_ptrs = L + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
@@ -187,15 +186,15 @@ def _bwd_kernel(
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, tl.trans(k))
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
l_i = tl.load(l_ptrs + offs_m_curr)
p = tl.math.exp2(qk * qk_scale - l_i[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, tl.trans(v))
dp += tl.dot(do, v)
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
@@ -214,6 +213,217 @@ def _bwd_kernel(
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
@triton.jit
def _bwd_kernel_dk_dv(
Q, K, V, sm_scale, Out, DO,
DK, DV,
L,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# Q is consumed depending on block ID. Every block uses
# previous block offset by BLOCK_M x D_HEAD.
qvk_offset = off_hz * stride_qh
qdo_offset = qvk_offset + start_m * BLOCK_M * stride_qm
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
# Initialize pointers to Q, K, V
Q_block_ptr = tl.make_block_ptr(
base=Q + qdo_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, start_m * BLOCK_M),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_vn, stride_vk),
offsets=(0, start_m * BLOCK_M),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
DO_block_ptr = tl.make_block_ptr(
base=DO + qdo_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
l_ptrs = L + off_hz * N_CTX
qk_scale = sm_scale * 1.44269504
# load k and v: they will stay in SRAM throughout
k = tl.load(K_block_ptr)
k = (k * qk_scale).to(tl.float16)
v = tl.load(V_block_ptr)
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# This lower loop bound is because of the causal mask. We create a lower triangular
# result. The upper triangular is -inf (becomes 0 when we do e^x). As such, it can
# be ignored in the GEMM.
lo = start_m * BLOCK_M
hi = N_CTX
# loop over q, do
for start_n in range(lo, hi, BLOCK_N):
offs_m_curr = offs_n[:, None] + start_n
# -- load q, do --
q = tl.load(Q_block_ptr)
do = tl.load(DO_block_ptr)
# -- compute qk ----
qk = tl.dot(q, k)
qk = tl.where(offs_m_curr >= offs_m[None, :], qk, float("-inf"))
l_i = tl.load(l_ptrs + offs_m_curr)
p = tl.math.exp2(qk - l_i)
# -- compute dv ----
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_N, BLOCK_M], dtype=tl.float32) - Di
dp += tl.dot(do, v)
# compute ds = p * (dp - delta[:, None])
ds = p * dp
# compute dk
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
# update pointers
Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_N, 0))
DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_N, 0))
# initialize pointers to output
DK_block_ptr = tl.make_block_ptr(
base=DK + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
DV_block_ptr = tl.make_block_ptr(
base=DV + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
tl.store(DK_block_ptr, (dk * sm_scale).to(tl.float16))
tl.store(DV_block_ptr, dv.to(tl.float16))
@triton.jit
def _bwd_kernel_dq(
Q, K, V, sm_scale, Out, DO,
DQ,
L,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
qvk_offset = off_hz * stride_qh
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
# Initialize pointers to Q, K, V
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_vn, stride_vk),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
DO_block_ptr = tl.make_block_ptr(
base=DO + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
l_ptrs = L + off_hz * N_CTX
qk_scale = sm_scale * 1.44269504
# load q and do: they will stay in SRAM throughout
q = tl.load(Q_block_ptr)
q = (q * qk_scale).to(tl.float16)
do = tl.load(DO_block_ptr)
Di = tl.load(D_ptrs + offs_m)
l_i = tl.load(l_ptrs + offs_m)
dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# loop over k, v
lo = 0
hi = (start_m + 1) * BLOCK_M
for start_n in range(lo, hi, BLOCK_N):
# -- load k, v --
k = tl.load(K_block_ptr)
v = tl.load(V_block_ptr)
# -- compute qk ----
qk = tl.dot(q, k)
qk = tl.where(offs_m[:, None] >= (offs_n[None, :] + start_n), qk, float("-inf"))
p = tl.math.exp2(qk - l_i[:, None])
# compute dp = dot(v, do)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, v)
# compute ds = p * (dp - delta[:, None])
ds = p * dp
# compute dq. Unfortunately we cannot avoid transpose here as this loop
# uses k both normal and transpose.
dq += tl.dot(ds.to(Q.dtype.element_ty), tl.trans(k))
# update pointers
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (0, BLOCK_N))
# initialize pointers to output
DQ_block_ptr = tl.make_block_ptr(
base=DQ + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
tl.store(DQ_block_ptr, (dq * sm_scale).to(tl.float16))
empty = torch.empty(128, device="cuda")
@@ -221,7 +431,7 @@ empty = torch.empty(128, device="cuda")
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, sm_scale):
def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False):
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
@@ -258,45 +468,79 @@ class _attention(torch.autograd.Function):
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
ctx.causal = causal
ctx.split_kernel = split_kernel
ctx.P_SEQ = P_SEQ
return o
@staticmethod
def backward(ctx, do):
if torch.version.hip is not None:
BLOCK = 64
else:
BLOCK = 128
q, k, v, o, l, m = ctx.saved_tensors
BLOCK = 64
q, k, v, o, l = ctx.saved_tensors
do = do.contiguous()
dq = torch.zeros_like(q, dtype=torch.float32)
dq = torch.zeros_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
# Figure out what BLOCK size fwd used and adjust num_blocks accordingly.
# If the two are the same, we don't need this but the bwd pass block size
# is smaller than the fwd so we need this scaling to ensure we loop over all
# values and don't skip some blocks.
# Alternatively we could compute a new grid but this keeps it consistent
# with fwd and easier to reason about.
block_scale = (q.shape[2] // ctx.grid[0]) // BLOCK
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
o, do,
do_scaled, delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
num_stages=1,
BLOCK_M=block_scale * BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
if not ctx.split_kernel:
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
block_scale * ctx.grid[0],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
num_stages=1,
)
else :
_bwd_kernel_dk_dv[(block_scale * ctx.grid[0], ctx.grid[1])](
q, k, v, ctx.sm_scale,
o, do_scaled,
dk, dv,
l,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
num_stages=1,
)
_bwd_kernel_dq[ctx.grid](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq,
l,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=2*BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
num_stages=1,
)
# print(h.asm["ttgir"])
return dq, dk, dv, None
return dq, dk, dv, None, None, None
attention = _attention.apply
@@ -309,12 +553,12 @@ attention = _attention.apply
(4, 48, 16384, 64, 128)
])
@pytest.mark.parametrize('causal', [False, True])
def test_op(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16):
def test_op_fwd(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
sm_scale = 0.5
sm_scale = q.shape[-1] ** (-0.5)
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ)
@@ -322,23 +566,55 @@ def test_op(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16):
if causal:
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
# p = torch.exp(p)
ref_out = torch.matmul(p, v)
#ref_out.backward(dout)
#ref_dv, v.grad = v.grad.clone(), None
#ref_dk, k.grad = k.grad.clone(), None
#ref_dq, q.grad = q.grad.clone(), None
# triton implementation
tri_out = attention(q, k, v, causal, sm_scale).half()
#tri_out.backward(dout)
#tri_dv, v.grad = v.grad.clone(), None
#tri_dk, k.grad = k.grad.clone(), None
#tri_dq, q.grad = q.grad.clone(), None
tri_out = attention(q, k, v, causal, sm_scale)
# compare
assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
#assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=0)
#assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0)
#assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0)
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, P_SEQ',
[(4, 48, 1024, 64, 0),
(4, 48, 2048, 64, 0),
(4, 48, 4096, 64, 0),
(1, 16, 8192, 64, 0),
])
def test_op_bwd(Z, H, N_CTX, D_HEAD, P_SEQ, dtype=torch.float16):
torch.manual_seed(20)
causal = True
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
sm_scale = q.shape[-1] ** (-0.5)
split_kernel = True
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ)
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
if causal:
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# # triton implementation
tri_out = attention(q, k, v, causal, sm_scale, split_kernel)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0)
if torch.version.hip is None:
assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=0)
# The current block size for MI200 series is 64x64. This results in
# larger differences in float results due to rounding.
else:
assert torch.allclose(ref_dv, tri_dv, atol=5e-2, rtol=0)
assert torch.allclose(ref_dk, tri_dk, atol=5e-2, rtol=0)
assert torch.allclose(ref_dq, tri_dq, atol=5e-2, rtol=0)
try:
@@ -365,7 +641,7 @@ configs = [triton.testing.Benchmark(
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode, 'causal': causal}
) for mode in ['fwd'] for causal in [False]]
) for mode in ['fwd', 'bwd'] for causal in [True, False]]
@triton.testing.perf_report(configs)
@@ -373,12 +649,17 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
assert mode in ['fwd', 'bwd']
warmup = 25
rep = 100
split_kernel = False
# Bwd pass only supports causal=True right now
if mode == 'bwd':
causal = True
split_kernel = True
if provider == "triton":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
sm_scale = 1.3
fn = lambda: attention(q, k, v, causal, sm_scale)
fn = lambda: attention(q, k, v, causal, sm_scale, split_kernel)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)