fix(support): Use shared_ptr intead of raw ptr on JitLambdaSupport that allows the JitLambda to be used after that the compilation result is freed

This commit is contained in:
Quentin Bourgerie
2022-03-24 10:19:58 +01:00
parent 6717c4f5ff
commit c70ef1dcda
5 changed files with 16 additions and 14 deletions

View File

@@ -48,13 +48,13 @@ MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
jit_load_client_parameters(JITLambdaSupport_C support,
mlir::concretelang::JitCompilationResult &);
MLIR_CAPI_EXPORTED mlir::concretelang::JITLambda *
MLIR_CAPI_EXPORTED std::shared_ptr<mlir::concretelang::JITLambda>
jit_load_server_lambda(JITLambdaSupport_C support,
mlir::concretelang::JitCompilationResult &);
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
jit_server_call(JITLambdaSupport_C support,
mlir::concretelang::JITLambda *lambda,
mlir::concretelang::JITLambda &lambda,
concretelang::clientlib::PublicArguments &args);
// Library Support bindings ///////////////////////////////////////////////////

View File

@@ -23,13 +23,14 @@ namespace clientlib = ::concretelang::clientlib;
/// JitCompilationResult is the result of a Jit compilation, the server JIT
/// lambda and the clientParameters.
struct JitCompilationResult {
std::unique_ptr<concretelang::JITLambda> lambda;
std::shared_ptr<concretelang::JITLambda> lambda;
clientlib::ClientParameters clientParameters;
};
/// JitLambdaSupport is the instantiated LambdaSupport for the Jit Compilation.
class JitLambdaSupport
: public LambdaSupport<concretelang::JITLambda *, JitCompilationResult> {
: public LambdaSupport<std::shared_ptr<concretelang::JITLambda>,
JitCompilationResult> {
public:
JitLambdaSupport(llvm::Optional<llvm::StringRef> runtimeLibPath = llvm::None);
@@ -38,9 +39,9 @@ public:
compile(llvm::SourceMgr &program, CompilationOptions options) override;
using LambdaSupport::compile;
llvm::Expected<concretelang::JITLambda *>
llvm::Expected<std::shared_ptr<concretelang::JITLambda>>
loadServerLambda(JitCompilationResult &result) override {
return result.lambda.get();
return result.lambda;
}
llvm::Expected<clientlib::ClientParameters>
@@ -49,7 +50,7 @@ public:
}
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
serverCall(concretelang::JITLambda *lambda,
serverCall(std::shared_ptr<concretelang::JITLambda> lambda,
clientlib::PublicArguments &args) override {
return lambda->call(args);
}

View File

@@ -54,7 +54,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
pybind11::class_<mlir::concretelang::JitCompilationResult>(
m, "JitCompilationResult");
pybind11::class_<mlir::concretelang::JITLambda>(m, "JITLambda");
pybind11::class_<mlir::concretelang::JITLambda,
std::shared_ptr<mlir::concretelang::JITLambda>>(m,
"JITLambda");
pybind11::class_<JITLambdaSupport_C>(m, "JITLambdaSupport")
.def(pybind11::init([](std::string runtimeLibPath) {
return jit_lambda_support(runtimeLibPath.c_str());
@@ -77,7 +79,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
},
pybind11::return_value_policy::reference)
.def("server_call",
[](JITLambdaSupport_C &support, concretelang::JITLambda *lambda,
[](JITLambdaSupport_C &support, concretelang::JITLambda &lambda,
clientlib::PublicArguments &publicArguments) {
return jit_server_call(support, lambda, publicArguments);
});

View File

@@ -43,7 +43,7 @@ jit_load_client_parameters(JITLambdaSupport_C support,
return *clientParameters;
}
MLIR_CAPI_EXPORTED mlir::concretelang::JITLambda *
MLIR_CAPI_EXPORTED std::shared_ptr<mlir::concretelang::JITLambda>
jit_load_server_lambda(JITLambdaSupport_C support,
mlir::concretelang::JitCompilationResult &result) {
GET_OR_THROW_LLVM_EXPECTED(serverLambda,
@@ -53,10 +53,9 @@ jit_load_server_lambda(JITLambdaSupport_C support,
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
jit_server_call(JITLambdaSupport_C support,
mlir::concretelang::JITLambda *lambda,
mlir::concretelang::JITLambda &lambda,
concretelang::clientlib::PublicArguments &args) {
GET_OR_THROW_LLVM_EXPECTED(publicResult,
support.support.serverCall(lambda, args));
GET_OR_THROW_LLVM_EXPECTED(publicResult, lambda.call(args));
return std::move(*publicResult);
}

View File

@@ -46,7 +46,7 @@ JitLambdaSupport::compile(llvm::SourceMgr &program,
return StreamStringError("No client parameters has been generated");
}
auto result = std::make_unique<JitCompilationResult>();
result->lambda = std::move(*lambda);
result->lambda = std::shared_ptr<concretelang::JITLambda>(std::move(*lambda));
result->clientParameters =
compilationResult.get().clientParameters.getValue();
return std::move(result);