// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include "concretelang/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Operation.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include namespace { class ForOpPattern : public mlir::OpRewritePattern { public: ForOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} mlir::LogicalResult matchAndRewrite(mlir::scf::ForOp forOp, mlir::PatternRewriter &rewriter) const override { auto attr = forOp->getAttrOfType("parallel"); if (attr == nullptr) { return mlir::failure(); } assert(forOp.getRegionIterArgs().size() == 0 && "unexpecting iter args when loops are bufferized"); if (attr.getValue()) { rewriter.replaceOpWithNewOp( forOp, mlir::ValueRange{forOp.getLowerBound()}, mlir::ValueRange{forOp.getUpperBound()}, forOp.getStep(), std::nullopt, [&](mlir::OpBuilder &builder, mlir::Location location, mlir::ValueRange indVar, mlir::ValueRange iterArgs) { mlir::IRMapping map; map.map(forOp.getInductionVar(), indVar.front()); for (auto &op : forOp.getRegion().front()) { auto newOp = builder.clone(op, map); map.map(op.getResults(), newOp->getResults()); } }); } else { rewriter.replaceOpWithNewOp( forOp, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), std::nullopt, [&](mlir::OpBuilder &builder, mlir::Location location, mlir::Value indVar, mlir::ValueRange iterArgs) { mlir::IRMapping map; map.map(forOp.getInductionVar(), indVar); for (auto &op : forOp.getRegion().front()) { auto newOp = builder.clone(op, map); map.map(op.getResults(), newOp->getResults()); } }); } return mlir::success(); } }; } // namespace namespace { struct ForLoopToParallelPass : public ForLoopToParallelBase { void runOnOperation() override { auto func = getOperation(); auto *context = &getContext(); mlir::RewritePatternSet patterns(context); mlir::ConversionTarget target(*context); patterns.add(context); target.addDynamicallyLegalOp([&](mlir::scf::ForOp op) { auto r = op->getAttrOfType("parallel") == nullptr; return r; }); target.markUnknownOpDynamicallyLegal( [&](mlir::Operation *op) { return true; }); if (mlir::applyPatternsAndFoldGreedily(func, std::move(patterns)) .failed()) { this->signalPassFailure(); }; } }; } // namespace std::unique_ptr> mlir::concretelang::createForLoopToParallel() { return std::make_unique(); }