mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: get RT lib path from py and use as sharedlib
Try to find the runtime library automatically (should only work on proper installation of the package), and fail silently by not passing any RT lib. The RT lib can also be specified manually. The RT lib will be used as a shared library by the JIT compiler.
This commit is contained in:
@@ -31,9 +31,11 @@ struct executionArguments {
|
||||
typedef struct executionArguments exectuionArguments;
|
||||
|
||||
// Build lambda from a textual representation of an MLIR module
|
||||
// The lambda will have `funcName` as entrypoint
|
||||
// The lambda will have `funcName` as entrypoint, and use runtimeLibPath (if not
|
||||
// null) as a shared library during compilation
|
||||
MLIR_CAPI_EXPORTED mlir::zamalang::JitCompilerEngine::Lambda
|
||||
buildLambda(const char *module, const char *funcName);
|
||||
buildLambda(const char *module, const char *funcName,
|
||||
const char *runtimeLibPath);
|
||||
|
||||
// Parse then print a textual representation of an MLIR module
|
||||
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
|
||||
|
||||
@@ -101,9 +101,11 @@ public:
|
||||
: type(type), name(name){};
|
||||
|
||||
/// create a JITLambda that point to the function name of the given module.
|
||||
/// Use runtimeLibPath as a shared library if specified.
|
||||
static llvm::Expected<std::unique_ptr<JITLambda>>
|
||||
create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline);
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline,
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath = {});
|
||||
|
||||
/// invokeRaw execute the jit lambda with a list of Argument, the last one is
|
||||
/// used to store the result of the computation.
|
||||
|
||||
@@ -359,14 +359,20 @@ public:
|
||||
CompilationContext::createShared(),
|
||||
unsigned int optimizationLevel = 3);
|
||||
|
||||
llvm::Expected<Lambda> buildLambda(llvm::StringRef src,
|
||||
llvm::StringRef funcName = "main");
|
||||
/// Build a Lambda from a source MLIR, with `funcName` as entrypoint.
|
||||
/// Use runtimeLibPath as a shared library if specified.
|
||||
llvm::Expected<Lambda>
|
||||
buildLambda(llvm::StringRef src, llvm::StringRef funcName = "main",
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath = {});
|
||||
|
||||
llvm::Expected<Lambda> buildLambda(std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
llvm::StringRef funcName = "main");
|
||||
llvm::Expected<Lambda>
|
||||
buildLambda(std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
llvm::StringRef funcName = "main",
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath = {});
|
||||
|
||||
llvm::Expected<Lambda> buildLambda(llvm::SourceMgr &sm,
|
||||
llvm::StringRef funcName = "main");
|
||||
llvm::Expected<Lambda>
|
||||
buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName = "main",
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath = {});
|
||||
|
||||
protected:
|
||||
llvm::Expected<mlir::LLVM::LLVMFuncOp> findLLVMFuncOp(mlir::ModuleOp module,
|
||||
|
||||
@@ -26,10 +26,14 @@ void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
|
||||
|
||||
pybind11::class_<JitCompilerEngine>(m, "JitCompilerEngine")
|
||||
.def(pybind11::init())
|
||||
.def_static("build_lambda",
|
||||
[](std::string mlir_input, std::string func_name) {
|
||||
return buildLambda(mlir_input.c_str(), func_name.c_str());
|
||||
});
|
||||
.def_static("build_lambda", [](std::string mlir_input,
|
||||
std::string func_name,
|
||||
std::string runtime_lib_path) {
|
||||
if (runtime_lib_path.empty())
|
||||
return buildLambda(mlir_input.c_str(), func_name.c_str(), nullptr);
|
||||
return buildLambda(mlir_input.c_str(), func_name.c_str(),
|
||||
runtime_lib_path.c_str());
|
||||
});
|
||||
|
||||
pybind11::class_<lambdaArgument>(m, "LambdaArgument")
|
||||
.def_static("from_tensor", lambdaArgumentFromTensor)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Compiler submodule"""
|
||||
import os
|
||||
from typing import List, Union
|
||||
from mlir._mlir_libs._zamalang._compiler import JitCompilerEngine as _JitCompilerEngine
|
||||
from mlir._mlir_libs._zamalang._compiler import LambdaArgument as _LambdaArgument
|
||||
@@ -6,6 +7,30 @@ from mlir._mlir_libs._zamalang._compiler import round_trip as _round_trip
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _lookup_runtime_lib() -> str:
|
||||
"""Try to find the absolute path to the runtime library.
|
||||
|
||||
Returns:
|
||||
str: absolute path to the runtime library, or empty str if unsuccessful.
|
||||
"""
|
||||
# Go up to site-packages level
|
||||
cwd = os.path.abspath(__file__)
|
||||
cwd = os.path.abspath(os.path.join(cwd, os.pardir))
|
||||
cwd = os.path.abspath(os.path.join(cwd, os.pardir))
|
||||
package_name = "concretefhe_compiler"
|
||||
libs_path = os.path.join(cwd, f"{package_name}.libs")
|
||||
# Can be because it's not a properly installed package
|
||||
if not os.path.exists(libs_path):
|
||||
return ""
|
||||
runtime_library_paths = [
|
||||
filename
|
||||
for filename in os.listdir(libs_path)
|
||||
if filename.startswith("libZamalangRuntime")
|
||||
]
|
||||
assert len(runtime_library_paths) == 1, "should be one and only one runtime library"
|
||||
return os.path.join(libs_path, runtime_library_paths[0])
|
||||
|
||||
|
||||
def round_trip(mlir_str: str) -> str:
|
||||
"""Parse the MLIR input, then return it back.
|
||||
|
||||
@@ -39,7 +64,9 @@ def create_execution_argument(value: Union[int, np.ndarray]) -> "_LambdaArgument
|
||||
raise TypeError("value of execution argument must be either int or numpy.array")
|
||||
if isinstance(value, int):
|
||||
if not (0 <= value < (2 ** 64 - 1)):
|
||||
raise TypeError("single integer must be in the range [0, 2**64 - 1] (uint64)")
|
||||
raise TypeError(
|
||||
"single integer must be in the range [0, 2**64 - 1] (uint64)"
|
||||
)
|
||||
return _LambdaArgument.from_scalar(value)
|
||||
else:
|
||||
assert isinstance(value, np.ndarray)
|
||||
@@ -55,19 +82,30 @@ class CompilerEngine:
|
||||
if mlir_str is not None:
|
||||
self.compile_fhe(mlir_str)
|
||||
|
||||
def compile_fhe(self, mlir_str: str, func_name: str = "main"):
|
||||
def compile_fhe(
|
||||
self, mlir_str: str, func_name: str = "main", runtime_lib_path: str = None
|
||||
):
|
||||
"""Compile the MLIR input.
|
||||
|
||||
Args:
|
||||
mlir_str (str): MLIR to compile.
|
||||
func_name (str): name of the function to set as entrypoint.
|
||||
func_name (str): name of the function to set as entrypoint (default: main).
|
||||
runtime_lib_path (str): path to the runtime lib (default: None).
|
||||
|
||||
Raises:
|
||||
TypeError: if the argument is not an str.
|
||||
"""
|
||||
if not isinstance(mlir_str, str):
|
||||
raise TypeError("input must be an `str`")
|
||||
self._lambda = self._engine.build_lambda(mlir_str, func_name)
|
||||
if runtime_lib_path is None:
|
||||
# Set to empty string if not found
|
||||
runtime_lib_path = _lookup_runtime_lib()
|
||||
else:
|
||||
if not isinstance(runtime_lib_path, str):
|
||||
raise TypeError(
|
||||
"runtime_lib_path must be an str representing the path to the runtime lib"
|
||||
)
|
||||
self._lambda = self._engine.build_lambda(mlir_str, func_name, runtime_lib_path)
|
||||
|
||||
def run(self, *args: List[Union[int, np.ndarray]]) -> Union[int, np.ndarray]:
|
||||
"""Run the compiled code.
|
||||
|
||||
@@ -5,11 +5,16 @@
|
||||
|
||||
using mlir::zamalang::JitCompilerEngine;
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda buildLambda(const char *module,
|
||||
const char *funcName) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda
|
||||
buildLambda(const char *module, const char *funcName,
|
||||
const char *runtimeLibPath) {
|
||||
// Set the runtime library path if not nullptr
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPathOptional = {};
|
||||
if (runtimeLibPath != nullptr)
|
||||
runtimeLibPathOptional = runtimeLibPath;
|
||||
mlir::zamalang::JitCompilerEngine engine;
|
||||
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
|
||||
engine.buildLambda(module, funcName);
|
||||
engine.buildLambda(module, funcName, runtimeLibPathOptional);
|
||||
if (!lambdaOrErr) {
|
||||
std::string backingString;
|
||||
llvm::raw_string_ostream os(backingString);
|
||||
|
||||
@@ -15,7 +15,8 @@ namespace zamalang {
|
||||
|
||||
llvm::Expected<std::unique_ptr<JITLambda>>
|
||||
JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline) {
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline,
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath) {
|
||||
|
||||
// Looking for the function
|
||||
auto rangeOps = module.getOps<mlir::LLVM::LLVMFuncOp>();
|
||||
@@ -33,9 +34,14 @@ JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
mlir::registerLLVMDialectTranslation(*module->getContext());
|
||||
|
||||
// Create an MLIR execution engine. The execution engine eagerly
|
||||
// JIT-compiles the module.
|
||||
// JIT-compiles the module. If runtimeLibPath is specified, it's passed as a
|
||||
// shared library to the JIT compiler.
|
||||
std::vector<llvm::StringRef> sharedLibPaths;
|
||||
if (runtimeLibPath.hasValue())
|
||||
sharedLibPaths.push_back(runtimeLibPath.getValue());
|
||||
auto maybeEngine = mlir::ExecutionEngine::create(
|
||||
module, /*llvmModuleBuilder=*/nullptr, optPipeline);
|
||||
module, /*llvmModuleBuilder=*/nullptr, optPipeline,
|
||||
/*jitCodeGenOptLevel=*/llvm::None, sharedLibPaths);
|
||||
if (!maybeEngine) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"failed to construct the MLIR ExecutionEngine",
|
||||
|
||||
@@ -36,13 +36,14 @@ JitCompilerEngine::findLLVMFuncOp(mlir::ModuleOp module, llvm::StringRef name) {
|
||||
// `funcName` from the sources in `buffer`.
|
||||
llvm::Expected<JitCompilerEngine::Lambda>
|
||||
JitCompilerEngine::buildLambda(std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
llvm::StringRef funcName) {
|
||||
llvm::StringRef funcName,
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath) {
|
||||
llvm::SourceMgr sm;
|
||||
|
||||
sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
|
||||
|
||||
llvm::Expected<JitCompilerEngine::Lambda> res =
|
||||
this->buildLambda(sm, funcName);
|
||||
this->buildLambda(sm, funcName, runtimeLibPath);
|
||||
|
||||
return std::move(res);
|
||||
}
|
||||
@@ -50,10 +51,11 @@ JitCompilerEngine::buildLambda(std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
// Build a lambda from the function with the name given in `funcName`
|
||||
// from the source string `s`.
|
||||
llvm::Expected<JitCompilerEngine::Lambda>
|
||||
JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName) {
|
||||
JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName,
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath) {
|
||||
std::unique_ptr<llvm::MemoryBuffer> mb = llvm::MemoryBuffer::getMemBuffer(s);
|
||||
llvm::Expected<JitCompilerEngine::Lambda> res =
|
||||
this->buildLambda(std::move(mb), funcName);
|
||||
this->buildLambda(std::move(mb), funcName, runtimeLibPath);
|
||||
|
||||
return std::move(res);
|
||||
}
|
||||
@@ -61,7 +63,8 @@ JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName) {
|
||||
// Build a lambda from the function with the name given in
|
||||
// `funcName` from the sources managed by the source manager `sm`.
|
||||
llvm::Expected<JitCompilerEngine::Lambda>
|
||||
JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName) {
|
||||
JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName,
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath) {
|
||||
MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
|
||||
|
||||
this->setGenerateClientParameters(true);
|
||||
@@ -92,7 +95,8 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName) {
|
||||
mlir::makeOptimizingTransformer(3, 0, nullptr);
|
||||
|
||||
llvm::Expected<std::unique_ptr<JITLambda>> lambdaOrErr =
|
||||
mlir::zamalang::JITLambda::create(funcName, module, optPipeline);
|
||||
mlir::zamalang::JITLambda::create(funcName, module, optPipeline,
|
||||
runtimeLibPath);
|
||||
|
||||
// Generate the KeySet for encrypting lambda arguments, decrypting lambda
|
||||
// results
|
||||
|
||||
Reference in New Issue
Block a user