mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(support): Use shared_ptr intead of raw ptr on JitLambdaSupport that allows the JitLambda to be used after that the compilation result is freed
This commit is contained in:
@@ -48,13 +48,13 @@ MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
jit_load_client_parameters(JITLambdaSupport_C support,
|
||||
mlir::concretelang::JitCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::JITLambda *
|
||||
MLIR_CAPI_EXPORTED std::shared_ptr<mlir::concretelang::JITLambda>
|
||||
jit_load_server_lambda(JITLambdaSupport_C support,
|
||||
mlir::concretelang::JitCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
jit_server_call(JITLambdaSupport_C support,
|
||||
mlir::concretelang::JITLambda *lambda,
|
||||
mlir::concretelang::JITLambda &lambda,
|
||||
concretelang::clientlib::PublicArguments &args);
|
||||
|
||||
// Library Support bindings ///////////////////////////////////////////////////
|
||||
|
||||
@@ -23,13 +23,14 @@ namespace clientlib = ::concretelang::clientlib;
|
||||
/// JitCompilationResult is the result of a Jit compilation, the server JIT
|
||||
/// lambda and the clientParameters.
|
||||
struct JitCompilationResult {
|
||||
std::unique_ptr<concretelang::JITLambda> lambda;
|
||||
std::shared_ptr<concretelang::JITLambda> lambda;
|
||||
clientlib::ClientParameters clientParameters;
|
||||
};
|
||||
|
||||
/// JitLambdaSupport is the instantiated LambdaSupport for the Jit Compilation.
|
||||
class JitLambdaSupport
|
||||
: public LambdaSupport<concretelang::JITLambda *, JitCompilationResult> {
|
||||
: public LambdaSupport<std::shared_ptr<concretelang::JITLambda>,
|
||||
JitCompilationResult> {
|
||||
|
||||
public:
|
||||
JitLambdaSupport(llvm::Optional<llvm::StringRef> runtimeLibPath = llvm::None);
|
||||
@@ -38,9 +39,9 @@ public:
|
||||
compile(llvm::SourceMgr &program, CompilationOptions options) override;
|
||||
using LambdaSupport::compile;
|
||||
|
||||
llvm::Expected<concretelang::JITLambda *>
|
||||
llvm::Expected<std::shared_ptr<concretelang::JITLambda>>
|
||||
loadServerLambda(JitCompilationResult &result) override {
|
||||
return result.lambda.get();
|
||||
return result.lambda;
|
||||
}
|
||||
|
||||
llvm::Expected<clientlib::ClientParameters>
|
||||
@@ -49,7 +50,7 @@ public:
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
serverCall(concretelang::JITLambda *lambda,
|
||||
serverCall(std::shared_ptr<concretelang::JITLambda> lambda,
|
||||
clientlib::PublicArguments &args) override {
|
||||
return lambda->call(args);
|
||||
}
|
||||
|
||||
@@ -54,7 +54,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
|
||||
pybind11::class_<mlir::concretelang::JitCompilationResult>(
|
||||
m, "JitCompilationResult");
|
||||
pybind11::class_<mlir::concretelang::JITLambda>(m, "JITLambda");
|
||||
pybind11::class_<mlir::concretelang::JITLambda,
|
||||
std::shared_ptr<mlir::concretelang::JITLambda>>(m,
|
||||
"JITLambda");
|
||||
pybind11::class_<JITLambdaSupport_C>(m, "JITLambdaSupport")
|
||||
.def(pybind11::init([](std::string runtimeLibPath) {
|
||||
return jit_lambda_support(runtimeLibPath.c_str());
|
||||
@@ -77,7 +79,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
},
|
||||
pybind11::return_value_policy::reference)
|
||||
.def("server_call",
|
||||
[](JITLambdaSupport_C &support, concretelang::JITLambda *lambda,
|
||||
[](JITLambdaSupport_C &support, concretelang::JITLambda &lambda,
|
||||
clientlib::PublicArguments &publicArguments) {
|
||||
return jit_server_call(support, lambda, publicArguments);
|
||||
});
|
||||
|
||||
@@ -43,7 +43,7 @@ jit_load_client_parameters(JITLambdaSupport_C support,
|
||||
return *clientParameters;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::JITLambda *
|
||||
MLIR_CAPI_EXPORTED std::shared_ptr<mlir::concretelang::JITLambda>
|
||||
jit_load_server_lambda(JITLambdaSupport_C support,
|
||||
mlir::concretelang::JitCompilationResult &result) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(serverLambda,
|
||||
@@ -53,10 +53,9 @@ jit_load_server_lambda(JITLambdaSupport_C support,
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
jit_server_call(JITLambdaSupport_C support,
|
||||
mlir::concretelang::JITLambda *lambda,
|
||||
mlir::concretelang::JITLambda &lambda,
|
||||
concretelang::clientlib::PublicArguments &args) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(publicResult,
|
||||
support.support.serverCall(lambda, args));
|
||||
GET_OR_THROW_LLVM_EXPECTED(publicResult, lambda.call(args));
|
||||
return std::move(*publicResult);
|
||||
}
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ JitLambdaSupport::compile(llvm::SourceMgr &program,
|
||||
return StreamStringError("No client parameters has been generated");
|
||||
}
|
||||
auto result = std::make_unique<JitCompilationResult>();
|
||||
result->lambda = std::move(*lambda);
|
||||
result->lambda = std::shared_ptr<concretelang::JITLambda>(std::move(*lambda));
|
||||
result->clientParameters =
|
||||
compilationResult.get().clientParameters.getValue();
|
||||
return std::move(result);
|
||||
|
||||
Reference in New Issue
Block a user