// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include namespace { struct LinalgGenericOpWithTensorsToLoopsPass : public LinalgGenericOpWithTensorsToLoopsBase< LinalgGenericOpWithTensorsToLoopsPass> { LinalgGenericOpWithTensorsToLoopsPass() = delete; LinalgGenericOpWithTensorsToLoopsPass(bool parallelizeLoops) : parallelizeLoops(parallelizeLoops){}; void runOnOperation() final; private: bool parallelizeLoops; }; } // namespace template class LinalgRewritePattern : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LinalgRewritePattern(::mlir::MLIRContext *context, bool parallelizeLoops, mlir::PatternBenefit benefit = 0) : parallelizeLoops(parallelizeLoops), ::mlir::OpRewritePattern(context, benefit) {} mlir::LogicalResult matchAndRewrite(mlir::linalg::GenericOp linalgOp, mlir::PatternRewriter &rewriter) const override { mlir::FailureOr loops = mlir::concretelang::linalgextras::linalgTensorOpToLoopsImpl( rewriter, linalgOp, parallelizeLoops); if (((mlir::LogicalResult)loops).failed() || loops->size() == 0) return mlir::failure(); rewriter.replaceOp(linalgOp, loops.getValue()[0]->getResult(0)); return mlir::success(); }; private: bool parallelizeLoops; }; void LinalgGenericOpWithTensorsToLoopsPass::runOnOperation() { auto op = this->getOperation(); mlir::RewritePatternSet patterns(&getContext()); patterns.insert>(&getContext(), parallelizeLoops); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } namespace mlir { namespace concretelang { std::unique_ptr> createLinalgGenericOpWithTensorsToLoopsPass(bool parallelizeLoops) { return std::make_unique( parallelizeLoops); } } // namespace concretelang } // namespace mlir