// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include #include #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Transforms/DialectConversion.h" #include #include #include #include #include #include #include #include #include #include #include #include #define GEN_PASS_CLASSES #include namespace mlir { namespace concretelang { namespace { class BufferizeRTTypesConverter : public mlir::bufferization::BufferizeTypeConverter { public: BufferizeRTTypesConverter() { addConversion([&](mlir::concretelang::RT::FutureType type) { return mlir::concretelang::RT::FutureType::get( this->convertType(type.dyn_cast() .getElementType())); }); addConversion([&](mlir::concretelang::RT::PointerType type) { return mlir::concretelang::RT::PointerType::get( this->convertType(type.dyn_cast() .getElementType())); }); addConversion([&](mlir::FunctionType type) { SignatureConversion result(type.getNumInputs()); mlir::SmallVector newResults; if (failed(this->convertSignatureArgs(type.getInputs(), result)) || failed(this->convertTypes(type.getResults(), newResults))) return type; return mlir::FunctionType::get(type.getContext(), result.getConvertedTypes(), newResults); }); } }; } // namespace namespace { /// For documentation see Autopar.td struct BufferizeDataflowTaskOpsPass : public BufferizeDataflowTaskOpsBase { void runOnOperation() override { auto module = getOperation(); auto *context = &getContext(); BufferizeRTTypesConverter typeConverter; RewritePatternSet patterns(context); ConversionTarget target(*context); populateFunctionOpInterfaceTypeConversionPattern( patterns, typeConverter); patterns.add>( context, typeConverter); target.addDynamicallyLegalDialect([&](Operation *op) { if (auto fun = dyn_cast_or_null(op)) return typeConverter.isSignatureLegal(fun.getFunctionType()) && typeConverter.isLegal(&fun.getBody()); if (auto fun = dyn_cast_or_null(op)) return FunctionConstantOpConversion::isLegal( fun, typeConverter); return typeConverter.isLegal(op); }); patterns.add< mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::DataflowTaskOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::DataflowYieldOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::MakeReadyFutureOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::AwaitFutureOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::CreateAsyncTaskOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::WorkFunctionReturnOp>, mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(), typeConverter); // Conversion of RT Dialect Ops mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::DataflowTaskOp>(target, typeConverter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::DataflowYieldOp>(target, typeConverter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::MakeReadyFutureOp>(target, typeConverter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::AwaitFutureOp>(target, typeConverter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::CreateAsyncTaskOp>(target, typeConverter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target, typeConverter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>( target, typeConverter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target, typeConverter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::WorkFunctionReturnOp>(target, typeConverter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target, typeConverter); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } BufferizeDataflowTaskOpsPass(bool debug) : debug(debug){}; protected: bool debug; }; } // end anonymous namespace std::unique_ptr createBufferizeDataflowTaskOpsPass(bool debug) { return std::make_unique(debug); } } // namespace concretelang } // namespace mlir