mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER] cleaned, renamed and simplified some optimization passes (#1232)
This shouldn't actually change the behavior of Triton -- only clean things up.
This commit is contained in:
@@ -1375,7 +1375,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.def(
|
||||
"add_sccp_pass",
|
||||
[](mlir::PassManager &self) { self.addPass(mlir::createSCCPPass()); })
|
||||
.def("add_coalesce_pass",
|
||||
.def("add_tritongpu_coalesce_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUCoalescePass());
|
||||
})
|
||||
@@ -1414,10 +1414,18 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUPrefetchPass());
|
||||
})
|
||||
.def("add_tritongpu_combine_pass",
|
||||
.def("add_tritongpu_accelerate_matmul_pass",
|
||||
[](mlir::PassManager &self, int computeCapability) {
|
||||
self.addPass(
|
||||
mlir::createTritonGPUCombineOpsPass(computeCapability));
|
||||
mlir::createTritonGPUAccelerateMatmulPass(computeCapability));
|
||||
})
|
||||
.def("add_tritongpu_fuse_transpositions_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUFuseTranspositionsPass());
|
||||
})
|
||||
.def("add_tritongpu_remove_layout_conversions_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPURemoveLayoutConversionsPass());
|
||||
})
|
||||
.def("add_tritongpu_update_mma_for_volta_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
|
||||
@@ -975,32 +975,32 @@ def ast_to_ttir(fn, signature, specialization, constants):
|
||||
return optimize_triton_ir(mod)
|
||||
|
||||
|
||||
def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
|
||||
def ttir_to_ttgir(mod, num_warps):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def optimize_ttgir(mod, num_stages, compute_capability):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_coalesce_pass()
|
||||
# The combine pass converts blocked layout to mma layout
|
||||
# for dot ops so that pipeline can get shared memory swizzled correctly.
|
||||
pm.add_tritongpu_combine_pass(compute_capability)
|
||||
pm.add_tritongpu_coalesce_pass()
|
||||
pm.add_tritongpu_accelerate_matmul_pass(compute_capability)
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
pm.add_tritongpu_fuse_transpositions_pass()
|
||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
# Prefetch must be done after pipeline pass because pipeline pass
|
||||
# extracts slices from the original tensor.
|
||||
pm.add_tritongpu_prefetch_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_tritongpu_combine_pass(compute_capability)
|
||||
pm.add_licm_pass()
|
||||
pm.add_tritongpu_combine_pass(compute_capability)
|
||||
pm.add_cse_pass()
|
||||
pm.add_tritongpu_fuse_transpositions_pass()
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
pm.add_tritongpu_decompose_conversions_pass()
|
||||
if compute_capability // 10 == 7:
|
||||
# The update_mma_for_volta pass helps to compute some information for MMA encoding specifically for MMAv1
|
||||
# NOTE this pass should be placed after all the passes those modifies mma layout
|
||||
pm.add_tritongpu_update_mma_for_volta_pass()
|
||||
pm.add_tritongpu_reorder_instructions_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_symbol_dce_pass()
|
||||
pm.add_tritongpu_reorder_instructions_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
@@ -1565,7 +1565,7 @@ def compile(fn, **kwargs):
|
||||
"ttir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
"ttgir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
|
||||
"llir": (lambda path: Path(path).read_text(),
|
||||
lambda src: ttgir_to_llir(src, extern_libs, capability)),
|
||||
"ptx": (lambda path: Path(path).read_text(),
|
||||
|
||||
@@ -42,7 +42,8 @@ if __name__ == '__main__':
|
||||
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
|
||||
|
||||
# triton-ir -> triton-gpu-ir
|
||||
module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3, compute_capability=args.sm)
|
||||
module = triton.compiler.ttir_to_ttgir(module, num_warps=4)
|
||||
module = triton.compiler.optimize_ttgir(module, num_stages=3, compute_capability=args.sm)
|
||||
if args.target == 'triton-gpu-ir':
|
||||
print(module.str())
|
||||
sys.exit(0)
|
||||
|
||||
@@ -223,6 +223,7 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||
num_stages=2,
|
||||
)
|
||||
# print(h.asm["ttgir"])
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.grid = grid
|
||||
@@ -260,6 +261,7 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||
num_stages=1,
|
||||
)
|
||||
# print(h.asm["ttgir"])
|
||||
return dq, dk, dv, None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user