fix: don't pass ref to runtimeLibPath

We were keeping a reference to the path which led to its use after the
string was freed
This commit is contained in:
youben11
2022-03-25 16:55:07 +01:00
committed by Ayoub Benaissa
parent 9de1776753
commit ad79aa627f
8 changed files with 16 additions and 23 deletions

View File

@@ -38,7 +38,7 @@ struct JITLambdaSupport_C {
typedef struct JITLambdaSupport_C JITLambdaSupport_C;
MLIR_CAPI_EXPORTED JITLambdaSupport_C
jit_lambda_support(const char *runtimeLibPath);
jit_lambda_support(std::string runtimeLibPath);
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::JitCompilationResult>
jit_compile(JITLambdaSupport_C support, const char *module,

View File

@@ -32,7 +32,7 @@ public:
static llvm::Expected<std::unique_ptr<JITLambda>>
create(llvm::StringRef name, mlir::ModuleOp &module,
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline,
llvm::Optional<llvm::StringRef> runtimeLibPath = {});
llvm::Optional<std::string> runtimeLibPath = {});
/// Call the JIT lambda with the public arguments.
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>

View File

@@ -33,7 +33,7 @@ class JitLambdaSupport
JitCompilationResult> {
public:
JitLambdaSupport(llvm::Optional<llvm::StringRef> runtimeLibPath = llvm::None);
JitLambdaSupport(llvm::Optional<std::string> runtimeLibPath = llvm::None);
llvm::Expected<std::unique_ptr<JitCompilationResult>>
compile(llvm::SourceMgr &program, CompilationOptions options) override;
@@ -56,7 +56,7 @@ public:
}
private:
llvm::Optional<llvm::StringRef> runtimeLibPath;
llvm::Optional<std::string> runtimeLibPath;
llvm::function_ref<llvm::Error(llvm::Module *)> llvmOptPipeline;
};

View File

@@ -59,7 +59,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
"JITLambda");
pybind11::class_<JITLambdaSupport_C>(m, "JITLambdaSupport")
.def(pybind11::init([](std::string runtimeLibPath) {
return jit_lambda_support(runtimeLibPath.c_str());
return jit_lambda_support(runtimeLibPath);
}))
.def("compile",
[](JITLambdaSupport_C &support, std::string mlir_program,

View File

@@ -96,7 +96,6 @@ class CompilerEngine:
self,
mlir_str: str,
func_name: str = "main",
runtime_lib_path: str = None,
unsecure_key_set_cache_path: str = None,
auto_parallelize: bool = False,
loop_parallelize: bool = False,
@@ -107,7 +106,6 @@ class CompilerEngine:
Args:
mlir_str (str): MLIR to compile.
func_name (str): name of the function to set as entrypoint (default: main).
runtime_lib_path (str): path to the runtime lib (default: None).
unsecure_key_set_cache_path (str): path to the activate keyset caching (default: None).
auto_parallelize (bool): whether to activate auto-parallelization or not (default: False),
loop_parallelize (bool): whether to activate loop-parallelization or not (default: False),
@@ -116,16 +114,6 @@ class CompilerEngine:
Raises:
TypeError: if the argument is not an str.
"""
if not isinstance(mlir_str, str):
raise TypeError("input must be an `str`")
if runtime_lib_path is None:
# Set to empty string if not found
runtime_lib_path = _lookup_runtime_lib()
else:
if not isinstance(runtime_lib_path, str):
raise TypeError(
"runtime_lib_path must be an str representing the path to the runtime lib"
)
if not all(
isinstance(flag, bool)
for flag in [auto_parallelize, loop_parallelize, df_parallelize]
@@ -263,6 +251,11 @@ class JITCompilerSupport:
def __init__(self, runtime_lib_path=None):
if runtime_lib_path is None:
runtime_lib_path = _lookup_runtime_lib()
else:
if not isinstance(runtime_lib_path, str):
raise TypeError(
"runtime_lib_path must be an str representing the path to the runtime lib"
)
self._support = _JITLambdaSupport(runtime_lib_path)
def compile(self, mlir_program: str, options: CompilationOptions = CompilationOptions("main")) -> JitCompilationResult:

View File

@@ -21,9 +21,10 @@
// JIT Support bindings ///////////////////////////////////////////////////////
MLIR_CAPI_EXPORTED JITLambdaSupport_C
jit_lambda_support(const char *runtimeLibPath) {
llvm::StringRef str(runtimeLibPath);
auto opt = str.empty() ? llvm::None : llvm::Optional<llvm::StringRef>(str);
jit_lambda_support(std::string runtimeLibPath) {
auto opt = runtimeLibPath.empty()
? llvm::None
: llvm::Optional<std::string>(runtimeLibPath);
return JITLambdaSupport_C{mlir::concretelang::JitLambdaSupport(opt)};
}

View File

@@ -23,7 +23,7 @@ namespace concretelang {
llvm::Expected<std::unique_ptr<JITLambda>>
JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline,
llvm::Optional<llvm::StringRef> runtimeLibPath) {
llvm::Optional<std::string> runtimeLibPath) {
// Looking for the function
auto rangeOps = module.getOps<mlir::LLVM::LLVMFuncOp>();

View File

@@ -10,8 +10,7 @@
namespace mlir {
namespace concretelang {
JitLambdaSupport::JitLambdaSupport(
llvm::Optional<llvm::StringRef> runtimeLibPath)
JitLambdaSupport::JitLambdaSupport(llvm::Optional<std::string> runtimeLibPath)
: runtimeLibPath(runtimeLibPath) {}
llvm::Expected<std::unique_ptr<JitCompilationResult>>