feat(compiler): add a key cache

This commit is contained in:
Mayeul@Zama
2021-11-25 18:26:42 +01:00
committed by mayeul-zama
parent f193fd71a2
commit dad4390518
17 changed files with 533 additions and 87 deletions

View File

@@ -22,6 +22,15 @@ jobs:
with:
submodules: recursive
- name: "KeySetCache"
uses: actions/cache@v2
with:
path: ${{ github.workspace }}/KeySetCache
# actions/cache does not permit to update a cache entry
key: ${{ runner.os }}-KeySetCache-2021-12-02
restore-keys: |
${{ runner.os }}-KeySetCache-
- name: Build and test compiler
uses: addnab/docker-run-action@v3
with:
@@ -29,7 +38,7 @@ jobs:
image: ghcr.io/zama-ai/zamalang-compiler:latest
username: ${{ secrets.GHCR_LOGIN }}
password: ${{ secrets.GHCR_PASSWORD }}
options: -v ${{ github.workspace }}/compiler:/compiler
options: -v ${{ github.workspace }}/compiler:/compiler -v ${{ github.workspace }}/KeySetCache:/tmp/KeySetCache
shell: bash
run: |
set -e
@@ -42,3 +51,4 @@ jobs:
make CCACHE=ON BUILD_DIR=/build test
echo "Debug: ccache statistics (after the build):"
ccache -s
chmod -R ugo+rwx /tmp/KeySetCache

View File

@@ -43,7 +43,7 @@ test-check: zamacompiler file-check not
$(BUILD_DIR)/bin/llvm-lit -v tests/
test-python: python-bindings zamacompiler
PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/zamalang/python_packages/zamalang_core LD_PRELOAD=$(BUILD_DIR)/lib/libZamalangRuntime.so pytest -vs tests/python
PATH=$(BUILD_DIR)/bin:${PATH} PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/zamalang/python_packages/zamalang_core LD_PRELOAD=$(BUILD_DIR)/lib/libZamalangRuntime.so pytest -vs tests/python
test: test-check test-end-to-end-jit test-python

View File

@@ -32,10 +32,12 @@ typedef struct executionArguments executionArguments;
// Build lambda from a textual representation of an MLIR module
// The lambda will have `funcName` as entrypoint, and use runtimeLibPath (if not
// null) as a shared library during compilation
// null) as a shared library during compilation,
// a path to activate the use a cache for encryption keys for test purpose
// (unsecure).
MLIR_CAPI_EXPORTED mlir::zamalang::JitCompilerEngine::Lambda
buildLambda(const char *module, const char *funcName,
const char *runtimeLibPath);
const char *runtimeLibPath, const char *keySetCachePath);
// Parse then print a textual representation of an MLIR module
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);

View File

@@ -24,6 +24,8 @@ typedef uint64_t GlweDimension;
typedef std::string LweSecretKeyID;
struct LweSecretKeyParam {
LweSize size;
void hash(size_t &seed);
};
typedef std::string BootstrapKeyID;
@@ -34,6 +36,8 @@ struct BootstrapKeyParam {
DecompositionBaseLog baseLog;
GlweDimension glweDimension;
Variance variance;
void hash(size_t &seed);
};
typedef std::string KeyswitchKeyID;
@@ -43,6 +47,8 @@ struct KeyswitchKeyParam {
DecompositionLevelCount level;
DecompositionBaseLog baseLog;
Variance variance;
void hash(size_t &seed);
};
struct Encoding {
@@ -75,11 +81,13 @@ struct ClientParameters {
std::map<KeyswitchKeyID, KeyswitchKeyParam> keyswitchKeys;
std::vector<CircuitGate> inputs;
std::vector<CircuitGate> outputs;
size_t hash();
};
llvm::Expected<ClientParameters>
createClientParametersForV0(V0FHEContext context, llvm::StringRef name,
mlir::ModuleOp module);
} // namespace zamalang
} // namespace mlir

View File

