mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPS] enable flash_attention_v2 TMA (#2544)
This commit is contained in:
@@ -165,7 +165,7 @@ flash_attention_data = {
|
||||
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.266,
|
||||
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.098,
|
||||
(4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.159,
|
||||
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.136,
|
||||
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.159,
|
||||
(4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.088,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,8 +43,6 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
# # triton implementation
|
||||
tri_out = triton.ops.attention(q, k, v, causal, sm_scale, seq_par)
|
||||
# temporary env var control begin
|
||||
os.putenv("ENABLE_TMA", "0")
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
@@ -55,5 +53,3 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
|
||||
torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)
|
||||
# temporary env var control end
|
||||
os.putenv("ENABLE_TMA", enable_tma)
|
||||
|
||||
Reference in New Issue
Block a user