[OPTIMIZER] cleaned, renamed and simplified some optimization passes (#1232)

This shouldn't actually change the behavior of Triton -- only clean things up.
This commit is contained in:
Philippe Tillet
2023-02-22 13:54:55 -08:00
committed by GitHub
parent ba0198326e
commit 0ec277efc5
18 changed files with 599 additions and 652 deletions

View File

@@ -1375,7 +1375,7 @@ void init_triton_ir(py::module &&m) {
.def(
"add_sccp_pass",
[](mlir::PassManager &self) { self.addPass(mlir::createSCCPPass()); })
.def("add_coalesce_pass",
.def("add_tritongpu_coalesce_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUCoalescePass());
})
@@ -1414,10 +1414,18 @@ void init_triton_ir(py::module &&m) {
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUPrefetchPass());
})
.def("add_tritongpu_combine_pass",
.def("add_tritongpu_accelerate_matmul_pass",
[](mlir::PassManager &self, int computeCapability) {
self.addPass(
mlir::createTritonGPUCombineOpsPass(computeCapability));
mlir::createTritonGPUAccelerateMatmulPass(computeCapability));
})
.def("add_tritongpu_fuse_transpositions_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUFuseTranspositionsPass());
})
.def("add_tritongpu_remove_layout_conversions_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPURemoveLayoutConversionsPass());
})
.def("add_tritongpu_update_mma_for_volta_pass",
[](mlir::PassManager &self) {

View File

@@ -975,32 +975,32 @@ def ast_to_ttir(fn, signature, specialization, constants):
return optimize_triton_ir(mod)
def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
def ttir_to_ttgir(mod, num_warps):
pm = _triton.ir.pass_manager(mod.context)
pm.add_convert_triton_to_tritongpu_pass(num_warps)
pm.run(mod)
return mod
def optimize_ttgir(mod, num_stages, compute_capability):
pm = _triton.ir.pass_manager(mod.context)
pm.enable_debug()
pm.add_coalesce_pass()
# The combine pass converts blocked layout to mma layout
# for dot ops so that pipeline can get shared memory swizzled correctly.
pm.add_tritongpu_combine_pass(compute_capability)
pm.add_tritongpu_coalesce_pass()
pm.add_tritongpu_accelerate_matmul_pass(compute_capability)
pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_fuse_transpositions_pass()
pm.add_tritongpu_pipeline_pass(num_stages)
# Prefetch must be done after pipeline pass because pipeline pass
# extracts slices from the original tensor.
pm.add_tritongpu_prefetch_pass()
pm.add_canonicalizer_pass()
pm.add_cse_pass()
pm.add_tritongpu_combine_pass(compute_capability)
pm.add_licm_pass()
pm.add_tritongpu_combine_pass(compute_capability)
pm.add_cse_pass()
pm.add_tritongpu_fuse_transpositions_pass()
pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_decompose_conversions_pass()
if compute_capability // 10 == 7:
# The update_mma_for_volta pass helps to compute some information for MMA encoding specifically for MMAv1
# NOTE this pass should be placed after all the passes those modifies mma layout
pm.add_tritongpu_update_mma_for_volta_pass()
pm.add_tritongpu_reorder_instructions_pass()
pm.add_cse_pass()
pm.add_symbol_dce_pass()
pm.add_tritongpu_reorder_instructions_pass()
pm.run(mod)
return mod
@@ -1565,7 +1565,7 @@ def compile(fn, **kwargs):
"ttir": (lambda path: parse_mlir_module(path, context),
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
"ttgir": (lambda path: parse_mlir_module(path, context),
lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
"llir": (lambda path: Path(path).read_text(),
lambda src: ttgir_to_llir(src, extern_libs, capability)),
"ptx": (lambda path: Path(path).read_text(),

View File

@@ -42,7 +42,8 @@ if __name__ == '__main__':
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
# triton-ir -> triton-gpu-ir
module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3, compute_capability=args.sm)
module = triton.compiler.ttir_to_ttgir(module, num_warps=4)
module = triton.compiler.optimize_ttgir(module, num_stages=3, compute_capability=args.sm)
if args.target == 'triton-gpu-ir':
print(module.str())
sys.exit(0)

View File

@@ -223,6 +223,7 @@ class _attention(torch.autograd.Function):
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=2,
)
# print(h.asm["ttgir"])
ctx.save_for_backward(q, k, v, o, L, m)
ctx.grid = grid
@@ -260,6 +261,7 @@ class _attention(torch.autograd.Function):
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
num_stages=1,
)
# print(h.asm["ttgir"])
return dq, dk, dv, None