mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[HOPPER] fix ref check failure of flash attention with mma v3 (#2384)
This commit is contained in:
@@ -18,13 +18,6 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
|
||||
if enable_tma in ["on", "true", "1"]:
|
||||
if dtype == torch.bfloat16:
|
||||
pytest.skip('bfloat16 tma not support currently')
|
||||
if '-'.join(map(str, [seq_par, causal, Z, H, N_CTX, D_HEAD])) in [
|
||||
"True-True-2-4-512-16",
|
||||
"True-True-2-4-512-32",
|
||||
"True-False-2-4-512-16",
|
||||
"True-False-2-4-512-32",
|
||||
]:
|
||||
pytest.skip('backward ref check failed')
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
interpreter = os.environ.get("TRITON_INTERPRET", 'not found') in ["on", "true", "1"]
|
||||
|
||||
@@ -143,6 +143,7 @@ def _bwd_kernel_one_col_block(
|
||||
BLOCK_N: tl.constexpr,
|
||||
SEQUENCE_PARALLEL: tl.constexpr,
|
||||
CAUSAL: tl.constexpr,
|
||||
MMA_V3: tl.constexpr
|
||||
):
|
||||
if SEQUENCE_PARALLEL:
|
||||
DQ += stride_dqa.to(tl.int64) * start_n
|
||||
@@ -202,8 +203,11 @@ def _bwd_kernel_one_col_block(
|
||||
dq += tl.dot(ds, k, allow_tf32=True)
|
||||
tl.store(dq_ptrs, dq)
|
||||
elif SEQUENCE_PARALLEL:
|
||||
# dq = tl.dot(ds, k, allow_tf32=True)
|
||||
dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True))
|
||||
if MMA_V3:
|
||||
dq = tl.dot(ds, k, allow_tf32=True)
|
||||
else:
|
||||
# not work with mma v3, becuase M % 64 != 0
|
||||
dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True))
|
||||
tl.store(dq_ptrs, dq)
|
||||
|
||||
# increment pointers
|
||||
@@ -233,6 +237,7 @@ def _bwd_kernel(
|
||||
BLOCK_N: tl.constexpr,
|
||||
SEQUENCE_PARALLEL: tl.constexpr,
|
||||
CAUSAL: tl.constexpr,
|
||||
MMA_V3: tl.constexpr
|
||||
# fmt: on
|
||||
):
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
@@ -265,6 +270,7 @@ def _bwd_kernel(
|
||||
BLOCK_N=BLOCK_N,
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
|
||||
CAUSAL=CAUSAL,
|
||||
MMA_V3=MMA_V3
|
||||
)
|
||||
else:
|
||||
start_n = tl.program_id(1)
|
||||
@@ -282,6 +288,7 @@ def _bwd_kernel(
|
||||
BLOCK_N=BLOCK_N,
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
|
||||
CAUSAL=CAUSAL,
|
||||
MMA_V3=MMA_V3
|
||||
)
|
||||
|
||||
|
||||
@@ -328,6 +335,11 @@ class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
import os
|
||||
enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower()
|
||||
MMA_V3 = False
|
||||
if enable_mmav3 in ["on", "true", "1"]:
|
||||
MMA_V3 = True
|
||||
BLOCK = 128
|
||||
q, k, v, o, L = ctx.saved_tensors
|
||||
sequence_parallel = ctx.sequence_parallel
|
||||
@@ -361,6 +373,7 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
|
||||
SEQUENCE_PARALLEL=sequence_parallel,
|
||||
CAUSAL=ctx.causal,
|
||||
MMA_V3=MMA_V3,
|
||||
num_warps=8,
|
||||
num_stages=1,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user