// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete/blob/main/LICENSE.txt // for license information. #include #include #include #include #include #include #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Transforms/DialectConversion.h" #include #include #include #include #include #include #include #include #include #include #include #include #define GEN_PASS_CLASSES #include namespace mlir { namespace concretelang { namespace { class BufferizeRTTypesConverter : public mlir::TypeConverter { protected: bufferization::BufferizeTypeConverter btc; public: BufferizeRTTypesConverter() { addConversion([&](mlir::Type type) { return btc.convertType(type); }); addConversion([&](mlir::RankedTensorType type) { return mlir::MemRefType::get(type.getShape(), this->convertType(type.getElementType())); }); addConversion([&](mlir::UnrankedTensorType type) { return mlir::UnrankedMemRefType::get( this->convertType(type.getElementType()), 0); }); addConversion([&](mlir::MemRefType type) { return mlir::MemRefType::get(type.getShape(), this->convertType(type.getElementType()), type.getLayout(), type.getMemorySpace()); }); addConversion([&](mlir::UnrankedMemRefType type) { return mlir::UnrankedMemRefType::get( this->convertType(type.getElementType()), type.getMemorySpace()); }); addConversion([&](mlir::concretelang::RT::FutureType type) { return mlir::concretelang::RT::FutureType::get( this->convertType(type.getElementType())); }); addConversion([&](mlir::concretelang::RT::PointerType type) { return mlir::concretelang::RT::PointerType::get( this->convertType(type.getElementType())); }); addConversion([&](mlir::FunctionType type) { SignatureConversion result(type.getNumInputs()); mlir::SmallVector newResults; if (failed(this->convertSignatureArgs(type.getInputs(), result)) || failed(this->convertTypes(type.getResults(), newResults))) { return type; } return mlir::FunctionType::get(type.getContext(), result.getConvertedTypes(), newResults); }); } }; } // namespace namespace { /// For documentation see Autopar.td struct BufferizeDataflowTaskOpsPass : public BufferizeDataflowTaskOpsBase { void runOnOperation() override { auto module = getOperation(); auto *context = &getContext(); BufferizeRTTypesConverter typeConverter; RewritePatternSet patterns(context); ConversionTarget target(*context); populateFunctionOpInterfaceTypeConversionPattern( patterns, typeConverter); patterns.add>( context, typeConverter); target.addDynamicallyLegalDialect([&](Operation *op) { if (auto fun = dyn_cast_or_null(op)) return typeConverter.isSignatureLegal(fun.getFunctionType()) && typeConverter.isLegal(&fun.getBody()); if (auto fun = dyn_cast_or_null(op)) return FunctionConstantOpConversion::isLegal( fun, typeConverter); return typeConverter.isLegal(op); }); mlir::concretelang::populateWithRTTypeConverterPatterns(patterns, target, typeConverter); patterns.add, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::memref::LoadOp>, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::memref::StoreOp>, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::memref::CopyOp>, mlir::concretelang::TypeConvertingReinstantiationPattern< mlir::memref::SubViewOp, true>>(&getContext(), typeConverter); target.addDynamicallyLegalOp( [&](mlir::Operation *op) { return typeConverter.isLegal(op->getResultTypes()) && typeConverter.isLegal(op->getOperandTypes()); }); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } BufferizeDataflowTaskOpsPass(bool debug) : debug(debug){}; protected: bool debug; }; } // end anonymous namespace std::unique_ptr createBufferizeDataflowTaskOpsPass(bool debug) { return std::make_unique(debug); } } // namespace concretelang } // namespace mlir