mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER] Tweak warpsPerCTA based on the shape of MMA output (#2581)
In current implementation, warpsPerCTA is always set to [numWarps, 1] for 2 tt.dot fusion scenario. But, it is not optimal for cases such that tt.dot doesn't have enough parallelism on row dimension but on column dimension.
This commit is contained in:
@@ -143,15 +143,15 @@ flash_attention_data = {
|
||||
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.542,
|
||||
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.471,
|
||||
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.155,
|
||||
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.203,
|
||||
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.202,
|
||||
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.108,
|
||||
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.232,
|
||||
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.231,
|
||||
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.138,
|
||||
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.306,
|
||||
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.266,
|
||||
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.098,
|
||||
(4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.134,
|
||||
(4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135,
|
||||
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.066,
|
||||
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.092,
|
||||
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.541,
|
||||
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.471,
|
||||
(4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.150,
|
||||
|
||||
Reference in New Issue
Block a user