[TRITON][OPS] add Flash Attention v2 to Ops (#1970)

I also dropped the do_scaled as it is no longer needed (no scaling done
to the do in v2).

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Izzy Putterman
2023-07-23 14:07:15 -07:00
committed by GitHub
parent c9ab44888e
commit de6f053c0f
3 changed files with 86 additions and 131 deletions

View File

@@ -155,27 +155,27 @@ def test_elementwise(N, dtype_str):
flash_attention_data = {
"a100": {
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.433,
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.392,
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.106,
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.532,
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.471,
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.150,
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.204,
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.202,
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.089,
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.242,
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.220,
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.069,
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.298,
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.263,
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.095,
(4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.136,
(4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135,
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.052,
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.432,
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.392,
(4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.107,
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.525,
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.471,
(4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.150,
(4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.265,
(4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.257,
(4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.128,
(4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.251,
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.220,
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.069,
(4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.297,
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.263,
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.095,
(4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.159,
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.138,
(4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.076,