diff --git a/compiler/include/zamalang-c/Support/CompilerEngine.h b/compiler/include/zamalang-c/Support/CompilerEngine.h index 25f2dd931..df2e90560 100644 --- a/compiler/include/zamalang-c/Support/CompilerEngine.h +++ b/compiler/include/zamalang-c/Support/CompilerEngine.h @@ -31,9 +31,11 @@ struct executionArguments { typedef struct executionArguments exectuionArguments; // Build lambda from a textual representation of an MLIR module -// The lambda will have `funcName` as entrypoint +// The lambda will have `funcName` as entrypoint, and use runtimeLibPath (if not +// null) as a shared library during compilation MLIR_CAPI_EXPORTED mlir::zamalang::JitCompilerEngine::Lambda -buildLambda(const char *module, const char *funcName); +buildLambda(const char *module, const char *funcName, + const char *runtimeLibPath); // Parse then print a textual representation of an MLIR module MLIR_CAPI_EXPORTED std::string roundTrip(const char *module); diff --git a/compiler/include/zamalang/Support/Jit.h b/compiler/include/zamalang/Support/Jit.h index 35ef9cf8c..61bfff18e 100644 --- a/compiler/include/zamalang/Support/Jit.h +++ b/compiler/include/zamalang/Support/Jit.h @@ -101,9 +101,11 @@ public: : type(type), name(name){}; /// create a JITLambda that point to the function name of the given module. + /// Use runtimeLibPath as a shared library if specified. static llvm::Expected> create(llvm::StringRef name, mlir::ModuleOp &module, - llvm::function_ref optPipeline); + llvm::function_ref optPipeline, + llvm::Optional runtimeLibPath = {}); /// invokeRaw execute the jit lambda with a list of Argument, the last one is /// used to store the result of the computation. diff --git a/compiler/include/zamalang/Support/JitCompilerEngine.h b/compiler/include/zamalang/Support/JitCompilerEngine.h index 49de95e2f..018575bea 100644 --- a/compiler/include/zamalang/Support/JitCompilerEngine.h +++ b/compiler/include/zamalang/Support/JitCompilerEngine.h @@ -359,14 +359,20 @@ public: CompilationContext::createShared(), unsigned int optimizationLevel = 3); - llvm::Expected buildLambda(llvm::StringRef src, - llvm::StringRef funcName = "main"); + /// Build a Lambda from a source MLIR, with `funcName` as entrypoint. + /// Use runtimeLibPath as a shared library if specified. + llvm::Expected + buildLambda(llvm::StringRef src, llvm::StringRef funcName = "main", + llvm::Optional runtimeLibPath = {}); - llvm::Expected buildLambda(std::unique_ptr buffer, - llvm::StringRef funcName = "main"); + llvm::Expected + buildLambda(std::unique_ptr buffer, + llvm::StringRef funcName = "main", + llvm::Optional runtimeLibPath = {}); - llvm::Expected buildLambda(llvm::SourceMgr &sm, - llvm::StringRef funcName = "main"); + llvm::Expected + buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName = "main", + llvm::Optional runtimeLibPath = {}); protected: llvm::Expected findLLVMFuncOp(mlir::ModuleOp module, diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index f3507e427..3416e67cb 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -26,10 +26,14 @@ void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) { pybind11::class_(m, "JitCompilerEngine") .def(pybind11::init()) - .def_static("build_lambda", - [](std::string mlir_input, std::string func_name) { - return buildLambda(mlir_input.c_str(), func_name.c_str()); - }); + .def_static("build_lambda", [](std::string mlir_input, + std::string func_name, + std::string runtime_lib_path) { + if (runtime_lib_path.empty()) + return buildLambda(mlir_input.c_str(), func_name.c_str(), nullptr); + return buildLambda(mlir_input.c_str(), func_name.c_str(), + runtime_lib_path.c_str()); + }); pybind11::class_(m, "LambdaArgument") .def_static("from_tensor", lambdaArgumentFromTensor) diff --git a/compiler/lib/Bindings/Python/zamalang/compiler.py b/compiler/lib/Bindings/Python/zamalang/compiler.py index f2350a7d9..30e06f3e5 100644 --- a/compiler/lib/Bindings/Python/zamalang/compiler.py +++ b/compiler/lib/Bindings/Python/zamalang/compiler.py @@ -1,4 +1,5 @@ """Compiler submodule""" +import os from typing import List, Union from mlir._mlir_libs._zamalang._compiler import JitCompilerEngine as _JitCompilerEngine from mlir._mlir_libs._zamalang._compiler import LambdaArgument as _LambdaArgument @@ -6,6 +7,30 @@ from mlir._mlir_libs._zamalang._compiler import round_trip as _round_trip import numpy as np +def _lookup_runtime_lib() -> str: + """Try to find the absolute path to the runtime library. + + Returns: + str: absolute path to the runtime library, or empty str if unsuccessful. + """ + # Go up to site-packages level + cwd = os.path.abspath(__file__) + cwd = os.path.abspath(os.path.join(cwd, os.pardir)) + cwd = os.path.abspath(os.path.join(cwd, os.pardir)) + package_name = "concretefhe_compiler" + libs_path = os.path.join(cwd, f"{package_name}.libs") + # Can be because it's not a properly installed package + if not os.path.exists(libs_path): + return "" + runtime_library_paths = [ + filename + for filename in os.listdir(libs_path) + if filename.startswith("libZamalangRuntime") + ] + assert len(runtime_library_paths) == 1, "should be one and only one runtime library" + return os.path.join(libs_path, runtime_library_paths[0]) + + def round_trip(mlir_str: str) -> str: """Parse the MLIR input, then return it back. @@ -39,7 +64,9 @@ def create_execution_argument(value: Union[int, np.ndarray]) -> "_LambdaArgument raise TypeError("value of execution argument must be either int or numpy.array") if isinstance(value, int): if not (0 <= value < (2 ** 64 - 1)): - raise TypeError("single integer must be in the range [0, 2**64 - 1] (uint64)") + raise TypeError( + "single integer must be in the range [0, 2**64 - 1] (uint64)" + ) return _LambdaArgument.from_scalar(value) else: assert isinstance(value, np.ndarray) @@ -55,19 +82,30 @@ class CompilerEngine: if mlir_str is not None: self.compile_fhe(mlir_str) - def compile_fhe(self, mlir_str: str, func_name: str = "main"): + def compile_fhe( + self, mlir_str: str, func_name: str = "main", runtime_lib_path: str = None + ): """Compile the MLIR input. Args: mlir_str (str): MLIR to compile. - func_name (str): name of the function to set as entrypoint. + func_name (str): name of the function to set as entrypoint (default: main). + runtime_lib_path (str): path to the runtime lib (default: None). Raises: TypeError: if the argument is not an str. """ if not isinstance(mlir_str, str): raise TypeError("input must be an `str`") - self._lambda = self._engine.build_lambda(mlir_str, func_name) + if runtime_lib_path is None: + # Set to empty string if not found + runtime_lib_path = _lookup_runtime_lib() + else: + if not isinstance(runtime_lib_path, str): + raise TypeError( + "runtime_lib_path must be an str representing the path to the runtime lib" + ) + self._lambda = self._engine.build_lambda(mlir_str, func_name, runtime_lib_path) def run(self, *args: List[Union[int, np.ndarray]]) -> Union[int, np.ndarray]: """Run the compiled code. diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index e1ff52d63..d60ad515d 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -5,11 +5,16 @@ using mlir::zamalang::JitCompilerEngine; -mlir::zamalang::JitCompilerEngine::Lambda buildLambda(const char *module, - const char *funcName) { +mlir::zamalang::JitCompilerEngine::Lambda +buildLambda(const char *module, const char *funcName, + const char *runtimeLibPath) { + // Set the runtime library path if not nullptr + llvm::Optional runtimeLibPathOptional = {}; + if (runtimeLibPath != nullptr) + runtimeLibPathOptional = runtimeLibPath; mlir::zamalang::JitCompilerEngine engine; llvm::Expected lambdaOrErr = - engine.buildLambda(module, funcName); + engine.buildLambda(module, funcName, runtimeLibPathOptional); if (!lambdaOrErr) { std::string backingString; llvm::raw_string_ostream os(backingString); diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 7278aa0a7..181ba654b 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -15,7 +15,8 @@ namespace zamalang { llvm::Expected> JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module, - llvm::function_ref optPipeline) { + llvm::function_ref optPipeline, + llvm::Optional runtimeLibPath) { // Looking for the function auto rangeOps = module.getOps(); @@ -33,9 +34,14 @@ JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module, mlir::registerLLVMDialectTranslation(*module->getContext()); // Create an MLIR execution engine. The execution engine eagerly - // JIT-compiles the module. + // JIT-compiles the module. If runtimeLibPath is specified, it's passed as a + // shared library to the JIT compiler. + std::vector sharedLibPaths; + if (runtimeLibPath.hasValue()) + sharedLibPaths.push_back(runtimeLibPath.getValue()); auto maybeEngine = mlir::ExecutionEngine::create( - module, /*llvmModuleBuilder=*/nullptr, optPipeline); + module, /*llvmModuleBuilder=*/nullptr, optPipeline, + /*jitCodeGenOptLevel=*/llvm::None, sharedLibPaths); if (!maybeEngine) { return llvm::make_error( "failed to construct the MLIR ExecutionEngine", diff --git a/compiler/lib/Support/JitCompilerEngine.cpp b/compiler/lib/Support/JitCompilerEngine.cpp index b1d8deef9..eb3859631 100644 --- a/compiler/lib/Support/JitCompilerEngine.cpp +++ b/compiler/lib/Support/JitCompilerEngine.cpp @@ -36,13 +36,14 @@ JitCompilerEngine::findLLVMFuncOp(mlir::ModuleOp module, llvm::StringRef name) { // `funcName` from the sources in `buffer`. llvm::Expected JitCompilerEngine::buildLambda(std::unique_ptr buffer, - llvm::StringRef funcName) { + llvm::StringRef funcName, + llvm::Optional runtimeLibPath) { llvm::SourceMgr sm; sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); llvm::Expected res = - this->buildLambda(sm, funcName); + this->buildLambda(sm, funcName, runtimeLibPath); return std::move(res); } @@ -50,10 +51,11 @@ JitCompilerEngine::buildLambda(std::unique_ptr buffer, // Build a lambda from the function with the name given in `funcName` // from the source string `s`. llvm::Expected -JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName) { +JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName, + llvm::Optional runtimeLibPath) { std::unique_ptr mb = llvm::MemoryBuffer::getMemBuffer(s); llvm::Expected res = - this->buildLambda(std::move(mb), funcName); + this->buildLambda(std::move(mb), funcName, runtimeLibPath); return std::move(res); } @@ -61,7 +63,8 @@ JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName) { // Build a lambda from the function with the name given in // `funcName` from the sources managed by the source manager `sm`. llvm::Expected -JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName) { +JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName, + llvm::Optional runtimeLibPath) { MLIRContext &mlirContext = *this->compilationContext->getMLIRContext(); this->setGenerateClientParameters(true); @@ -92,7 +95,8 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName) { mlir::makeOptimizingTransformer(3, 0, nullptr); llvm::Expected> lambdaOrErr = - mlir::zamalang::JITLambda::create(funcName, module, optPipeline); + mlir::zamalang::JITLambda::create(funcName, module, optPipeline, + runtimeLibPath); // Generate the KeySet for encrypting lambda arguments, decrypting lambda // results