diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index 0a53d3541..1b076713f 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -38,7 +38,7 @@ struct JITLambdaSupport_C { typedef struct JITLambdaSupport_C JITLambdaSupport_C; MLIR_CAPI_EXPORTED JITLambdaSupport_C -jit_lambda_support(const char *runtimeLibPath); +jit_lambda_support(std::string runtimeLibPath); MLIR_CAPI_EXPORTED std::unique_ptr jit_compile(JITLambdaSupport_C support, const char *module, diff --git a/compiler/include/concretelang/Support/Jit.h b/compiler/include/concretelang/Support/Jit.h index 1f0df39e4..76613c937 100644 --- a/compiler/include/concretelang/Support/Jit.h +++ b/compiler/include/concretelang/Support/Jit.h @@ -32,7 +32,7 @@ public: static llvm::Expected> create(llvm::StringRef name, mlir::ModuleOp &module, llvm::function_ref optPipeline, - llvm::Optional runtimeLibPath = {}); + llvm::Optional runtimeLibPath = {}); /// Call the JIT lambda with the public arguments. llvm::Expected> diff --git a/compiler/include/concretelang/Support/JitLambdaSupport.h b/compiler/include/concretelang/Support/JitLambdaSupport.h index bc9a7da92..be185453e 100644 --- a/compiler/include/concretelang/Support/JitLambdaSupport.h +++ b/compiler/include/concretelang/Support/JitLambdaSupport.h @@ -33,7 +33,7 @@ class JitLambdaSupport JitCompilationResult> { public: - JitLambdaSupport(llvm::Optional runtimeLibPath = llvm::None); + JitLambdaSupport(llvm::Optional runtimeLibPath = llvm::None); llvm::Expected> compile(llvm::SourceMgr &program, CompilationOptions options) override; @@ -56,7 +56,7 @@ public: } private: - llvm::Optional runtimeLibPath; + llvm::Optional runtimeLibPath; llvm::function_ref llvmOptPipeline; }; diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index a34a65d68..d3af32d43 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -59,7 +59,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( "JITLambda"); pybind11::class_(m, "JITLambdaSupport") .def(pybind11::init([](std::string runtimeLibPath) { - return jit_lambda_support(runtimeLibPath.c_str()); + return jit_lambda_support(runtimeLibPath); })) .def("compile", [](JITLambdaSupport_C &support, std::string mlir_program, diff --git a/compiler/lib/Bindings/Python/concrete/compiler.py b/compiler/lib/Bindings/Python/concrete/compiler.py index d9812bb1c..4f009dac1 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler.py +++ b/compiler/lib/Bindings/Python/concrete/compiler.py @@ -96,7 +96,6 @@ class CompilerEngine: self, mlir_str: str, func_name: str = "main", - runtime_lib_path: str = None, unsecure_key_set_cache_path: str = None, auto_parallelize: bool = False, loop_parallelize: bool = False, @@ -107,7 +106,6 @@ class CompilerEngine: Args: mlir_str (str): MLIR to compile. func_name (str): name of the function to set as entrypoint (default: main). - runtime_lib_path (str): path to the runtime lib (default: None). unsecure_key_set_cache_path (str): path to the activate keyset caching (default: None). auto_parallelize (bool): whether to activate auto-parallelization or not (default: False), loop_parallelize (bool): whether to activate loop-parallelization or not (default: False), @@ -116,16 +114,6 @@ class CompilerEngine: Raises: TypeError: if the argument is not an str. """ - if not isinstance(mlir_str, str): - raise TypeError("input must be an `str`") - if runtime_lib_path is None: - # Set to empty string if not found - runtime_lib_path = _lookup_runtime_lib() - else: - if not isinstance(runtime_lib_path, str): - raise TypeError( - "runtime_lib_path must be an str representing the path to the runtime lib" - ) if not all( isinstance(flag, bool) for flag in [auto_parallelize, loop_parallelize, df_parallelize] @@ -263,6 +251,11 @@ class JITCompilerSupport: def __init__(self, runtime_lib_path=None): if runtime_lib_path is None: runtime_lib_path = _lookup_runtime_lib() + else: + if not isinstance(runtime_lib_path, str): + raise TypeError( + "runtime_lib_path must be an str representing the path to the runtime lib" + ) self._support = _JITLambdaSupport(runtime_lib_path) def compile(self, mlir_program: str, options: CompilationOptions = CompilationOptions("main")) -> JitCompilationResult: diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index 63f279523..b04b8e83d 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -21,9 +21,10 @@ // JIT Support bindings /////////////////////////////////////////////////////// MLIR_CAPI_EXPORTED JITLambdaSupport_C -jit_lambda_support(const char *runtimeLibPath) { - llvm::StringRef str(runtimeLibPath); - auto opt = str.empty() ? llvm::None : llvm::Optional(str); +jit_lambda_support(std::string runtimeLibPath) { + auto opt = runtimeLibPath.empty() + ? llvm::None + : llvm::Optional(runtimeLibPath); return JITLambdaSupport_C{mlir::concretelang::JitLambdaSupport(opt)}; } diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 379b5e001..d5027f337 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -23,7 +23,7 @@ namespace concretelang { llvm::Expected> JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module, llvm::function_ref optPipeline, - llvm::Optional runtimeLibPath) { + llvm::Optional runtimeLibPath) { // Looking for the function auto rangeOps = module.getOps(); diff --git a/compiler/lib/Support/JitLambdaSupport.cpp b/compiler/lib/Support/JitLambdaSupport.cpp index e53a5b546..72c94dcec 100644 --- a/compiler/lib/Support/JitLambdaSupport.cpp +++ b/compiler/lib/Support/JitLambdaSupport.cpp @@ -10,8 +10,7 @@ namespace mlir { namespace concretelang { -JitLambdaSupport::JitLambdaSupport( - llvm::Optional runtimeLibPath) +JitLambdaSupport::JitLambdaSupport(llvm::Optional runtimeLibPath) : runtimeLibPath(runtimeLibPath) {} llvm::Expected>