@@ -1,6 +1,7 @@
#ifndef ZAMALANG_SUPPORT_JIT_COMPILER_ENGINE_H
#define ZAMALANG_SUPPORT_JIT_COMPILER_ENGINE_H
#include "zamalang/Support/KeySetCache.h"
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <zamalang/Support/CompilerEngine.h>
#include <zamalang/Support/Error.h>
@@ -363,15 +364,18 @@ public:
/// Use runtimeLibPath as a shared library if specified.
llvm::Expected<Lambda>
buildLambda(llvm::StringRef src, llvm::StringRef funcName = "main",
llvm::Optional<KeySetCache> cachePath = {},
llvm::Optional<llvm::StringRef> runtimeLibPath = {});
llvm::Expected<Lambda>
buildLambda(std::unique_ptr<llvm::MemoryBuffer> buffer,
llvm::StringRef funcName = "main",
llvm::Optional<KeySetCache> cachePath = {},
llvm::Optional<llvm::StringRef> runtimeLibPath = {});
llvm::Expected<Lambda>
buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName = "main",
llvm::Optional<KeySetCache> cachePath = {},
llvm::Optional<llvm::StringRef> runtimeLibPath = {});
protected:

View File

@@ -10,6 +10,7 @@ extern "C" {
}
#include "zamalang/Support/ClientParameters.h"
#include "zamalang/Support/KeySetCache.h"
namespace mlir {
namespace zamalang {
@@ -17,6 +18,15 @@ namespace zamalang {
class KeySet {
public:
~KeySet();
static std::unique_ptr<KeySet> uninitialized();
llvm::Error generateKeysFromParams(ClientParameters &params,
uint64_t seed_msb, uint64_t seed_lsb);
llvm::Error setupEncryptionMaterial(ClientParameters &params,
uint64_t seed_msb, uint64_t seed_lsb);
// allocate a KeySet according the ClientParameters.
static llvm::Expected<std::unique_ptr<KeySet>>
generate(ClientParameters &params, uint64_t seed_msb, uint64_t seed_lsb);
@@ -46,6 +56,18 @@ public:
context.bsk = std::get<1>(this->bootstrapKeys["bsk_v0"]);
}
const std::map<LweSecretKeyID,
std::pair<LweSecretKeyParam, LweSecretKey_u64 *>> &
getSecretKeys();
const std::map<LweSecretKeyID,
std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>> &
getBootstrapKeys();
const std::map<LweSecretKeyID,
std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>> &
getKeyswitchKeys();
protected:
llvm::Error generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param,
SecretRandomGenerator *generator);
@@ -54,6 +76,8 @@ protected:
llvm::Error generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param,
EncryptionRandomGenerator *generator);
friend class KeySetCache;
private:
EncryptionRandomGenerator *encryptionRandomGenerator;
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
@@ -66,6 +90,16 @@ private:
inputs;
std::vector<std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *>>
outputs;
void setKeys(
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
secretKeys,
std::map<LweSecretKeyID,
std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>>
bootstrapKeys,
std::map<LweSecretKeyID,
std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>>
keyswitchKeys);
};
} // namespace zamalang

View File

