From aa2e0479b3e3e86adb0754f0d941a5fcb634ae89 Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Mon, 19 Dec 2022 09:54:59 +0000 Subject: [PATCH] feat(compiler): add a parallel loop coalescing pass. --- .../include/concretelang/Transforms/Passes.h | 2 + .../include/concretelang/Transforms/Passes.td | 9 ++ compiler/lib/Support/Pipeline.cpp | 3 +- compiler/lib/Transforms/CMakeLists.txt | 1 + .../lib/Transforms/CollapseParallelLoops.cpp | 100 ++++++++++++++++++ 5 files changed, 114 insertions(+), 1 deletion(-) create mode 100644 compiler/lib/Transforms/CollapseParallelLoops.cpp diff --git a/compiler/include/concretelang/Transforms/Passes.h b/compiler/include/concretelang/Transforms/Passes.h index 72f94f2e7..c6d827585 100644 --- a/compiler/include/concretelang/Transforms/Passes.h +++ b/compiler/include/concretelang/Transforms/Passes.h @@ -17,6 +17,8 @@ namespace mlir { namespace concretelang { +std::unique_ptr> +createCollapseParallelLoops(); std::unique_ptr> createForLoopToParallel(); std::unique_ptr> createBatchingPass(); } // namespace concretelang diff --git a/compiler/include/concretelang/Transforms/Passes.td b/compiler/include/concretelang/Transforms/Passes.td index b0d99a513..d4af1604d 100644 --- a/compiler/include/concretelang/Transforms/Passes.td +++ b/compiler/include/concretelang/Transforms/Passes.td @@ -3,6 +3,15 @@ include "mlir/Pass/PassBase.td" +def CollapseParallelLoops : Pass<"collapse-parallel-loops", "mlir::ModuleOp"> { + let summary = + "Coalesce nested scf.for operations that are marked with " + "the custom attribute parallel = true into a single scf.for " + "loop which can subsequently be converted to scf.parallel."; + let constructor = "mlir::concretelang::createCollapseParallelLoops()"; + let dependentDialects = ["mlir::scf::SCFDialect"]; +} + def ForLoopToParallel : Pass<"for-loop-to-parallel", "mlir::ModuleOp"> { let summary = "Transform scf.for marked with the custom attribute parallel = true loop " diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 76e2528ff..fe60c3dff 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -341,7 +341,8 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, pm, mlir::concretelang::createBufferizeDataflowTaskOpsPass(), enablePass); if (parallelizeLoops) { - addPotentiallyNestedPass(pm, mlir::createLoopCoalescingPass(), enablePass); + addPotentiallyNestedPass( + pm, mlir::concretelang::createCollapseParallelLoops(), enablePass); addPotentiallyNestedPass(pm, mlir::concretelang::createForLoopToParallel(), enablePass); } diff --git a/compiler/lib/Transforms/CMakeLists.txt b/compiler/lib/Transforms/CMakeLists.txt index 81927772b..d1343f988 100644 --- a/compiler/lib/Transforms/CMakeLists.txt +++ b/compiler/lib/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_library( ConcretelangTransforms Batching.cpp + CollapseParallelLoops.cpp ForLoopToParallel.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/concretelang/Transforms diff --git a/compiler/lib/Transforms/CollapseParallelLoops.cpp b/compiler/lib/Transforms/CollapseParallelLoops.cpp new file mode 100644 index 000000000..5b7be6606 --- /dev/null +++ b/compiler/lib/Transforms/CollapseParallelLoops.cpp @@ -0,0 +1,100 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "concretelang/Transforms/Passes.h" + +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Operation.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include + +namespace { +struct CollapseParallelLoopsPass + : public CollapseParallelLoopsBase { + + /// Walk either an scf.for or an affine.for to find a band to coalesce. + template static void walkLoop(LoopOpTy op) {} + + void runOnOperation() override { + mlir::ModuleOp module = getOperation(); + module.walk([&](mlir::scf::ForOp forOp) { + // Ignore nested loops. + if (forOp->getParentOfType()) + return; + + // Determine which sequences of nested loops can be coalesced + // TODO: add loop interchange and hoisting to find more + // opportunities by getting multiple parallel loops in sequence + mlir::SmallVector loops; + getPerfectlyNestedLoops(loops, forOp); + mlir::SmallVector coalesceableLoopRanges(loops.size()); + for (unsigned i = 0, e = loops.size(); i < e; ++i) { + // Any loop is coalesceable to itself + coalesceableLoopRanges[i] = i; + + // The outermost loop doesn't have any outer loop to collapse into + if (i == 0) + continue; + + // A loop will only be coalesced with another if both are + // parallel. Otherwise it is irrelevant in this pass. + // If this loop itself is not parallel, then nothing we can do. + auto attr = loops[i]->getAttrOfType("parallel"); + if (attr == nullptr || attr.getValue() == false) + continue; + + // Find how many loops are able to be coalesced + for (unsigned j = 0; j < i; ++j) { + if (mlir::areValuesDefinedAbove(loops[i].getOperands(), + loops[j].getRegion())) { + coalesceableLoopRanges[i] = j; + break; + } + } + // Now ensure that all loops in this sequence + // [coalesceableLoopRanges[i], i] are parallel. Otherwise + // update the range's lower bound. + for (int k = i - 1; k >= (int)coalesceableLoopRanges[i]; --k) { + auto attrK = loops[k]->getAttrOfType("parallel"); + if (attrK == nullptr || attrK.getValue() == false) { + coalesceableLoopRanges[i] = k + 1; + break; + } + } + } + + for (unsigned end = loops.size(); end > 0; --end) { + unsigned start = 0; + for (; start < end - 1; ++start) { + auto maxPos = *std::max_element( + std::next(coalesceableLoopRanges.begin(), start), + std::next(coalesceableLoopRanges.begin(), end)); + if (maxPos > start) + continue; + + auto band = + llvm::makeMutableArrayRef(loops.data() + start, end - start); + (void)mlir::coalesceLoops(band); + break; + } + // If a band was found and transformed, keep looking at the loops above + // the outermost transformed loop. + if (start != end - 1) + end = start + 1; + } + }); + } +}; +} // namespace + +std::unique_ptr> +mlir::concretelang::createCollapseParallelLoops() { + return std::make_unique(); +}