From ded79e87ee375daf333ce2746dbe27e90cd31ca2 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Fri, 6 Oct 2023 19:54:45 -0500 Subject: [PATCH] [TUTORIALS] Enable causal=False in FA fwd kernel (#2459) --- python/tutorials/06-fused-attention.py | 68 +++++++++++++++----------- 1 file changed, 39 insertions(+), 29 deletions(-) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 7582433d9..dbe1e4c26 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -29,13 +29,17 @@ def _attn_fwd_inner( STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, + N_CTX: tl.constexpr, ): # range of values handled by this stage if STAGE == 1: lo, hi = 0, start_m * BLOCK_M - else: + elif STAGE == 2: lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M lo = tl.multiple_of(lo, BLOCK_M) + # causal = False + else: + lo, hi = 0, N_CTX K_block_ptr = tl.advance(K_block_ptr, (0, lo)) V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # loop over k, v and update accumulator @@ -160,23 +164,25 @@ def _attn_fwd( # load q: it will stay in SRAM throughout q = tl.load(Q_block_ptr) # stage 1: off-band + # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE if STAGE & 1: acc, l_i, m_i = _attn_fwd_inner( acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, qk_scale, BLOCK_M, BLOCK_DMODEL, BLOCK_N, - 1, offs_m, offs_n, + 4 - STAGE, offs_m, offs_n, N_CTX, ) - # barrier makes it easier for compielr to schedule the - # two loops independently - tl.debug_barrier() # stage 2: on-band if STAGE & 2: + # barrier makes it easier for compielr to schedule the + # two loops independently + tl.debug_barrier() acc, l_i, m_i = _attn_fwd_inner( acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, qk_scale, BLOCK_M, BLOCK_DMODEL, BLOCK_N, - 2, offs_m, offs_n, + 2, offs_m, offs_n, N_CTX, ) # epilogue m_i += tl.math.log2(l_i) @@ -460,6 +466,7 @@ class _attention(torch.autograd.Function): BLOCK_N = 64 if Lk <= 64 else 32 num_stages = 4 if Lk <= 64 else 3 num_warps = 4 + stage = 3 if causal else 1 # Tuning for H100 if torch.cuda.get_device_capability()[0] == 9: num_warps = 8 @@ -477,7 +484,7 @@ class _attention(torch.autograd.Function): BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, - STAGE=3, + STAGE=stage, num_warps=num_warps, num_stages=num_stages, ) @@ -591,28 +598,31 @@ except BaseException: TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 # vary seq length for fixed head and batch=4 -configs = [ - triton.testing.Benchmark( - x_names=["N_CTX"], - x_vals=[2**i for i in range(10, 15)], - line_arg="provider", - line_vals=["triton"] + (["flash"] if HAS_FLASH else []), - line_names=["Triton"] + (["Flash-2"] if HAS_FLASH else []), - styles=[("red", "-"), ("blue", "-")], - ylabel="ms", - plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}", - args={ - "H": N_HEADS, - "BATCH": BATCH, - "D_HEAD": D_HEAD, - "dtype": torch.float16, - "mode": mode, - "causal": causal, - }, - ) - for mode in ["fwd", "bwd"] - for causal in [True] -] +configs = [] +for mode in ["fwd", "bwd"]: + for causal in [True, False]: + if mode == "bwd" and not causal: + continue + configs.append( + triton.testing.Benchmark( + x_names=["N_CTX"], + x_vals=[2**i for i in range(10, 15)], + line_arg="provider", + line_vals=["triton"] + (["flash"] if HAS_FLASH else []), + line_names=["Triton"] + (["Flash-2"] if HAS_FLASH else []), + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}", + args={ + "H": N_HEADS, + "BATCH": BATCH, + "D_HEAD": D_HEAD, + "dtype": torch.float16, + "mode": mode, + "causal": causal, + }, + ) + ) @triton.testing.perf_report(configs)