// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include #include #include #include #include #include #include #include #include #include #include #define GEN_PASS_CLASSES #include namespace mlir { namespace concretelang { namespace { class BufferizeDataflowYieldOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RT::DataflowYieldOp op, RT::DataflowYieldOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, mlir::TypeRange(), adaptor.getOperands()); return success(); } }; } // namespace namespace { class BufferizeDataflowTaskOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RT::DataflowTaskOp op, RT::DataflowTaskOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::OpBuilder::InsertionGuard guard(rewriter); SmallVector newResults; (void)getTypeConverter()->convertTypes(op.getResultTypes(), newResults); auto newop = rewriter.create(op.getLoc(), newResults, adaptor.getOperands()); // We cannot clone here as cloned ops must be legalized (so this // would break on the YieldOp). Instead use mergeBlocks which // moves the ops instead of cloning. rewriter.mergeBlocks(op.getBody(), newop.getBody(), newop.getBody()->getArguments()); // Because of previous bufferization there are buffer cast ops // that have been generated for the previously tensor results of // some tasks. These cannot just be replaced directly as the // task's results would still be live. for (auto res : llvm::enumerate(op.getResults())) { // If this result is getting bufferized ... if (res.value().getType() != getTypeConverter()->convertType(res.value().getType())) { for (auto &use : llvm::make_early_inc_range(res.value().getUses())) { // ... and its uses are in `ToMemrefOp`s, then we // replace further uses of the buffer cast. if (isa(use.getOwner())) { rewriter.replaceOp(use.getOwner(), {newop.getResult(res.index())}); } } } } rewriter.replaceOp(op, {newop.getResults()}); return success(); } }; } // namespace void populateRTBufferizePatterns( mlir::bufferization::BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add( typeConverter, patterns.getContext()); } namespace { /// For documentation see Autopar.td struct BufferizeDataflowTaskOpsPass : public BufferizeDataflowTaskOpsBase { void runOnOperation() override { auto module = getOperation(); auto *context = &getContext(); mlir::bufferization::BufferizeTypeConverter typeConverter; RewritePatternSet patterns(context); ConversionTarget target(*context); populateRTBufferizePatterns(typeConverter, patterns); // Forbid all RT ops that still use/return tensors target.addDynamicallyLegalDialect( [&](Operation *op) { return typeConverter.isLegal(op); }); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } BufferizeDataflowTaskOpsPass(bool debug) : debug(debug){}; protected: bool debug; }; } // end anonymous namespace std::unique_ptr createBufferizeDataflowTaskOpsPass(bool debug) { return std::make_unique(debug); } } // namespace concretelang } // namespace mlir