[TUTORIALS] Enable causal=False in FA fwd kernel (#2459)

This commit is contained in:
Lixun Zhang
2023-10-06 19:54:45 -05:00
committed by GitHub
parent fb3c2f3b2b
commit ded79e87ee

View File

@@ -29,13 +29,17 @@ def _attn_fwd_inner(
STAGE: tl.constexpr,
offs_m: tl.constexpr,
offs_n: tl.constexpr,
N_CTX: tl.constexpr,
):
# range of values handled by this stage
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
else:
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
# causal = False
else:
lo, hi = 0, N_CTX
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
# loop over k, v and update accumulator
@@ -160,23 +164,25 @@ def _attn_fwd(
# load q: it will stay in SRAM throughout
q = tl.load(Q_block_ptr)
# stage 1: off-band
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
if STAGE & 1:
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m, qk_scale,
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
1, offs_m, offs_n,
4 - STAGE, offs_m, offs_n, N_CTX,
)
# barrier makes it easier for compielr to schedule the
# two loops independently
tl.debug_barrier()
# stage 2: on-band
if STAGE & 2:
# barrier makes it easier for compielr to schedule the
# two loops independently
tl.debug_barrier()
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m, qk_scale,
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
2, offs_m, offs_n,
2, offs_m, offs_n, N_CTX,
)
# epilogue
m_i += tl.math.log2(l_i)
@@ -460,6 +466,7 @@ class _attention(torch.autograd.Function):
BLOCK_N = 64 if Lk <= 64 else 32
num_stages = 4 if Lk <= 64 else 3
num_warps = 4
stage = 3 if causal else 1
# Tuning for H100
if torch.cuda.get_device_capability()[0] == 9:
num_warps = 8
@@ -477,7 +484,7 @@ class _attention(torch.autograd.Function):
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=Lk,
STAGE=3,
STAGE=stage,
num_warps=num_warps,
num_stages=num_stages,
)
@@ -591,28 +598,31 @@ except BaseException:
TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2')
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = [
triton.testing.Benchmark(
x_names=["N_CTX"],
x_vals=[2**i for i in range(10, 15)],
line_arg="provider",
line_vals=["triton"] + (["flash"] if HAS_FLASH else []),
line_names=["Triton"] + (["Flash-2"] if HAS_FLASH else []),
styles=[("red", "-"), ("blue", "-")],
ylabel="ms",
plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}",
args={
"H": N_HEADS,
"BATCH": BATCH,
"D_HEAD": D_HEAD,
"dtype": torch.float16,
"mode": mode,
"causal": causal,
},
)
for mode in ["fwd", "bwd"]
for causal in [True]
]
configs = []
for mode in ["fwd", "bwd"]:
for causal in [True, False]:
if mode == "bwd" and not causal:
continue
configs.append(
triton.testing.Benchmark(
x_names=["N_CTX"],
x_vals=[2**i for i in range(10, 15)],
line_arg="provider",
line_vals=["triton"] + (["flash"] if HAS_FLASH else []),
line_names=["Triton"] + (["Flash-2"] if HAS_FLASH else []),
styles=[("red", "-"), ("blue", "-")],
ylabel="ms",
plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}",
args={
"H": N_HEADS,
"BATCH": BATCH,
"D_HEAD": D_HEAD,
"dtype": torch.float16,
"mode": mode,
"causal": causal,
},
)
)
@triton.testing.perf_report(configs)