@@ -0,0 +1,31 @@
#ifndef ZAMALANG_SUPPORT_KEYSETCACHE_H_
#define ZAMALANG_SUPPORT_KEYSETCACHE_H_
#include "zamalang/Support/KeySet.h"
namespace mlir {
namespace zamalang {
class KeySet;
class KeySetCache {
std::string backingDirectoryPath;
public:
KeySetCache(std::string backingDirectoryPath)
: backingDirectoryPath(backingDirectoryPath) {}
llvm::Expected<std::unique_ptr<KeySet>>
tryLoadOrGenerateSave(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb);
private:
static llvm::Expected<std::unique_ptr<KeySet>>
tryLoadKeys(ClientParameters &params, uint64_t seed_msb, uint64_t seed_lsb,
llvm::SmallString<0> &folderPath);
};
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -17,6 +17,10 @@
using mlir::zamalang::JitCompilerEngine;
using mlir::zamalang::LambdaArgument;
const char *noEmptyStringPtr(std::string &s) {
return (s.empty()) ? nullptr : s.c_str();
}
/// Populate the compiler API python module.
void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
m.doc() = "Zamalang compiler python API";
@@ -31,14 +35,14 @@ void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
pybind11::class_<JitCompilerEngine>(m, "JitCompilerEngine")
.def(pybind11::init())
.def_static("build_lambda", [](std::string mlir_input,
std::string func_name,
std::string runtime_lib_path) {
if (runtime_lib_path.empty())
return buildLambda(mlir_input.c_str(), func_name.c_str(), nullptr);
return buildLambda(mlir_input.c_str(), func_name.c_str(),
runtime_lib_path.c_str());
});
.def_static("build_lambda",
[](std::string mlir_input, std::string func_name,
std::string runtime_lib_path,
std::string keysetcache_path) {
return buildLambda(mlir_input.c_str(), func_name.c_str(),
noEmptyStringPtr(runtime_lib_path),
noEmptyStringPtr(keysetcache_path));
});
pybind11::class_<lambdaArgument>(m, "LambdaArgument")
.def_static("from_tensor", lambdaArgumentFromTensor)

View File

@@ -118,7 +118,8 @@ class CompilerEngine:
self.compile_fhe(mlir_str)
def compile_fhe(
self, mlir_str: str, func_name: str = "main", runtime_lib_path: str = None
self, mlir_str: str, func_name: str = "main", runtime_lib_path: str = None,
unsecure_key_set_cache_path: str = None,
):
"""Compile the MLIR input.
@@ -126,6 +127,7 @@ class CompilerEngine:
mlir_str (str): MLIR to compile.
func_name (str): name of the function to set as entrypoint (default: main).
runtime_lib_path (str): path to the runtime lib (default: None).
unsecure_key_set_cache_path (str): path to the activate keyset caching (default: None).
Raises:
TypeError: if the argument is not an str.
@@ -140,7 +142,14 @@ class CompilerEngine:
raise TypeError(
"runtime_lib_path must be an str representing the path to the runtime lib"
)
self._lambda = self._engine.build_lambda(mlir_str, func_name, runtime_lib_path)
unsecure_key_set_cache_path = unsecure_key_set_cache_path or ""
if not isinstance(unsecure_key_set_cache_path, str):
raise TypeError(
"unsecure_key_set_cache_path must be a str"
)
self._lambda = self._engine.build_lambda(
mlir_str, func_name, runtime_lib_path,
unsecure_key_set_cache_path)
def run(self, *args: List[Union[int, np.ndarray]]) -> Union[int, np.ndarray]:
"""Run the compiled code.

View File

@@ -1,20 +1,31 @@
#include "llvm/ADT/SmallString.h"
#include "zamalang-c/Support/CompilerEngine.h"
#include "zamalang/Support/CompilerEngine.h"
#include "zamalang/Support/Jit.h"
#include "zamalang/Support/JitCompilerEngine.h"
#include "zamalang/Support/KeySetCache.h"
using mlir::zamalang::JitCompilerEngine;
mlir::zamalang::JitCompilerEngine::Lambda
buildLambda(const char *module, const char *funcName,
const char *runtimeLibPath) {
const char *runtimeLibPath, const char *keySetCachePath) {
// Set the runtime library path if not nullptr
llvm::Optional<llvm::StringRef> runtimeLibPathOptional = {};
if (runtimeLibPath != nullptr)
runtimeLibPathOptional = runtimeLibPath;
mlir::zamalang::JitCompilerEngine engine;
using KeySetCache = mlir::zamalang::KeySetCache;
using optKeySetCache = llvm::Optional<mlir::zamalang::KeySetCache>;
auto cacheOpt = optKeySetCache();
if (keySetCachePath != nullptr) {
cacheOpt = KeySetCache(std::string(keySetCachePath));
}
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
engine.buildLambda(module, funcName, runtimeLibPathOptional);
engine.buildLambda(module, funcName, cacheOpt, runtimeLibPathOptional);
if (!lambdaOrErr) {
std::string backingString;
llvm::raw_string_ostream os(backingString);

View File

@@ -11,6 +11,7 @@ add_mlir_library(ZamalangSupport
KeySet.cpp
logging.cpp
Jit.cpp
KeySetCache.cpp
LLVMEmitFile.cpp
ADDITIONAL_HEADER_DIRS

View File

@@ -145,5 +145,44 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
return c;
}
// https://stackoverflow.com/a/38140932
static inline void hash(std::size_t &seed) {}
template <typename T, typename... Rest>
static inline void hash(std::size_t &seed, const T &v, Rest... rest) {
// See https://softwareengineering.stackexchange.com/a/402543
const auto GOLDEN_RATIO = 0x9e3779b97f4a7c15; // pseudo random bits
const std::hash<T> hasher;
seed ^= hasher(v) + GOLDEN_RATIO + (seed << 6) + (seed >> 2);
hash(seed, rest...);
}
void LweSecretKeyParam::hash(size_t &seed) { mlir::zamalang::hash(seed, size); }
void BootstrapKeyParam::hash(size_t &seed) {
mlir::zamalang::hash(seed, inputSecretKeyID, outputSecretKeyID, level,
baseLog, glweDimension, variance);
}
void KeyswitchKeyParam::hash(size_t &seed) {
mlir::zamalang::hash(seed, inputSecretKeyID, outputSecretKeyID, level,
baseLog, variance);
}
std::size_t ClientParameters::hash() {
std::size_t currentHash = 1;
for (auto secretKeyParam : secretKeys) {
mlir::zamalang::hash(currentHash, secretKeyParam.first);
secretKeyParam.second.hash(currentHash);
}
for (auto bootstrapKeyParam : bootstrapKeys) {
mlir::zamalang::hash(currentHash, bootstrapKeyParam.first);
bootstrapKeyParam.second.hash(currentHash);
}
for (auto keyswitchParam : keyswitchKeys) {
mlir::zamalang::hash(currentHash, keyswitchParam.first);
keyswitchParam.second.hash(currentHash);
}
return currentHash;
}
} // namespace zamalang
} // namespace mlir

View File

@@ -37,25 +37,22 @@ JitCompilerEngine::findLLVMFuncOp(mlir::ModuleOp module, llvm::StringRef name) {
llvm::Expected<JitCompilerEngine::Lambda>
JitCompilerEngine::buildLambda(std::unique_ptr<llvm::MemoryBuffer> buffer,
llvm::StringRef funcName,
llvm::Optional<KeySetCache> cache,
llvm::Optional<llvm::StringRef> runtimeLibPath) {
llvm::SourceMgr sm;
sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
llvm::Expected<JitCompilerEngine::Lambda> res =
this->buildLambda(sm, funcName, runtimeLibPath);
return std::move(res);
return this->buildLambda(sm, funcName, cache, runtimeLibPath);
}
// Build a lambda from the function with the name given in `funcName`
// from the source string `s`.
llvm::Expected<JitCompilerEngine::Lambda>
JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName,
llvm::Optional<KeySetCache> cache,
llvm::Optional<llvm::StringRef> runtimeLibPath) {
std::unique_ptr<llvm::MemoryBuffer> mb = llvm::MemoryBuffer::getMemBuffer(s);
llvm::Expected<JitCompilerEngine::Lambda> res =
this->buildLambda(std::move(mb), funcName, runtimeLibPath);
this->buildLambda(std::move(mb), funcName, cache, runtimeLibPath);
return std::move(res);
}
@@ -64,6 +61,7 @@ JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName,
// `funcName` from the sources managed by the source manager `sm`.
llvm::Expected<JitCompilerEngine::Lambda>
JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName,
llvm::Optional<KeySetCache> cache,
llvm::Optional<llvm::StringRef> runtimeLibPath) {
MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
@@ -77,14 +75,17 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName,
if (!compResOrErr)
return std::move(compResOrErr.takeError());
mlir::ModuleOp module = compResOrErr->mlirModuleRef->get();
auto compRes = std::move(compResOrErr.get());
mlir::ModuleOp module = compRes.mlirModuleRef->get();
// Locate function to JIT-compile
llvm::Expected<mlir::LLVM::LLVMFuncOp> funcOrError =
this->findLLVMFuncOp(compResOrErr->mlirModuleRef->get(), funcName);
this->findLLVMFuncOp(compRes.mlirModuleRef->get(), funcName);
if (!funcOrError)
return std::move(funcOrError.takeError());
return StreamStringError() << "Cannot find function \"" << funcName
<< "\": " << std::move(funcOrError.takeError());
// Prepare LLVM infrastructure for JIT compilation
llvm::InitializeNativeTarget();
@@ -98,24 +99,32 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName,
mlir::zamalang::JITLambda::create(funcName, module, optPipeline,
runtimeLibPath);
if (!lambdaOrErr) {
return StreamStringError()
<< "Cannot create lambda: " << lambdaOrErr.takeError();
}
auto lambda = std::move(lambdaOrErr.get());
// Generate the KeySet for encrypting lambda arguments, decrypting lambda
// results
if (!compResOrErr->clientParameters.hasValue()) {
if (!compRes.clientParameters.hasValue()) {
return StreamStringError("Cannot generate the keySet since client "
"parameters has not been computed");
}
llvm::Expected<std::unique_ptr<mlir::zamalang::KeySet>> keySetOrErr =
mlir::zamalang::KeySet::generate(*compResOrErr->clientParameters, 0, 0);
(cache.hasValue())
? cache->tryLoadOrGenerateSave(*compRes.clientParameters, 0, 0)
: KeySet::generate(*compRes.clientParameters, 0, 0);
if (auto err = keySetOrErr.takeError())
return std::move(err);
if (!keySetOrErr) {
return keySetOrErr.takeError();
}
if (!lambdaOrErr)
return std::move(lambdaOrErr.takeError());
auto keySet = std::move(keySetOrErr.get());
return Lambda{this->compilationContext, std::move(lambdaOrErr.get()),
std::move(*keySetOrErr)};
return Lambda{this->compilationContext, std::move(lambda), std::move(keySet)};
}
} // namespace zamalang

View File

@@ -1,4 +1,5 @@
#include "zamalang/Support/KeySet.h"
#include "zamalang/Support/Error.h"
#define CAPI_ERR_TO_LLVM_ERROR(s, msg) \
{ \
@@ -30,7 +31,79 @@ KeySet::~KeySet() {
llvm::Expected<std::unique_ptr<KeySet>>
KeySet::generate(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb) {
auto keySet = std::make_unique<KeySet>();
auto a = uninitialized();
auto fillError = a->generateKeysFromParams(params, seed_msb, seed_lsb);
if (fillError) {
return StreamStringError()
<< "Cannot fill keys from params: " << std::move(fillError);
}
fillError = a->setupEncryptionMaterial(params, seed_msb, seed_lsb);
if (fillError) {
return StreamStringError()
<< "Cannot setup encryption material: " << std::move(fillError);
}
return a;
}
std::unique_ptr<KeySet> KeySet::uninitialized() {
return std::make_unique<KeySet>();
}
llvm::Error KeySet::setupEncryptionMaterial(ClientParameters &params,
uint64_t seed_msb,
uint64_t seed_lsb) {
// Set inputs and outputs LWE secret keys
{
for (auto param : params.inputs) {
std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *> input = {
param, nullptr, nullptr};
if (param.encryption.hasValue()) {
auto inputSk = this->secretKeys.find(param.encryption->secretKeyID);
if (inputSk == this->secretKeys.end()) {
return llvm::make_error<llvm::StringError>(
"input encryption secret key (" + param.encryption->secretKeyID +
") does not exist ",
llvm::inconvertibleErrorCode());
}
std::get<1>(input) = &inputSk->second.first;
std::get<2>(input) = inputSk->second.second;
}
this->inputs.push_back(input);
}
for (auto param : params.outputs) {
std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *> output =
{param, nullptr, nullptr};
if (param.encryption.hasValue()) {
auto outputSk = this->secretKeys.find(param.encryption->secretKeyID);
if (outputSk == this->secretKeys.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find output key to generate bootstrap key",
llvm::inconvertibleErrorCode());
}
std::get<1>(output) = &outputSk->second.first;
std::get<2>(output) = outputSk->second.second;
}
this->outputs.push_back(output);
}
}
int err;
CAPI_ERR_TO_LLVM_ERROR(
this->encryptionRandomGenerator =
allocate_encryption_generator(&err, seed_msb, seed_lsb),
"cannot allocate encryption generator");
return llvm::Error::success();
}
llvm::Error KeySet::generateKeysFromParams(ClientParameters &params,
uint64_t seed_msb,
uint64_t seed_lsb) {
{
// Generate LWE secret keys
@@ -39,8 +112,8 @@ KeySet::generate(ClientParameters &params, uint64_t seed_msb,
generator = allocate_secret_generator(&err, seed_msb, seed_lsb),
"cannot allocate random generator");
for (auto secretKeyParam : params.secretKeys) {
auto e = keySet->generateSecretKey(secretKeyParam.first,
secretKeyParam.second, generator);
auto e = this->generateSecretKey(secretKeyParam.first,
secretKeyParam.second, generator);
if (e) {
return std::move(e);
}
@@ -50,62 +123,43 @@ KeySet::generate(ClientParameters &params, uint64_t seed_msb,
}
// Allocate the encryption random generator
CAPI_ERR_TO_LLVM_ERROR(
keySet->encryptionRandomGenerator =
this->encryptionRandomGenerator =
allocate_encryption_generator(&err, seed_msb, seed_lsb),
"cannot allocate encryption generator");
// Generate bootstrap and keyswitch keys
{
for (auto bootstrapKeyParam : params.bootstrapKeys) {
auto e = keySet->generateBootstrapKey(bootstrapKeyParam.first,
bootstrapKeyParam.second,
keySet->encryptionRandomGenerator);
auto e = this->generateBootstrapKey(bootstrapKeyParam.first,
bootstrapKeyParam.second,
this->encryptionRandomGenerator);
if (e) {
return std::move(e);
}
}
for (auto keyswitchParam : params.keyswitchKeys) {
auto e = keySet->generateKeyswitchKey(keyswitchParam.first,
keyswitchParam.second,
keySet->encryptionRandomGenerator);
auto e = this->generateKeyswitchKey(keyswitchParam.first,
keyswitchParam.second,
this->encryptionRandomGenerator);
if (e) {
return std::move(e);
}
}
}
// Set inputs and outputs LWE secret keys
{
for (auto param : params.inputs) {
std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *> input = {
param, nullptr, nullptr};
if (param.encryption.hasValue()) {
auto inputSk = keySet->secretKeys.find(param.encryption->secretKeyID);
if (inputSk == keySet->secretKeys.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find input key to generate bootstrap key",
llvm::inconvertibleErrorCode());
}
std::get<1>(input) = &inputSk->second.first;
std::get<2>(input) = inputSk->second.second;
}
keySet->inputs.push_back(input);
}
for (auto param : params.outputs) {
std::tuple<CircuitGate, LweSecretKeyParam *, LweSecretKey_u64 *> output =
{param, nullptr, nullptr};
if (param.encryption.hasValue()) {
auto outputSk = keySet->secretKeys.find(param.encryption->secretKeyID);
if (outputSk == keySet->secretKeys.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find output key to generate bootstrap key",
llvm::inconvertibleErrorCode());
}
std::get<1>(output) = &outputSk->second.first;
std::get<2>(output) = outputSk->second.second;
}
keySet->outputs.push_back(output);
}
}
return std::move(keySet);
return llvm::Error::success();
}
void KeySet::setKeys(
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
secretKeys,
std::map<LweSecretKeyID,
std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>>
bootstrapKeys,
std::map<LweSecretKeyID,
std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>>
keyswitchKeys) {
this->secretKeys = secretKeys;
this->bootstrapKeys = bootstrapKeys;
this->keyswitchKeys = keyswitchKeys;
}
llvm::Error KeySet::generateSecretKey(LweSecretKeyID id,
@@ -120,6 +174,7 @@ llvm::Error KeySet::generateSecretKey(LweSecretKeyID id,
"cannot fill secret key with random generator");
secretKeys[id] = {param, sk};
return llvm::Error::success();
}
@@ -136,7 +191,7 @@ llvm::Error KeySet::generateBootstrapKey(BootstrapKeyID id,
auto outputSk = secretKeys.find(param.outputSecretKeyID);
if (outputSk == secretKeys.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find input key to generate bootstrap key",
"cannot find output key to generate bootstrap key",
llvm::inconvertibleErrorCode());
}
// Allocate the bootstrap key
@@ -291,5 +346,22 @@ llvm::Error KeySet::decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
return llvm::Error::success();
}
const std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
&KeySet::getSecretKeys() {
return secretKeys;
}
const std::map<LweSecretKeyID,
std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>> &
KeySet::getBootstrapKeys() {
return bootstrapKeys;
}
const std::map<LweSecretKeyID,
std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>> &
KeySet::getKeyswitchKeys() {
return keyswitchKeys;
}
} // namespace zamalang
} // namespace mlir

