mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZATION] Enable pipelining for bwd flash attention (#2590)
This allow pipelining when a load is used by multiple dot in a loop. Relax the condition to pipeline dot operands for mma v3 case. This improves performance for the bwd pass from 260TF to 275TF. However this expose a performance problem due to the wmma pipelining as ptxas will now fall back to serial wgmma. A follow up PR will fix a bug in how we emit wgmma_wait during pipelining and will bring performance to 335TF
This commit is contained in:
@@ -111,6 +111,7 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, e
|
||||
if optimize_epilogue:
|
||||
pm.add_tritongpu_optimize_epilogue_pass()
|
||||
pm.add_tritongpu_optimize_dot_operands_pass()
|
||||
pm.add_cse_pass()
|
||||
ws_enabled = False
|
||||
# `num_warps` does not mean the total number of warps of a CTA when
|
||||
# warp specialization is enabled.
|
||||
|
||||
Reference in New Issue
Block a user