[OPTIMIZER] Tweak warpsPerCTA based on the shape of MMA output (#2581)

In current implementation, warpsPerCTA is always set to [numWarps, 1]
for 2 tt.dot fusion scenario. But, it is not optimal for cases such that
tt.dot doesn't have enough parallelism on row dimension but on column
dimension.
This commit is contained in:
Weixing Zhang
2023-11-03 13:40:03 -07:00
committed by GitHub
parent 6ac9d51ff0
commit 34b89a1173
2 changed files with 11 additions and 6 deletions

View File

@@ -143,15 +143,15 @@ flash_attention_data = {
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.542,
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.471,
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.155,
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.203,
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.202,
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.108,
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.232,
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.231,
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.138,
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.306,
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.266,
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.098,
(4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.134,
(4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135,
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.066,
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.092,
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.541,
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.471,
(4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.150,