diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index b08e3e973..0a53d3541 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -48,13 +48,13 @@ MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters jit_load_client_parameters(JITLambdaSupport_C support, mlir::concretelang::JitCompilationResult &); -MLIR_CAPI_EXPORTED mlir::concretelang::JITLambda * +MLIR_CAPI_EXPORTED std::shared_ptr jit_load_server_lambda(JITLambdaSupport_C support, mlir::concretelang::JitCompilationResult &); MLIR_CAPI_EXPORTED std::unique_ptr jit_server_call(JITLambdaSupport_C support, - mlir::concretelang::JITLambda *lambda, + mlir::concretelang::JITLambda &lambda, concretelang::clientlib::PublicArguments &args); // Library Support bindings /////////////////////////////////////////////////// diff --git a/compiler/include/concretelang/Support/JitLambdaSupport.h b/compiler/include/concretelang/Support/JitLambdaSupport.h index 397e74813..bc9a7da92 100644 --- a/compiler/include/concretelang/Support/JitLambdaSupport.h +++ b/compiler/include/concretelang/Support/JitLambdaSupport.h @@ -23,13 +23,14 @@ namespace clientlib = ::concretelang::clientlib; /// JitCompilationResult is the result of a Jit compilation, the server JIT /// lambda and the clientParameters. struct JitCompilationResult { - std::unique_ptr lambda; + std::shared_ptr lambda; clientlib::ClientParameters clientParameters; }; /// JitLambdaSupport is the instantiated LambdaSupport for the Jit Compilation. class JitLambdaSupport - : public LambdaSupport { + : public LambdaSupport, + JitCompilationResult> { public: JitLambdaSupport(llvm::Optional runtimeLibPath = llvm::None); @@ -38,9 +39,9 @@ public: compile(llvm::SourceMgr &program, CompilationOptions options) override; using LambdaSupport::compile; - llvm::Expected + llvm::Expected> loadServerLambda(JitCompilationResult &result) override { - return result.lambda.get(); + return result.lambda; } llvm::Expected @@ -49,7 +50,7 @@ public: } llvm::Expected> - serverCall(concretelang::JITLambda *lambda, + serverCall(std::shared_ptr lambda, clientlib::PublicArguments &args) override { return lambda->call(args); } diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index ddcbe739d..a34a65d68 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -54,7 +54,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( pybind11::class_( m, "JitCompilationResult"); - pybind11::class_(m, "JITLambda"); + pybind11::class_>(m, + "JITLambda"); pybind11::class_(m, "JITLambdaSupport") .def(pybind11::init([](std::string runtimeLibPath) { return jit_lambda_support(runtimeLibPath.c_str()); @@ -77,7 +79,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( }, pybind11::return_value_policy::reference) .def("server_call", - [](JITLambdaSupport_C &support, concretelang::JITLambda *lambda, + [](JITLambdaSupport_C &support, concretelang::JITLambda &lambda, clientlib::PublicArguments &publicArguments) { return jit_server_call(support, lambda, publicArguments); }); diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index a7800fa6b..de592ef2f 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -43,7 +43,7 @@ jit_load_client_parameters(JITLambdaSupport_C support, return *clientParameters; } -MLIR_CAPI_EXPORTED mlir::concretelang::JITLambda * +MLIR_CAPI_EXPORTED std::shared_ptr jit_load_server_lambda(JITLambdaSupport_C support, mlir::concretelang::JitCompilationResult &result) { GET_OR_THROW_LLVM_EXPECTED(serverLambda, @@ -53,10 +53,9 @@ jit_load_server_lambda(JITLambdaSupport_C support, MLIR_CAPI_EXPORTED std::unique_ptr jit_server_call(JITLambdaSupport_C support, - mlir::concretelang::JITLambda *lambda, + mlir::concretelang::JITLambda &lambda, concretelang::clientlib::PublicArguments &args) { - GET_OR_THROW_LLVM_EXPECTED(publicResult, - support.support.serverCall(lambda, args)); + GET_OR_THROW_LLVM_EXPECTED(publicResult, lambda.call(args)); return std::move(*publicResult); } diff --git a/compiler/lib/Support/JitLambdaSupport.cpp b/compiler/lib/Support/JitLambdaSupport.cpp index a07519732..e53a5b546 100644 --- a/compiler/lib/Support/JitLambdaSupport.cpp +++ b/compiler/lib/Support/JitLambdaSupport.cpp @@ -46,7 +46,7 @@ JitLambdaSupport::compile(llvm::SourceMgr &program, return StreamStringError("No client parameters has been generated"); } auto result = std::make_unique(); - result->lambda = std::move(*lambda); + result->lambda = std::shared_ptr(std::move(*lambda)); result->clientParameters = compilationResult.get().clientParameters.getValue(); return std::move(result);