mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TRITON][OPS] add Flash Attention v2 to Ops (#1970)
I also dropped the do_scaled as it is no longer needed (no scaling done to the do in v2). --------- Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -155,27 +155,27 @@ def test_elementwise(N, dtype_str):
|
||||
|
||||
flash_attention_data = {
|
||||
"a100": {
|
||||
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.433,
|
||||
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.392,
|
||||
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.106,
|
||||
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.532,
|
||||
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.471,
|
||||
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.150,
|
||||
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.204,
|
||||
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.202,
|
||||
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.089,
|
||||
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.242,
|
||||
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.220,
|
||||
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.069,
|
||||
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.298,
|
||||
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.263,
|
||||
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.095,
|
||||
(4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.136,
|
||||
(4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135,
|
||||
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.052,
|
||||
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.432,
|
||||
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.392,
|
||||
(4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.107,
|
||||
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.525,
|
||||
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.471,
|
||||
(4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.150,
|
||||
(4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.265,
|
||||
(4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.257,
|
||||
(4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.128,
|
||||
(4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.251,
|
||||
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.220,
|
||||
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.069,
|
||||
(4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.297,
|
||||
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.263,
|
||||
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.095,
|
||||
(4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.159,
|
||||
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.138,
|
||||
(4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.076,
|
||||
|
||||
Reference in New Issue
Block a user