mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix: don't pass ref to runtimeLibPath
We were keeping a reference to the path which led to its use after the string was freed
This commit is contained in:
@@ -38,7 +38,7 @@ struct JITLambdaSupport_C {
|
||||
typedef struct JITLambdaSupport_C JITLambdaSupport_C;
|
||||
|
||||
MLIR_CAPI_EXPORTED JITLambdaSupport_C
|
||||
jit_lambda_support(const char *runtimeLibPath);
|
||||
jit_lambda_support(std::string runtimeLibPath);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::JitCompilationResult>
|
||||
jit_compile(JITLambdaSupport_C support, const char *module,
|
||||
|
||||
@@ -32,7 +32,7 @@ public:
|
||||
static llvm::Expected<std::unique_ptr<JITLambda>>
|
||||
create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline,
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath = {});
|
||||
llvm::Optional<std::string> runtimeLibPath = {});
|
||||
|
||||
/// Call the JIT lambda with the public arguments.
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
|
||||
@@ -33,7 +33,7 @@ class JitLambdaSupport
|
||||
JitCompilationResult> {
|
||||
|
||||
public:
|
||||
JitLambdaSupport(llvm::Optional<llvm::StringRef> runtimeLibPath = llvm::None);
|
||||
JitLambdaSupport(llvm::Optional<std::string> runtimeLibPath = llvm::None);
|
||||
|
||||
llvm::Expected<std::unique_ptr<JitCompilationResult>>
|
||||
compile(llvm::SourceMgr &program, CompilationOptions options) override;
|
||||
@@ -56,7 +56,7 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath;
|
||||
llvm::Optional<std::string> runtimeLibPath;
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> llvmOptPipeline;
|
||||
};
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
"JITLambda");
|
||||
pybind11::class_<JITLambdaSupport_C>(m, "JITLambdaSupport")
|
||||
.def(pybind11::init([](std::string runtimeLibPath) {
|
||||
return jit_lambda_support(runtimeLibPath.c_str());
|
||||
return jit_lambda_support(runtimeLibPath);
|
||||
}))
|
||||
.def("compile",
|
||||
[](JITLambdaSupport_C &support, std::string mlir_program,
|
||||
|
||||
@@ -96,7 +96,6 @@ class CompilerEngine:
|
||||
self,
|
||||
mlir_str: str,
|
||||
func_name: str = "main",
|
||||
runtime_lib_path: str = None,
|
||||
unsecure_key_set_cache_path: str = None,
|
||||
auto_parallelize: bool = False,
|
||||
loop_parallelize: bool = False,
|
||||
@@ -107,7 +106,6 @@ class CompilerEngine:
|
||||
Args:
|
||||
mlir_str (str): MLIR to compile.
|
||||
func_name (str): name of the function to set as entrypoint (default: main).
|
||||
runtime_lib_path (str): path to the runtime lib (default: None).
|
||||
unsecure_key_set_cache_path (str): path to the activate keyset caching (default: None).
|
||||
auto_parallelize (bool): whether to activate auto-parallelization or not (default: False),
|
||||
loop_parallelize (bool): whether to activate loop-parallelization or not (default: False),
|
||||
@@ -116,16 +114,6 @@ class CompilerEngine:
|
||||
Raises:
|
||||
TypeError: if the argument is not an str.
|
||||
"""
|
||||
if not isinstance(mlir_str, str):
|
||||
raise TypeError("input must be an `str`")
|
||||
if runtime_lib_path is None:
|
||||
# Set to empty string if not found
|
||||
runtime_lib_path = _lookup_runtime_lib()
|
||||
else:
|
||||
if not isinstance(runtime_lib_path, str):
|
||||
raise TypeError(
|
||||
"runtime_lib_path must be an str representing the path to the runtime lib"
|
||||
)
|
||||
if not all(
|
||||
isinstance(flag, bool)
|
||||
for flag in [auto_parallelize, loop_parallelize, df_parallelize]
|
||||
@@ -263,6 +251,11 @@ class JITCompilerSupport:
|
||||
def __init__(self, runtime_lib_path=None):
|
||||
if runtime_lib_path is None:
|
||||
runtime_lib_path = _lookup_runtime_lib()
|
||||
else:
|
||||
if not isinstance(runtime_lib_path, str):
|
||||
raise TypeError(
|
||||
"runtime_lib_path must be an str representing the path to the runtime lib"
|
||||
)
|
||||
self._support = _JITLambdaSupport(runtime_lib_path)
|
||||
|
||||
def compile(self, mlir_program: str, options: CompilationOptions = CompilationOptions("main")) -> JitCompilationResult:
|
||||
|
||||
@@ -21,9 +21,10 @@
|
||||
// JIT Support bindings ///////////////////////////////////////////////////////
|
||||
|
||||
MLIR_CAPI_EXPORTED JITLambdaSupport_C
|
||||
jit_lambda_support(const char *runtimeLibPath) {
|
||||
llvm::StringRef str(runtimeLibPath);
|
||||
auto opt = str.empty() ? llvm::None : llvm::Optional<llvm::StringRef>(str);
|
||||
jit_lambda_support(std::string runtimeLibPath) {
|
||||
auto opt = runtimeLibPath.empty()
|
||||
? llvm::None
|
||||
: llvm::Optional<std::string>(runtimeLibPath);
|
||||
return JITLambdaSupport_C{mlir::concretelang::JitLambdaSupport(opt)};
|
||||
}
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ namespace concretelang {
|
||||
llvm::Expected<std::unique_ptr<JITLambda>>
|
||||
JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline,
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath) {
|
||||
llvm::Optional<std::string> runtimeLibPath) {
|
||||
|
||||
// Looking for the function
|
||||
auto rangeOps = module.getOps<mlir::LLVM::LLVMFuncOp>();
|
||||
|
||||
@@ -10,8 +10,7 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
JitLambdaSupport::JitLambdaSupport(
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath)
|
||||
JitLambdaSupport::JitLambdaSupport(llvm::Optional<std::string> runtimeLibPath)
|
||||
: runtimeLibPath(runtimeLibPath) {}
|
||||
|
||||
llvm::Expected<std::unique_ptr<JitCompilationResult>>
|
||||
|
||||
Reference in New Issue
Block a user