View File

@@ -0,0 +1,193 @@
#include "zamalang/Support/KeySetCache.h"
#include "zamalang/Support/Error.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/Path.h"
#include <fstream>
#include <string>
extern "C" {
#include "concrete-ffi.h"
}
namespace mlir {
namespace zamalang {
static std::string readFile(llvm::SmallString<0> &path) {
std::ifstream in((std::string)path, std::ofstream::binary);
std::stringstream sbuffer;
sbuffer << in.rdbuf();
return sbuffer.str();
}
static void writeFile(llvm::SmallString<0> &path, Buffer content) {
std::ofstream out((std::string)path, std::ofstream::binary);
out.write((const char *)content.pointer, content.length);
out.close();
}
LweSecretKey_u64 *loadSecretKey(llvm::SmallString<0> &path) {
std::string content = readFile(path);
BufferView buffer = {(const uint8_t *)content.c_str(), content.length()};
return deserialize_lwe_secret_key_u64(buffer);
}
LweKeyswitchKey_u64 *loadKeyswitchKey(llvm::SmallString<0> &path) {
std::string content = readFile(path);
BufferView buffer = {(const uint8_t *)content.c_str(), content.length()};
return deserialize_lwe_keyswitching_key_u64(buffer);
}
LweBootstrapKey_u64 *loadBootstrapKey(llvm::SmallString<0> &path) {
std::string content = readFile(path);
BufferView buffer = {(const uint8_t *)content.c_str(), content.length()};
return deserialize_lwe_bootstrap_key_u64(buffer);
}
void saveSecretKey(llvm::SmallString<0> &path, LweSecretKey_u64 *key) {
Buffer buffer = serialize_lwe_secret_key_u64(key);
writeFile(path, buffer);
free(buffer.pointer);
}
void saveBootstrapKey(llvm::SmallString<0> &path, LweBootstrapKey_u64 *key) {
Buffer buffer = serialize_lwe_bootstrap_key_u64(key);
writeFile(path, buffer);
free(buffer.pointer);
}
void saveKeyswitchKey(llvm::SmallString<0> &path, LweKeyswitchKey_u64 *key) {
Buffer buffer = serialize_lwe_keyswitching_key_u64(key);
writeFile(path, buffer);
free(buffer.pointer);
}
llvm::Expected<std::unique_ptr<KeySet>>
KeySetCache::tryLoadKeys(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb, llvm::SmallString<0> &folderPath) {
// TODO: text dump of all parameter in /hash
auto key_set = KeySet::uninitialized();
std::map<LweSecretKeyID, std::pair<LweSecretKeyParam, LweSecretKey_u64 *>>
secretKeys;
std::map<LweSecretKeyID, std::pair<BootstrapKeyParam, LweBootstrapKey_u64 *>>
bootstrapKeys;
std::map<LweSecretKeyID, std::pair<KeyswitchKeyParam, LweKeyswitchKey_u64 *>>
keyswitchKeys;
// Load LWE secret keys
for (auto secretKeyParam : params.secretKeys) {
auto id = secretKeyParam.first;
auto param = secretKeyParam.second;
llvm::SmallString<0> path = folderPath;
llvm::sys::path::append(path, "secretKey_" + id);
LweSecretKey_u64 *sk = loadSecretKey(path);
secretKeys[id] = {param, sk};
}
// Load bootstrap keys
for (auto bootstrapKeyParam : params.bootstrapKeys) {
auto id = bootstrapKeyParam.first;
auto param = bootstrapKeyParam.second;
llvm::SmallString<0> path = folderPath;
llvm::sys::path::append(path, "pbsKey_" + id);
LweBootstrapKey_u64 *bsk = loadBootstrapKey(path);
bootstrapKeys[id] = {param, bsk};
}
// Load keyswitch keys
for (auto keyswitchParam : params.keyswitchKeys) {
auto id = keyswitchParam.first;
auto param = keyswitchParam.second;
llvm::SmallString<0> path = folderPath;
llvm::sys::path::append(path, "ksKey_" + id);
LweKeyswitchKey_u64 *ksk = loadKeyswitchKey(path);
keyswitchKeys[id] = {param, ksk};
}
key_set->setKeys(secretKeys, bootstrapKeys, keyswitchKeys);
auto err = key_set->setupEncryptionMaterial(params, seed_msb, seed_lsb);
if (err) {
return StreamStringError() << "Cannot setup encryption material: " << err;
}
return key_set;
}
llvm::Error saveKeys(KeySet &key_set, llvm::SmallString<0> &folderPath) {
llvm::SmallString<0> folderIncompletePath = folderPath;
folderIncompletePath.append(".incomplete");
auto err = llvm::sys::fs::create_directories(folderIncompletePath);
if (err) {
return StreamStringError()
<< "Cannot create directory \"" << folderIncompletePath
<< "\": " << err.message();
}
// Save LWE secret keys
for (auto secretKeyParam : key_set.getSecretKeys()) {
auto id = secretKeyParam.first;
auto key = secretKeyParam.second.second;
llvm::SmallString<0> path = folderIncompletePath;
llvm::sys::path::append(path, "secretKey_" + id);
saveSecretKey(path, key);
}
// Save bootstrap keys
for (auto bootstrapKeyParam : key_set.getBootstrapKeys()) {
auto id = bootstrapKeyParam.first;
auto key = bootstrapKeyParam.second.second;
llvm::SmallString<0> path = folderIncompletePath;
llvm::sys::path::append(path, "pbsKey_" + id);
saveBootstrapKey(path, key);
}
// Save keyswitch keys
for (auto keyswitchParam : key_set.getKeyswitchKeys()) {
auto id = keyswitchParam.first;
auto key = keyswitchParam.second.second;
llvm::SmallString<0> path = folderIncompletePath;
llvm::sys::path::append(path, "ksKey_" + id);
saveKeyswitchKey(path, key);
}
err = llvm::sys::fs::rename(folderIncompletePath, folderPath);
if (err) {
return StreamStringError()
<< "Cannot rename directory \"" << folderIncompletePath << "\" \""
<< folderPath << "\": " << err.message();
}
return llvm::Error::success();
}
llvm::Expected<std::unique_ptr<KeySet>>
KeySetCache::tryLoadOrGenerateSave(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb) {
llvm::SmallString<0> folderPath =
llvm::SmallString<0>(this->backingDirectoryPath);
llvm::sys::path::append(folderPath, std::to_string(params.hash()));
llvm::sys::path::append(folderPath, std::to_string(seed_msb) + "_" +
std::to_string(seed_lsb));
if (llvm::sys::fs::exists(folderPath)) {
return tryLoadKeys(params, seed_msb, seed_lsb, folderPath);
} else {
auto key_set = KeySet::generate(params, seed_msb, seed_lsb);
if (!key_set) {
return StreamStringError()
<< "Cannot generate key set: " << key_set.takeError();
}
auto savedErr = saveKeys(*(key_set.get()), folderPath);
if (savedErr) {
return StreamStringError() << "Cannot save key set: " << savedErr;
}
return key_set;
}
}
} // namespace zamalang
} // namespace mlir

