mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
feat(compiler): add a key cache
This commit is contained in:
12
.github/workflows/conformance.yml
vendored
12
.github/workflows/conformance.yml
vendored
@@ -22,6 +22,15 @@ jobs:
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: "KeySetCache"
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ${{ github.workspace }}/KeySetCache
|
||||
# actions/cache does not permit to update a cache entry
|
||||
key: ${{ runner.os }}-KeySetCache-2021-12-02
|
||||
restore-keys: |
|
||||
${{ runner.os }}-KeySetCache-
|
||||
|
||||
- name: Build and test compiler
|
||||
uses: addnab/docker-run-action@v3
|
||||
with:
|
||||
@@ -29,7 +38,7 @@ jobs:
|
||||
image: ghcr.io/zama-ai/zamalang-compiler:latest
|
||||
username: ${{ secrets.GHCR_LOGIN }}
|
||||
password: ${{ secrets.GHCR_PASSWORD }}
|
||||
options: -v ${{ github.workspace }}/compiler:/compiler
|
||||
options: -v ${{ github.workspace }}/compiler:/compiler -v ${{ github.workspace }}/KeySetCache:/tmp/KeySetCache
|
||||
shell: bash
|
||||
run: |
|
||||
set -e
|
||||
@@ -42,3 +51,4 @@ jobs:
|
||||
make CCACHE=ON BUILD_DIR=/build test
|
||||
echo "Debug: ccache statistics (after the build):"
|
||||
ccache -s
|
||||
chmod -R ugo+rwx /tmp/KeySetCache
|
||||
|
||||
@@ -43,7 +43,7 @@ test-check: zamacompiler file-check not
|
||||
$(BUILD_DIR)/bin/llvm-lit -v tests/
|
||||
|
||||
test-python: python-bindings zamacompiler
|
||||
PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/zamalang/python_packages/zamalang_core LD_PRELOAD=$(BUILD_DIR)/lib/libZamalangRuntime.so pytest -vs tests/python
|
||||
PATH=$(BUILD_DIR)/bin:${PATH} PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/zamalang/python_packages/zamalang_core LD_PRELOAD=$(BUILD_DIR)/lib/libZamalangRuntime.so pytest -vs tests/python
|
||||
|
||||
test: test-check test-end-to-end-jit test-python
|
||||
|
||||
|
||||
@@ -32,10 +32,12 @@ typedef struct executionArguments executionArguments;
|
||||
|
||||
// Build lambda from a textual representation of an MLIR module
|
||||
// The lambda will have `funcName` as entrypoint, and use runtimeLibPath (if not
|
||||
// null) as a shared library during compilation
|
||||
// null) as a shared library during compilation,
|
||||
// a path to activate the use a cache for encryption keys for test purpose
|
||||
// (unsecure).
|
||||
MLIR_CAPI_EXPORTED mlir::zamalang::JitCompilerEngine::Lambda
|
||||
buildLambda(const char *module, const char *funcName,
|
||||
const char *runtimeLibPath);
|
||||
const char *runtimeLibPath, const char *keySetCachePath);
|
||||
|
||||
// Parse then print a textual representation of an MLIR module
|
||||
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
|
||||
|
||||
@@ -24,6 +24,8 @@ typedef uint64_t GlweDimension;
|
||||
typedef std::string LweSecretKeyID;
|
||||
struct LweSecretKeyParam {
|
||||
LweSize size;
|
||||
|
||||
void hash(size_t &seed);
|
||||
};
|
||||
|
||||
typedef std::string BootstrapKeyID;
|
||||
@@ -34,6 +36,8 @@ struct BootstrapKeyParam {
|
||||
DecompositionBaseLog baseLog;
|
||||
GlweDimension glweDimension;
|
||||
Variance variance;
|
||||
|
||||
void hash(size_t &seed);
|
||||
};
|
||||
|
||||
typedef std::string KeyswitchKeyID;
|
||||
@@ -43,6 +47,8 @@ struct KeyswitchKeyParam {
|
||||
DecompositionLevelCount level;
|
||||
DecompositionBaseLog baseLog;
|
||||
Variance variance;
|
||||
|
||||
void hash(size_t &seed);
|
||||
};
|
||||
|
||||
struct Encoding {
|
||||
@@ -75,11 +81,13 @@ struct ClientParameters {
|
||||
std::map<KeyswitchKeyID, KeyswitchKeyParam> keyswitchKeys;
|
||||
std::vector<CircuitGate> inputs;
|
||||
std::vector<CircuitGate> outputs;
|
||||
size_t hash();
|
||||
};
|
||||
|
||||
llvm::Expected<ClientParameters>
|
||||
createClientParametersForV0(V0FHEContext context, llvm::StringRef name,
|
||||
mlir::ModuleOp module);
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#ifndef ZAMALANG_SUPPORT_JIT_COMPILER_ENGINE_H
|
||||
#define ZAMALANG_SUPPORT_JIT_COMPILER_ENGINE_H
|
||||
|
||||
#include "zamalang/Support/KeySetCache.h"
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
#include <zamalang/Support/CompilerEngine.h>
|
||||
#include <zamalang/Support/Error.h>
|
||||
@@ -363,15 +364,18 @@ public:
|
||||
/// Use runtimeLibPath as a shared library if specified.
|
||||
llvm::Expected<Lambda>
|
||||
buildLambda(llvm::StringRef src, llvm::StringRef funcName = "main",
|
||||
llvm::Optional<KeySetCache> cachePath = {},
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath = {});
|
||||
|
||||
llvm::Expected<Lambda>
|
||||
buildLambda(std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
llvm::StringRef funcName = "main",
|
||||
llvm::Optional<KeySetCache> cachePath = {},
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath = {});
|
||||
|
||||
llvm::Expected<Lambda>
|
||||
buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName = "main",
|
||||
llvm::Optional<KeySetCache> cachePath = {},
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath = {});
|
||||
|
||||
protected:
|
||||
|
||||
@@ -10,6 +10,7 @@ extern "C" {
|
||||
}
|
||||
|
||||
#include "zamalang/Support/ClientParameters.h"
|
||||
#include "zamalang/Support/KeySetCache.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
@@ -17,6 +18,15 @@ namespace zamalang {
|
||||
class KeySet {
|
||||
public:
|
||||
~KeySet();
|
||||
|
||||
static std::unique_ptr<KeySet> uninitialized();
|
||||
|
||||
llvm::Error generateKeysFromParams(ClientParameters ¶ms,
|
||||
uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
llvm::Error setupEncryptionMaterial(ClientParameters ¶ms,
|
||||
uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
// allocate a KeySet according the ClientParameters.
|
||||
static llvm::Expected<std::unique_ptr<KeySet>>
|
||||
generate(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb);
|
||||
@@ -46,6 +56,18 @@ public:
|
||||
context.bsk = std::get<1>(this->bootstrapKeys["bsk_v0"]);
|
||||
}
|
||||
|
||||
const std::map<LweSecretKeyID,
|
||||
std::pair<LweSecretKeyParam, LweSecretKey_u64 *>> &
|
||||
getSecretKeys();
|
||||
|
||||
const std::map<LweSecretKeyID,
|
||||
std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>> &
|
||||
getBootstrapKeys();
|
||||
|
||||
const std::map<LweSecretKeyID,
|
||||
std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>> &
|
||||
getKeyswitchKeys();
|
||||
|
||||
protected:
|
||||
llvm::Error generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
|
||||
SecretRandomGenerator *generator);
|
||||
@@ -54,6 +76,8 @@ protected:
|
||||
llvm::Error generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param,
|
||||
EncryptionRandomGenerator *generator);
|
||||
|
||||
friend class KeySetCache;
|
||||
|
||||
private:
|
||||
EncryptionRandomGenerator *encryptionRandomGenerator;
|
||||
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
|
||||
@@ -66,6 +90,16 @@ private:
|
||||
inputs;
|
||||
std::vector<std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *>>
|
||||
outputs;
|
||||
|
||||
void setKeys(
|
||||
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
|
||||
secretKeys,
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>>
|
||||
bootstrapKeys,
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>>
|
||||
keyswitchKeys);
|
||||
};
|
||||
|
||||
} // namespace zamalang
|
||||
|
||||
31
compiler/include/zamalang/Support/KeySetCache.h
Normal file
31
compiler/include/zamalang/Support/KeySetCache.h
Normal file
@@ -0,0 +1,31 @@
|
||||
#ifndef ZAMALANG_SUPPORT_KEYSETCACHE_H_
|
||||
#define ZAMALANG_SUPPORT_KEYSETCACHE_H_
|
||||
|
||||
#include "zamalang/Support/KeySet.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
class KeySet;
|
||||
|
||||
class KeySetCache {
|
||||
std::string backingDirectoryPath;
|
||||
|
||||
public:
|
||||
KeySetCache(std::string backingDirectoryPath)
|
||||
: backingDirectoryPath(backingDirectoryPath) {}
|
||||
|
||||
llvm::Expected<std::unique_ptr<KeySet>>
|
||||
tryLoadOrGenerateSave(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb);
|
||||
|
||||
private:
|
||||
static llvm::Expected<std::unique_ptr<KeySet>>
|
||||
tryLoadKeys(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb,
|
||||
llvm::SmallString<0> &folderPath);
|
||||
};
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -17,6 +17,10 @@
|
||||
using mlir::zamalang::JitCompilerEngine;
|
||||
using mlir::zamalang::LambdaArgument;
|
||||
|
||||
const char *noEmptyStringPtr(std::string &s) {
|
||||
return (s.empty()) ? nullptr : s.c_str();
|
||||
}
|
||||
|
||||
/// Populate the compiler API python module.
|
||||
void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
|
||||
m.doc() = "Zamalang compiler python API";
|
||||
@@ -31,14 +35,14 @@ void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
|
||||
|
||||
pybind11::class_<JitCompilerEngine>(m, "JitCompilerEngine")
|
||||
.def(pybind11::init())
|
||||
.def_static("build_lambda", [](std::string mlir_input,
|
||||
std::string func_name,
|
||||
std::string runtime_lib_path) {
|
||||
if (runtime_lib_path.empty())
|
||||
return buildLambda(mlir_input.c_str(), func_name.c_str(), nullptr);
|
||||
return buildLambda(mlir_input.c_str(), func_name.c_str(),
|
||||
runtime_lib_path.c_str());
|
||||
});
|
||||
.def_static("build_lambda",
|
||||
[](std::string mlir_input, std::string func_name,
|
||||
std::string runtime_lib_path,
|
||||
std::string keysetcache_path) {
|
||||
return buildLambda(mlir_input.c_str(), func_name.c_str(),
|
||||
noEmptyStringPtr(runtime_lib_path),
|
||||
noEmptyStringPtr(keysetcache_path));
|
||||
});
|
||||
|
||||
pybind11::class_<lambdaArgument>(m, "LambdaArgument")
|
||||
.def_static("from_tensor", lambdaArgumentFromTensor)
|
||||
|
||||
@@ -118,7 +118,8 @@ class CompilerEngine:
|
||||
self.compile_fhe(mlir_str)
|
||||
|
||||
def compile_fhe(
|
||||
self, mlir_str: str, func_name: str = "main", runtime_lib_path: str = None
|
||||
self, mlir_str: str, func_name: str = "main", runtime_lib_path: str = None,
|
||||
unsecure_key_set_cache_path: str = None,
|
||||
):
|
||||
"""Compile the MLIR input.
|
||||
|
||||
@@ -126,6 +127,7 @@ class CompilerEngine:
|
||||
mlir_str (str): MLIR to compile.
|
||||
func_name (str): name of the function to set as entrypoint (default: main).
|
||||
runtime_lib_path (str): path to the runtime lib (default: None).
|
||||
unsecure_key_set_cache_path (str): path to the activate keyset caching (default: None).
|
||||
|
||||
Raises:
|
||||
TypeError: if the argument is not an str.
|
||||
@@ -140,7 +142,14 @@ class CompilerEngine:
|
||||
raise TypeError(
|
||||
"runtime_lib_path must be an str representing the path to the runtime lib"
|
||||
)
|
||||
self._lambda = self._engine.build_lambda(mlir_str, func_name, runtime_lib_path)
|
||||
unsecure_key_set_cache_path = unsecure_key_set_cache_path or ""
|
||||
if not isinstance(unsecure_key_set_cache_path, str):
|
||||
raise TypeError(
|
||||
"unsecure_key_set_cache_path must be a str"
|
||||
)
|
||||
self._lambda = self._engine.build_lambda(
|
||||
mlir_str, func_name, runtime_lib_path,
|
||||
unsecure_key_set_cache_path)
|
||||
|
||||
def run(self, *args: List[Union[int, np.ndarray]]) -> Union[int, np.ndarray]:
|
||||
"""Run the compiled code.
|
||||
|
||||
@@ -1,20 +1,31 @@
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
|
||||
#include "zamalang-c/Support/CompilerEngine.h"
|
||||
#include "zamalang/Support/CompilerEngine.h"
|
||||
#include "zamalang/Support/Jit.h"
|
||||
#include "zamalang/Support/JitCompilerEngine.h"
|
||||
#include "zamalang/Support/KeySetCache.h"
|
||||
|
||||
using mlir::zamalang::JitCompilerEngine;
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda
|
||||
buildLambda(const char *module, const char *funcName,
|
||||
const char *runtimeLibPath) {
|
||||
const char *runtimeLibPath, const char *keySetCachePath) {
|
||||
// Set the runtime library path if not nullptr
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPathOptional = {};
|
||||
if (runtimeLibPath != nullptr)
|
||||
runtimeLibPathOptional = runtimeLibPath;
|
||||
mlir::zamalang::JitCompilerEngine engine;
|
||||
|
||||
using KeySetCache = mlir::zamalang::KeySetCache;
|
||||
using optKeySetCache = llvm::Optional<mlir::zamalang::KeySetCache>;
|
||||
auto cacheOpt = optKeySetCache();
|
||||
if (keySetCachePath != nullptr) {
|
||||
cacheOpt = KeySetCache(std::string(keySetCachePath));
|
||||
}
|
||||
|
||||
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
|
||||
engine.buildLambda(module, funcName, runtimeLibPathOptional);
|
||||
engine.buildLambda(module, funcName, cacheOpt, runtimeLibPathOptional);
|
||||
if (!lambdaOrErr) {
|
||||
std::string backingString;
|
||||
llvm::raw_string_ostream os(backingString);
|
||||
|
||||
@@ -11,6 +11,7 @@ add_mlir_library(ZamalangSupport
|
||||
KeySet.cpp
|
||||
logging.cpp
|
||||
Jit.cpp
|
||||
KeySetCache.cpp
|
||||
LLVMEmitFile.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
|
||||
@@ -145,5 +145,44 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
|
||||
return c;
|
||||
}
|
||||
|
||||
// https://stackoverflow.com/a/38140932
|
||||
static inline void hash(std::size_t &seed) {}
|
||||
template <typename T, typename... Rest>
|
||||
static inline void hash(std::size_t &seed, const T &v, Rest... rest) {
|
||||
// See https://softwareengineering.stackexchange.com/a/402543
|
||||
const auto GOLDEN_RATIO = 0x9e3779b97f4a7c15; // pseudo random bits
|
||||
const std::hash<T> hasher;
|
||||
seed ^= hasher(v) + GOLDEN_RATIO + (seed << 6) + (seed >> 2);
|
||||
hash(seed, rest...);
|
||||
}
|
||||
|
||||
void LweSecretKeyParam::hash(size_t &seed) { mlir::zamalang::hash(seed, size); }
|
||||
|
||||
void BootstrapKeyParam::hash(size_t &seed) {
|
||||
mlir::zamalang::hash(seed, inputSecretKeyID, outputSecretKeyID, level,
|
||||
baseLog, glweDimension, variance);
|
||||
}
|
||||
|
||||
void KeyswitchKeyParam::hash(size_t &seed) {
|
||||
mlir::zamalang::hash(seed, inputSecretKeyID, outputSecretKeyID, level,
|
||||
baseLog, variance);
|
||||
}
|
||||
|
||||
std::size_t ClientParameters::hash() {
|
||||
std::size_t currentHash = 1;
|
||||
for (auto secretKeyParam : secretKeys) {
|
||||
mlir::zamalang::hash(currentHash, secretKeyParam.first);
|
||||
secretKeyParam.second.hash(currentHash);
|
||||
}
|
||||
for (auto bootstrapKeyParam : bootstrapKeys) {
|
||||
mlir::zamalang::hash(currentHash, bootstrapKeyParam.first);
|
||||
bootstrapKeyParam.second.hash(currentHash);
|
||||
}
|
||||
for (auto keyswitchParam : keyswitchKeys) {
|
||||
mlir::zamalang::hash(currentHash, keyswitchParam.first);
|
||||
keyswitchParam.second.hash(currentHash);
|
||||
}
|
||||
return currentHash;
|
||||
}
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -37,25 +37,22 @@ JitCompilerEngine::findLLVMFuncOp(mlir::ModuleOp module, llvm::StringRef name) {
|
||||
llvm::Expected<JitCompilerEngine::Lambda>
|
||||
JitCompilerEngine::buildLambda(std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
llvm::StringRef funcName,
|
||||
llvm::Optional<KeySetCache> cache,
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath) {
|
||||
llvm::SourceMgr sm;
|
||||
|
||||
sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
|
||||
|
||||
llvm::Expected<JitCompilerEngine::Lambda> res =
|
||||
this->buildLambda(sm, funcName, runtimeLibPath);
|
||||
|
||||
return std::move(res);
|
||||
return this->buildLambda(sm, funcName, cache, runtimeLibPath);
|
||||
}
|
||||
|
||||
// Build a lambda from the function with the name given in `funcName`
|
||||
// from the source string `s`.
|
||||
llvm::Expected<JitCompilerEngine::Lambda>
|
||||
JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName,
|
||||
llvm::Optional<KeySetCache> cache,
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath) {
|
||||
std::unique_ptr<llvm::MemoryBuffer> mb = llvm::MemoryBuffer::getMemBuffer(s);
|
||||
llvm::Expected<JitCompilerEngine::Lambda> res =
|
||||
this->buildLambda(std::move(mb), funcName, runtimeLibPath);
|
||||
this->buildLambda(std::move(mb), funcName, cache, runtimeLibPath);
|
||||
|
||||
return std::move(res);
|
||||
}
|
||||
@@ -64,6 +61,7 @@ JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName,
|
||||
// `funcName` from the sources managed by the source manager `sm`.
|
||||
llvm::Expected<JitCompilerEngine::Lambda>
|
||||
JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName,
|
||||
llvm::Optional<KeySetCache> cache,
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath) {
|
||||
MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
|
||||
|
||||
@@ -77,14 +75,17 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName,
|
||||
if (!compResOrErr)
|
||||
return std::move(compResOrErr.takeError());
|
||||
|
||||
mlir::ModuleOp module = compResOrErr->mlirModuleRef->get();
|
||||
auto compRes = std::move(compResOrErr.get());
|
||||
|
||||
mlir::ModuleOp module = compRes.mlirModuleRef->get();
|
||||
|
||||
// Locate function to JIT-compile
|
||||
llvm::Expected<mlir::LLVM::LLVMFuncOp> funcOrError =
|
||||
this->findLLVMFuncOp(compResOrErr->mlirModuleRef->get(), funcName);
|
||||
this->findLLVMFuncOp(compRes.mlirModuleRef->get(), funcName);
|
||||
|
||||
if (!funcOrError)
|
||||
return std::move(funcOrError.takeError());
|
||||
return StreamStringError() << "Cannot find function \"" << funcName
|
||||
<< "\": " << std::move(funcOrError.takeError());
|
||||
|
||||
// Prepare LLVM infrastructure for JIT compilation
|
||||
llvm::InitializeNativeTarget();
|
||||
@@ -98,24 +99,32 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName,
|
||||
mlir::zamalang::JITLambda::create(funcName, module, optPipeline,
|
||||
runtimeLibPath);
|
||||
|
||||
if (!lambdaOrErr) {
|
||||
return StreamStringError()
|
||||
<< "Cannot create lambda: " << lambdaOrErr.takeError();
|
||||
}
|
||||
|
||||
auto lambda = std::move(lambdaOrErr.get());
|
||||
|
||||
// Generate the KeySet for encrypting lambda arguments, decrypting lambda
|
||||
// results
|
||||
if (!compResOrErr->clientParameters.hasValue()) {
|
||||
if (!compRes.clientParameters.hasValue()) {
|
||||
return StreamStringError("Cannot generate the keySet since client "
|
||||
"parameters has not been computed");
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<mlir::zamalang::KeySet>> keySetOrErr =
|
||||
mlir::zamalang::KeySet::generate(*compResOrErr->clientParameters, 0, 0);
|
||||
(cache.hasValue())
|
||||
? cache->tryLoadOrGenerateSave(*compRes.clientParameters, 0, 0)
|
||||
: KeySet::generate(*compRes.clientParameters, 0, 0);
|
||||
|
||||
if (auto err = keySetOrErr.takeError())
|
||||
return std::move(err);
|
||||
if (!keySetOrErr) {
|
||||
return keySetOrErr.takeError();
|
||||
}
|
||||
|
||||
if (!lambdaOrErr)
|
||||
return std::move(lambdaOrErr.takeError());
|
||||
auto keySet = std::move(keySetOrErr.get());
|
||||
|
||||
return Lambda{this->compilationContext, std::move(lambdaOrErr.get()),
|
||||
std::move(*keySetOrErr)};
|
||||
return Lambda{this->compilationContext, std::move(lambda), std::move(keySet)};
|
||||
}
|
||||
|
||||
} // namespace zamalang
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "zamalang/Support/KeySet.h"
|
||||
#include "zamalang/Support/Error.h"
|
||||
|
||||
#define CAPI_ERR_TO_LLVM_ERROR(s, msg) \
|
||||
{ \
|
||||
@@ -30,7 +31,79 @@ KeySet::~KeySet() {
|
||||
llvm::Expected<std::unique_ptr<KeySet>>
|
||||
KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
auto keySet = std::make_unique<KeySet>();
|
||||
|
||||
auto a = uninitialized();
|
||||
|
||||
auto fillError = a->generateKeysFromParams(params, seed_msb, seed_lsb);
|
||||
|
||||
if (fillError) {
|
||||
return StreamStringError()
|
||||
<< "Cannot fill keys from params: " << std::move(fillError);
|
||||
}
|
||||
|
||||
fillError = a->setupEncryptionMaterial(params, seed_msb, seed_lsb);
|
||||
|
||||
if (fillError) {
|
||||
return StreamStringError()
|
||||
<< "Cannot setup encryption material: " << std::move(fillError);
|
||||
}
|
||||
|
||||
return a;
|
||||
}
|
||||
|
||||
std::unique_ptr<KeySet> KeySet::uninitialized() {
|
||||
return std::make_unique<KeySet>();
|
||||
}
|
||||
|
||||
llvm::Error KeySet::setupEncryptionMaterial(ClientParameters ¶ms,
|
||||
uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
// Set inputs and outputs LWE secret keys
|
||||
{
|
||||
for (auto param : params.inputs) {
|
||||
std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *> input = {
|
||||
param, nullptr, nullptr};
|
||||
if (param.encryption.hasValue()) {
|
||||
auto inputSk = this->secretKeys.find(param.encryption->secretKeyID);
|
||||
if (inputSk == this->secretKeys.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"input encryption secret key (" + param.encryption->secretKeyID +
|
||||
") does not exist ",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
std::get<1>(input) = &inputSk->second.first;
|
||||
std::get<2>(input) = inputSk->second.second;
|
||||
}
|
||||
this->inputs.push_back(input);
|
||||
}
|
||||
for (auto param : params.outputs) {
|
||||
std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *> output =
|
||||
{param, nullptr, nullptr};
|
||||
if (param.encryption.hasValue()) {
|
||||
auto outputSk = this->secretKeys.find(param.encryption->secretKeyID);
|
||||
if (outputSk == this->secretKeys.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot find output key to generate bootstrap key",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
std::get<1>(output) = &outputSk->second.first;
|
||||
std::get<2>(output) = outputSk->second.second;
|
||||
}
|
||||
this->outputs.push_back(output);
|
||||
}
|
||||
}
|
||||
int err;
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
this->encryptionRandomGenerator =
|
||||
allocate_encryption_generator(&err, seed_msb, seed_lsb),
|
||||
"cannot allocate encryption generator");
|
||||
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error KeySet::generateKeysFromParams(ClientParameters ¶ms,
|
||||
uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
|
||||
{
|
||||
// Generate LWE secret keys
|
||||
@@ -39,8 +112,8 @@ KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
generator = allocate_secret_generator(&err, seed_msb, seed_lsb),
|
||||
"cannot allocate random generator");
|
||||
for (auto secretKeyParam : params.secretKeys) {
|
||||
auto e = keySet->generateSecretKey(secretKeyParam.first,
|
||||
secretKeyParam.second, generator);
|
||||
auto e = this->generateSecretKey(secretKeyParam.first,
|
||||
secretKeyParam.second, generator);
|
||||
if (e) {
|
||||
return std::move(e);
|
||||
}
|
||||
@@ -50,62 +123,43 @@ KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
}
|
||||
// Allocate the encryption random generator
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
keySet->encryptionRandomGenerator =
|
||||
this->encryptionRandomGenerator =
|
||||
allocate_encryption_generator(&err, seed_msb, seed_lsb),
|
||||
"cannot allocate encryption generator");
|
||||
// Generate bootstrap and keyswitch keys
|
||||
{
|
||||
for (auto bootstrapKeyParam : params.bootstrapKeys) {
|
||||
auto e = keySet->generateBootstrapKey(bootstrapKeyParam.first,
|
||||
bootstrapKeyParam.second,
|
||||
keySet->encryptionRandomGenerator);
|
||||
auto e = this->generateBootstrapKey(bootstrapKeyParam.first,
|
||||
bootstrapKeyParam.second,
|
||||
this->encryptionRandomGenerator);
|
||||
if (e) {
|
||||
return std::move(e);
|
||||
}
|
||||
}
|
||||
for (auto keyswitchParam : params.keyswitchKeys) {
|
||||
auto e = keySet->generateKeyswitchKey(keyswitchParam.first,
|
||||
keyswitchParam.second,
|
||||
keySet->encryptionRandomGenerator);
|
||||
auto e = this->generateKeyswitchKey(keyswitchParam.first,
|
||||
keyswitchParam.second,
|
||||
this->encryptionRandomGenerator);
|
||||
if (e) {
|
||||
return std::move(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Set inputs and outputs LWE secret keys
|
||||
{
|
||||
for (auto param : params.inputs) {
|
||||
std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *> input = {
|
||||
param, nullptr, nullptr};
|
||||
if (param.encryption.hasValue()) {
|
||||
auto inputSk = keySet->secretKeys.find(param.encryption->secretKeyID);
|
||||
if (inputSk == keySet->secretKeys.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot find input key to generate bootstrap key",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
std::get<1>(input) = &inputSk->second.first;
|
||||
std::get<2>(input) = inputSk->second.second;
|
||||
}
|
||||
keySet->inputs.push_back(input);
|
||||
}
|
||||
for (auto param : params.outputs) {
|
||||
std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *> output =
|
||||
{param, nullptr, nullptr};
|
||||
if (param.encryption.hasValue()) {
|
||||
auto outputSk = keySet->secretKeys.find(param.encryption->secretKeyID);
|
||||
if (outputSk == keySet->secretKeys.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot find output key to generate bootstrap key",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
std::get<1>(output) = &outputSk->second.first;
|
||||
std::get<2>(output) = outputSk->second.second;
|
||||
}
|
||||
keySet->outputs.push_back(output);
|
||||
}
|
||||
}
|
||||
return std::move(keySet);
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
void KeySet::setKeys(
|
||||
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
|
||||
secretKeys,
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>>
|
||||
bootstrapKeys,
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>>
|
||||
keyswitchKeys) {
|
||||
this->secretKeys = secretKeys;
|
||||
this->bootstrapKeys = bootstrapKeys;
|
||||
this->keyswitchKeys = keyswitchKeys;
|
||||
}
|
||||
|
||||
llvm::Error KeySet::generateSecretKey(LweSecretKeyID id,
|
||||
@@ -120,6 +174,7 @@ llvm::Error KeySet::generateSecretKey(LweSecretKeyID id,
|
||||
"cannot fill secret key with random generator");
|
||||
|
||||
secretKeys[id] = {param, sk};
|
||||
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
@@ -136,7 +191,7 @@ llvm::Error KeySet::generateBootstrapKey(BootstrapKeyID id,
|
||||
auto outputSk = secretKeys.find(param.outputSecretKeyID);
|
||||
if (outputSk == secretKeys.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot find input key to generate bootstrap key",
|
||||
"cannot find output key to generate bootstrap key",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// Allocate the bootstrap key
|
||||
@@ -291,5 +346,22 @@ llvm::Error KeySet::decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
const std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
|
||||
&KeySet::getSecretKeys() {
|
||||
return secretKeys;
|
||||
}
|
||||
|
||||
const std::map<LweSecretKeyID,
|
||||
std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>> &
|
||||
KeySet::getBootstrapKeys() {
|
||||
return bootstrapKeys;
|
||||
}
|
||||
|
||||
const std::map<LweSecretKeyID,
|
||||
std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>> &
|
||||
KeySet::getKeyswitchKeys() {
|
||||
return keyswitchKeys;
|
||||
}
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
193
compiler/lib/Support/KeySetCache.cpp
Normal file
193
compiler/lib/Support/KeySetCache.cpp
Normal file
@@ -0,0 +1,193 @@
|
||||
#include "zamalang/Support/KeySetCache.h"
|
||||
#include "zamalang/Support/Error.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/Path.h"
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
extern "C" {
|
||||
#include "concrete-ffi.h"
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
static std::string readFile(llvm::SmallString<0> &path) {
|
||||
std::ifstream in((std::string)path, std::ofstream::binary);
|
||||
std::stringstream sbuffer;
|
||||
sbuffer << in.rdbuf();
|
||||
return sbuffer.str();
|
||||
}
|
||||
|
||||
static void writeFile(llvm::SmallString<0> &path, Buffer content) {
|
||||
std::ofstream out((std::string)path, std::ofstream::binary);
|
||||
out.write((const char *)content.pointer, content.length);
|
||||
out.close();
|
||||
}
|
||||
|
||||
LweSecretKey_u64 *loadSecretKey(llvm::SmallString<0> &path) {
|
||||
std::string content = readFile(path);
|
||||
BufferView buffer = {(const uint8_t *)content.c_str(), content.length()};
|
||||
return deserialize_lwe_secret_key_u64(buffer);
|
||||
}
|
||||
|
||||
LweKeyswitchKey_u64 *loadKeyswitchKey(llvm::SmallString<0> &path) {
|
||||
std::string content = readFile(path);
|
||||
BufferView buffer = {(const uint8_t *)content.c_str(), content.length()};
|
||||
return deserialize_lwe_keyswitching_key_u64(buffer);
|
||||
}
|
||||
|
||||
LweBootstrapKey_u64 *loadBootstrapKey(llvm::SmallString<0> &path) {
|
||||
std::string content = readFile(path);
|
||||
BufferView buffer = {(const uint8_t *)content.c_str(), content.length()};
|
||||
return deserialize_lwe_bootstrap_key_u64(buffer);
|
||||
}
|
||||
|
||||
void saveSecretKey(llvm::SmallString<0> &path, LweSecretKey_u64 *key) {
|
||||
Buffer buffer = serialize_lwe_secret_key_u64(key);
|
||||
writeFile(path, buffer);
|
||||
free(buffer.pointer);
|
||||
}
|
||||
|
||||
void saveBootstrapKey(llvm::SmallString<0> &path, LweBootstrapKey_u64 *key) {
|
||||
Buffer buffer = serialize_lwe_bootstrap_key_u64(key);
|
||||
writeFile(path, buffer);
|
||||
free(buffer.pointer);
|
||||
}
|
||||
|
||||
void saveKeyswitchKey(llvm::SmallString<0> &path, LweKeyswitchKey_u64 *key) {
|
||||
Buffer buffer = serialize_lwe_keyswitching_key_u64(key);
|
||||
writeFile(path, buffer);
|
||||
free(buffer.pointer);
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<KeySet>>
|
||||
KeySetCache::tryLoadKeys(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb, llvm::SmallString<0> &folderPath) {
|
||||
// TODO: text dump of all parameter in /hash
|
||||
auto key_set = KeySet::uninitialized();
|
||||
|
||||
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
|
||||
secretKeys;
|
||||
std::map<LweSecretKeyID, std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>>
|
||||
bootstrapKeys;
|
||||
std::map<LweSecretKeyID, std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>>
|
||||
keyswitchKeys;
|
||||
|
||||
// Load LWE secret keys
|
||||
for (auto secretKeyParam : params.secretKeys) {
|
||||
auto id = secretKeyParam.first;
|
||||
auto param = secretKeyParam.second;
|
||||
llvm::SmallString<0> path = folderPath;
|
||||
llvm::sys::path::append(path, "secretKey_" + id);
|
||||
LweSecretKey_u64 *sk = loadSecretKey(path);
|
||||
secretKeys[id] = {param, sk};
|
||||
}
|
||||
// Load bootstrap keys
|
||||
for (auto bootstrapKeyParam : params.bootstrapKeys) {
|
||||
auto id = bootstrapKeyParam.first;
|
||||
auto param = bootstrapKeyParam.second;
|
||||
llvm::SmallString<0> path = folderPath;
|
||||
llvm::sys::path::append(path, "pbsKey_" + id);
|
||||
LweBootstrapKey_u64 *bsk = loadBootstrapKey(path);
|
||||
bootstrapKeys[id] = {param, bsk};
|
||||
}
|
||||
// Load keyswitch keys
|
||||
for (auto keyswitchParam : params.keyswitchKeys) {
|
||||
auto id = keyswitchParam.first;
|
||||
auto param = keyswitchParam.second;
|
||||
llvm::SmallString<0> path = folderPath;
|
||||
llvm::sys::path::append(path, "ksKey_" + id);
|
||||
LweKeyswitchKey_u64 *ksk = loadKeyswitchKey(path);
|
||||
keyswitchKeys[id] = {param, ksk};
|
||||
}
|
||||
|
||||
key_set->setKeys(secretKeys, bootstrapKeys, keyswitchKeys);
|
||||
|
||||
auto err = key_set->setupEncryptionMaterial(params, seed_msb, seed_lsb);
|
||||
if (err) {
|
||||
return StreamStringError() << "Cannot setup encryption material: " << err;
|
||||
}
|
||||
|
||||
return key_set;
|
||||
}
|
||||
|
||||
llvm::Error saveKeys(KeySet &key_set, llvm::SmallString<0> &folderPath) {
|
||||
llvm::SmallString<0> folderIncompletePath = folderPath;
|
||||
|
||||
folderIncompletePath.append(".incomplete");
|
||||
|
||||
auto err = llvm::sys::fs::create_directories(folderIncompletePath);
|
||||
if (err) {
|
||||
return StreamStringError()
|
||||
<< "Cannot create directory \"" << folderIncompletePath
|
||||
<< "\": " << err.message();
|
||||
}
|
||||
|
||||
// Save LWE secret keys
|
||||
for (auto secretKeyParam : key_set.getSecretKeys()) {
|
||||
auto id = secretKeyParam.first;
|
||||
auto key = secretKeyParam.second.second;
|
||||
llvm::SmallString<0> path = folderIncompletePath;
|
||||
llvm::sys::path::append(path, "secretKey_" + id);
|
||||
saveSecretKey(path, key);
|
||||
}
|
||||
// Save bootstrap keys
|
||||
for (auto bootstrapKeyParam : key_set.getBootstrapKeys()) {
|
||||
auto id = bootstrapKeyParam.first;
|
||||
auto key = bootstrapKeyParam.second.second;
|
||||
llvm::SmallString<0> path = folderIncompletePath;
|
||||
llvm::sys::path::append(path, "pbsKey_" + id);
|
||||
saveBootstrapKey(path, key);
|
||||
}
|
||||
// Save keyswitch keys
|
||||
for (auto keyswitchParam : key_set.getKeyswitchKeys()) {
|
||||
auto id = keyswitchParam.first;
|
||||
auto key = keyswitchParam.second.second;
|
||||
llvm::SmallString<0> path = folderIncompletePath;
|
||||
llvm::sys::path::append(path, "ksKey_" + id);
|
||||
saveKeyswitchKey(path, key);
|
||||
}
|
||||
|
||||
err = llvm::sys::fs::rename(folderIncompletePath, folderPath);
|
||||
if (err) {
|
||||
return StreamStringError()
|
||||
<< "Cannot rename directory \"" << folderIncompletePath << "\" \""
|
||||
<< folderPath << "\": " << err.message();
|
||||
}
|
||||
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<KeySet>>
|
||||
KeySetCache::tryLoadOrGenerateSave(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
|
||||
llvm::SmallString<0> folderPath =
|
||||
llvm::SmallString<0>(this->backingDirectoryPath);
|
||||
|
||||
llvm::sys::path::append(folderPath, std::to_string(params.hash()));
|
||||
|
||||
llvm::sys::path::append(folderPath, std::to_string(seed_msb) + "_" +
|
||||
std::to_string(seed_lsb));
|
||||
|
||||
if (llvm::sys::fs::exists(folderPath)) {
|
||||
return tryLoadKeys(params, seed_msb, seed_lsb, folderPath);
|
||||
} else {
|
||||
auto key_set = KeySet::generate(params, seed_msb, seed_lsb);
|
||||
|
||||
if (!key_set) {
|
||||
return StreamStringError()
|
||||
<< "Cannot generate key set: " << key_set.takeError();
|
||||
}
|
||||
|
||||
auto savedErr = saveKeys(*(key_set.get()), folderPath);
|
||||
if (savedErr) {
|
||||
return StreamStringError() << "Cannot save key set: " << savedErr;
|
||||
}
|
||||
|
||||
return key_set;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from zamalang import CompilerEngine, library
|
||||
|
||||
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result",
|
||||
@@ -104,7 +106,7 @@ from zamalang import CompilerEngine, library
|
||||
)
|
||||
def test_compile_and_run(mlir_input, args, expected_result):
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(mlir_input)
|
||||
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH)
|
||||
if isinstance(expected_result, int):
|
||||
assert engine.run(*args) == expected_result
|
||||
else:
|
||||
@@ -129,7 +131,7 @@ def test_compile_and_run(mlir_input, args, expected_result):
|
||||
)
|
||||
def test_compile_and_run_invalid_arg_number(mlir_input, args):
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(mlir_input)
|
||||
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH)
|
||||
with pytest.raises(ValueError, match=r"wrong number of arguments"):
|
||||
engine.run(*args)
|
||||
|
||||
@@ -154,7 +156,7 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args):
|
||||
)
|
||||
def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(mlir_input)
|
||||
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH)
|
||||
assert abs(engine.run(*args) - expected_result) / tab_size < 0.1
|
||||
|
||||
|
||||
@@ -177,7 +179,7 @@ def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
|
||||
def test_compile_invalid(mlir_input):
|
||||
engine = CompilerEngine()
|
||||
with pytest.raises(RuntimeError, match=r"Compilation failed:"):
|
||||
engine.compile_fhe(mlir_input)
|
||||
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH)
|
||||
|
||||
|
||||
MODULE_1 = """
|
||||
|
||||
@@ -5,6 +5,9 @@
|
||||
|
||||
#include "zamalang/Support/CompilerEngine.h"
|
||||
#include "zamalang/Support/JitCompilerEngine.h"
|
||||
#include "zamalang/Support/KeySetCache.h"
|
||||
#include "llvm/Support/Path.h"
|
||||
|
||||
#include "globals.h"
|
||||
|
||||
#define ASSERT_LLVM_ERROR(err) \
|
||||
@@ -80,18 +83,32 @@ static bool assert_expected_value(llvm::Expected<T> &&val, const V &exp) {
|
||||
// and reult in abnormal termination.
|
||||
template <typename F>
|
||||
mlir::zamalang::JitCompilerEngine::Lambda
|
||||
internalCheckedJit(F checkfunc, llvm::StringRef src,
|
||||
internalCheckedJit(F checkFunc, llvm::StringRef src,
|
||||
llvm::StringRef func = "main",
|
||||
bool useDefaultFHEConstraints = false) {
|
||||
|
||||
llvm::SmallString<0> cachePath;
|
||||
|
||||
llvm::sys::path::system_temp_directory(true, cachePath);
|
||||
|
||||
llvm::sys::path::append(cachePath, "KeySetCache");
|
||||
|
||||
auto cachePathStr = std::string(cachePath);
|
||||
auto optCache = llvm::Optional<mlir::zamalang::KeySetCache>(
|
||||
mlir::zamalang::KeySetCache(cachePathStr));
|
||||
|
||||
mlir::zamalang::JitCompilerEngine engine;
|
||||
|
||||
if (useDefaultFHEConstraints)
|
||||
engine.setFHEConstraints(defaultV0Constraints);
|
||||
|
||||
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
|
||||
engine.buildLambda(src, func);
|
||||
engine.buildLambda(src, func, optCache);
|
||||
|
||||
checkfunc(lambdaOrErr);
|
||||
if (!lambdaOrErr) {
|
||||
std::cout << llvm::toString(lambdaOrErr.takeError()) << std::endl;
|
||||
}
|
||||
checkFunc(lambdaOrErr);
|
||||
|
||||
return std::move(*lambdaOrErr);
|
||||
}
|
||||
@@ -116,4 +133,4 @@ static inline uint64_t operator"" _u64(unsigned long long int v) { return v; }
|
||||
}, \
|
||||
__VA_ARGS__)
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user