// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include "concretelang/Transforms/Passes.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Transforms/DialectConversion.h" namespace { struct LinalgFillToLinalgGenericPattern : public mlir::OpRewritePattern { LinalgFillToLinalgGenericPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} ::mlir::LogicalResult matchAndRewrite(mlir::linalg::FillOp fillOp, ::mlir::PatternRewriter &rewriter) const override { if (fillOp.getOutputs().size() != 1) return ::mlir::failure(); mlir::RankedTensorType outputTensorType = fillOp.getOutputs()[0].getType().cast(); llvm::SmallVector iteratorTypes( outputTensorType.getRank(), mlir::utils::IteratorType::parallel); mlir::AffineMap map = mlir::AffineMap::getMultiDimIdentityMap( outputTensorType.getRank(), this->getContext()); mlir::SmallVector maps(1, map); auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange blockArgs) { nestedBuilder.create(nestedLoc, fillOp.getInputs()[0]); }; rewriter.replaceOpWithNewOp( fillOp, fillOp.getOutputs().getTypes(), mlir::ValueRange{}, fillOp.getOutputs(), maps, iteratorTypes, bodyBuilder); return ::mlir::success(); }; }; struct LinalgFillToLinalgGenericPass : public LinalgFillToLinalgGenericBase { LinalgFillToLinalgGenericPass() {} void runOnOperation() override { auto op = this->getOperation(); mlir::ConversionTarget target(getContext()); mlir::RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); target.addIllegalOp(); target.markUnknownOpDynamicallyLegal( [](mlir::Operation *) { return true; }); // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)) .failed()) { this->signalPassFailure(); } } }; } // namespace namespace mlir { namespace concretelang { std::unique_ptr> createLinalgFillToLinalgGenericPass() { return std::make_unique(); } } // namespace concretelang } // namespace mlir