View File

@@ -1,9 +1,11 @@
import os
import tempfile
import pytest
import numpy as np
from zamalang import CompilerEngine, library
KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
@pytest.mark.parametrize(
"mlir_input, args, expected_result",
@@ -104,7 +106,7 @@ from zamalang import CompilerEngine, library
)
def test_compile_and_run(mlir_input, args, expected_result):
engine = CompilerEngine()
engine.compile_fhe(mlir_input)
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH)
if isinstance(expected_result, int):
assert engine.run(*args) == expected_result
else:
@@ -129,7 +131,7 @@ def test_compile_and_run(mlir_input, args, expected_result):
)
def test_compile_and_run_invalid_arg_number(mlir_input, args):
engine = CompilerEngine()
engine.compile_fhe(mlir_input)
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH)
with pytest.raises(ValueError, match=r"wrong number of arguments"):
engine.run(*args)
@@ -154,7 +156,7 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args):
)
def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
engine = CompilerEngine()
engine.compile_fhe(mlir_input)
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH)
assert abs(engine.run(*args) - expected_result) / tab_size < 0.1
@@ -177,7 +179,7 @@ def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size):
def test_compile_invalid(mlir_input):
engine = CompilerEngine()
with pytest.raises(RuntimeError, match=r"Compilation failed:"):
engine.compile_fhe(mlir_input)
engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH)
MODULE_1 = """

View File

@@ -5,6 +5,9 @@
#include "zamalang/Support/CompilerEngine.h"
#include "zamalang/Support/JitCompilerEngine.h"
#include "zamalang/Support/KeySetCache.h"
#include "llvm/Support/Path.h"
#include "globals.h"
#define ASSERT_LLVM_ERROR(err) \
@@ -80,18 +83,32 @@ static bool assert_expected_value(llvm::Expected<T> &&val, const V &exp) {
// and reult in abnormal termination.
template <typename F>
mlir::zamalang::JitCompilerEngine::Lambda
internalCheckedJit(F checkfunc, llvm::StringRef src,
internalCheckedJit(F checkFunc, llvm::StringRef src,
llvm::StringRef func = "main",
bool useDefaultFHEConstraints = false) {
llvm::SmallString<0> cachePath;
llvm::sys::path::system_temp_directory(true, cachePath);
llvm::sys::path::append(cachePath, "KeySetCache");
auto cachePathStr = std::string(cachePath);
auto optCache = llvm::Optional<mlir::zamalang::KeySetCache>(
mlir::zamalang::KeySetCache(cachePathStr));
mlir::zamalang::JitCompilerEngine engine;
if (useDefaultFHEConstraints)
engine.setFHEConstraints(defaultV0Constraints);
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
engine.buildLambda(src, func);
engine.buildLambda(src, func, optCache);
checkfunc(lambdaOrErr);
if (!lambdaOrErr) {
std::cout << llvm::toString(lambdaOrErr.takeError()) << std::endl;
}
checkFunc(lambdaOrErr);
return std::move(*lambdaOrErr);
}
@@ -116,4 +133,4 @@ static inline uint64_t operator"" _u64(unsigned long long int v) { return v; }
}, \
__VA_ARGS__)
#endif
#endif