[HOPPER] fix ref check failure of flash attention with mma v3 (#2384)

This commit is contained in:
ben-zhang-609
2023-09-26 02:29:49 +08:00
committed by GitHub
parent 6bc1d9e1be
commit d040b58547
2 changed files with 15 additions and 9 deletions

View File

@@ -18,13 +18,6 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
if enable_tma in ["on", "true", "1"]:
if dtype == torch.bfloat16:
pytest.skip('bfloat16 tma not support currently')
if '-'.join(map(str, [seq_par, causal, Z, H, N_CTX, D_HEAD])) in [
"True-True-2-4-512-16",
"True-True-2-4-512-32",
"True-False-2-4-512-16",
"True-False-2-4-512-32",
]:
pytest.skip('backward ref check failed')
capability = torch.cuda.get_device_capability()
interpreter = os.environ.get("TRITON_INTERPRET", 'not found') in ["on", "true", "1"]

View File

@@ -143,6 +143,7 @@ def _bwd_kernel_one_col_block(
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
CAUSAL: tl.constexpr,
MMA_V3: tl.constexpr
):
if SEQUENCE_PARALLEL:
DQ += stride_dqa.to(tl.int64) * start_n
@@ -202,8 +203,11 @@ def _bwd_kernel_one_col_block(
dq += tl.dot(ds, k, allow_tf32=True)
tl.store(dq_ptrs, dq)
elif SEQUENCE_PARALLEL:
# dq = tl.dot(ds, k, allow_tf32=True)
dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True))
if MMA_V3:
dq = tl.dot(ds, k, allow_tf32=True)
else:
# not work with mma v3, becuase M % 64 != 0
dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True))
tl.store(dq_ptrs, dq)
# increment pointers
@@ -233,6 +237,7 @@ def _bwd_kernel(
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
CAUSAL: tl.constexpr,
MMA_V3: tl.constexpr
# fmt: on
):
qk_scale = sm_scale * 1.44269504
@@ -265,6 +270,7 @@ def _bwd_kernel(
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
MMA_V3=MMA_V3
)
else:
start_n = tl.program_id(1)
@@ -282,6 +288,7 @@ def _bwd_kernel(
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
MMA_V3=MMA_V3
)
@@ -328,6 +335,11 @@ class _attention(torch.autograd.Function):
@staticmethod
def backward(ctx, do):
import os
enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower()
MMA_V3 = False
if enable_mmav3 in ["on", "true", "1"]:
MMA_V3 = True
BLOCK = 128
q, k, v, o, L = ctx.saved_tensors
sequence_parallel = ctx.sequence_parallel
@@ -361,6 +373,7 @@ class _attention(torch.autograd.Function):
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
SEQUENCE_PARALLEL=sequence_parallel,
CAUSAL=ctx.causal,
MMA_V3=MMA_V3,
num_warps=8,
num_stages=1,
)