mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
[TUTORIALS] Enable causal=False in FA fwd kernel (#2459)
This commit is contained in:
@@ -29,13 +29,17 @@ def _attn_fwd_inner(
|
||||
STAGE: tl.constexpr,
|
||||
offs_m: tl.constexpr,
|
||||
offs_n: tl.constexpr,
|
||||
N_CTX: tl.constexpr,
|
||||
):
|
||||
# range of values handled by this stage
|
||||
if STAGE == 1:
|
||||
lo, hi = 0, start_m * BLOCK_M
|
||||
else:
|
||||
elif STAGE == 2:
|
||||
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
||||
lo = tl.multiple_of(lo, BLOCK_M)
|
||||
# causal = False
|
||||
else:
|
||||
lo, hi = 0, N_CTX
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
|
||||
# loop over k, v and update accumulator
|
||||
@@ -160,23 +164,25 @@ def _attn_fwd(
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(Q_block_ptr)
|
||||
# stage 1: off-band
|
||||
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
|
||||
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
|
||||
if STAGE & 1:
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
||||
start_m, qk_scale,
|
||||
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
|
||||
1, offs_m, offs_n,
|
||||
4 - STAGE, offs_m, offs_n, N_CTX,
|
||||
)
|
||||
# barrier makes it easier for compielr to schedule the
|
||||
# two loops independently
|
||||
tl.debug_barrier()
|
||||
# stage 2: on-band
|
||||
if STAGE & 2:
|
||||
# barrier makes it easier for compielr to schedule the
|
||||
# two loops independently
|
||||
tl.debug_barrier()
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
||||
start_m, qk_scale,
|
||||
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
|
||||
2, offs_m, offs_n,
|
||||
2, offs_m, offs_n, N_CTX,
|
||||
)
|
||||
# epilogue
|
||||
m_i += tl.math.log2(l_i)
|
||||
@@ -460,6 +466,7 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_N = 64 if Lk <= 64 else 32
|
||||
num_stages = 4 if Lk <= 64 else 3
|
||||
num_warps = 4
|
||||
stage = 3 if causal else 1
|
||||
# Tuning for H100
|
||||
if torch.cuda.get_device_capability()[0] == 9:
|
||||
num_warps = 8
|
||||
@@ -477,7 +484,7 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
BLOCK_DMODEL=Lk,
|
||||
STAGE=3,
|
||||
STAGE=stage,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
@@ -591,28 +598,31 @@ except BaseException:
|
||||
TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2')
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["N_CTX"],
|
||||
x_vals=[2**i for i in range(10, 15)],
|
||||
line_arg="provider",
|
||||
line_vals=["triton"] + (["flash"] if HAS_FLASH else []),
|
||||
line_names=["Triton"] + (["Flash-2"] if HAS_FLASH else []),
|
||||
styles=[("red", "-"), ("blue", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}",
|
||||
args={
|
||||
"H": N_HEADS,
|
||||
"BATCH": BATCH,
|
||||
"D_HEAD": D_HEAD,
|
||||
"dtype": torch.float16,
|
||||
"mode": mode,
|
||||
"causal": causal,
|
||||
},
|
||||
)
|
||||
for mode in ["fwd", "bwd"]
|
||||
for causal in [True]
|
||||
]
|
||||
configs = []
|
||||
for mode in ["fwd", "bwd"]:
|
||||
for causal in [True, False]:
|
||||
if mode == "bwd" and not causal:
|
||||
continue
|
||||
configs.append(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["N_CTX"],
|
||||
x_vals=[2**i for i in range(10, 15)],
|
||||
line_arg="provider",
|
||||
line_vals=["triton"] + (["flash"] if HAS_FLASH else []),
|
||||
line_names=["Triton"] + (["Flash-2"] if HAS_FLASH else []),
|
||||
styles=[("red", "-"), ("blue", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}",
|
||||
args={
|
||||
"H": N_HEADS,
|
||||
"BATCH": BATCH,
|
||||
"D_HEAD": D_HEAD,
|
||||
"dtype": torch.float16,
|
||||
"mode": mode,
|
||||
"causal": causal,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
|
||||
Reference in New Issue
Block a user