[OPTIMIZATION] Enable pipelining for bwd flash attention (#2590)

This allow pipelining when a load is used by multiple dot in a loop.

Relax the condition to pipeline dot operands for mma v3 case. This
improves performance for the bwd pass from 260TF to 275TF. However this
expose a performance problem due to the wmma pipelining as ptxas will
now fall back to serial wgmma. A follow up PR will fix a bug in how we
emit wgmma_wait during pipelining and will bring performance to 335TF
This commit is contained in:
Thomas Raoux
2023-11-03 11:46:51 -07:00
committed by GitHub
parent df08301e76
commit 6ac9d51ff0
4 changed files with 86 additions and 24 deletions

View File

@@ -111,6 +111,7 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, e
if optimize_epilogue:
pm.add_tritongpu_optimize_epilogue_pass()
pm.add_tritongpu_optimize_dot_operands_pass()
pm.add_cse_pass()
ws_enabled = False
# `num_warps` does not mean the total number of warps of a CTA when
# warp specialization is enabled.