From dad43905186abd9f4a8cdda8feded497a95c2a22 Mon Sep 17 00:00:00 2001 From: "Mayeul@Zama" Date: Thu, 25 Nov 2021 18:26:42 +0100 Subject: [PATCH] feat(compiler): add a key cache --- .github/workflows/conformance.yml | 12 +- compiler/Makefile | 2 +- .../zamalang-c/Support/CompilerEngine.h | 6 +- .../zamalang/Support/ClientParameters.h | 8 + .../zamalang/Support/JitCompilerEngine.h | 4 + compiler/include/zamalang/Support/KeySet.h | 34 +++ .../include/zamalang/Support/KeySetCache.h | 31 +++ .../lib/Bindings/Python/CompilerAPIModule.cpp | 20 +- .../lib/Bindings/Python/zamalang/compiler.py | 13 +- compiler/lib/CAPI/Support/CompilerEngine.cpp | 15 +- compiler/lib/Support/CMakeLists.txt | 1 + compiler/lib/Support/ClientParameters.cpp | 39 ++++ compiler/lib/Support/JitCompilerEngine.cpp | 45 ++-- compiler/lib/Support/KeySet.cpp | 162 +++++++++++---- compiler/lib/Support/KeySetCache.cpp | 193 ++++++++++++++++++ compiler/tests/python/test_compiler_engine.py | 10 +- compiler/tests/unittest/end_to_end_jit_test.h | 25 ++- 17 files changed, 533 insertions(+), 87 deletions(-) create mode 100644 compiler/include/zamalang/Support/KeySetCache.h create mode 100644 compiler/lib/Support/KeySetCache.cpp diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml index 63ab8da66..6496f159b 100644 --- a/.github/workflows/conformance.yml +++ b/.github/workflows/conformance.yml @@ -22,6 +22,15 @@ jobs: with: submodules: recursive + - name: "KeySetCache" + uses: actions/cache@v2 + with: + path: ${{ github.workspace }}/KeySetCache + # actions/cache does not permit to update a cache entry + key: ${{ runner.os }}-KeySetCache-2021-12-02 + restore-keys: | + ${{ runner.os }}-KeySetCache- + - name: Build and test compiler uses: addnab/docker-run-action@v3 with: @@ -29,7 +38,7 @@ jobs: image: ghcr.io/zama-ai/zamalang-compiler:latest username: ${{ secrets.GHCR_LOGIN }} password: ${{ secrets.GHCR_PASSWORD }} - options: -v ${{ github.workspace }}/compiler:/compiler + options: -v ${{ github.workspace }}/compiler:/compiler -v ${{ github.workspace }}/KeySetCache:/tmp/KeySetCache shell: bash run: | set -e @@ -42,3 +51,4 @@ jobs: make CCACHE=ON BUILD_DIR=/build test echo "Debug: ccache statistics (after the build):" ccache -s + chmod -R ugo+rwx /tmp/KeySetCache diff --git a/compiler/Makefile b/compiler/Makefile index c9535b645..9d4c89ded 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -43,7 +43,7 @@ test-check: zamacompiler file-check not $(BUILD_DIR)/bin/llvm-lit -v tests/ test-python: python-bindings zamacompiler - PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/zamalang/python_packages/zamalang_core LD_PRELOAD=$(BUILD_DIR)/lib/libZamalangRuntime.so pytest -vs tests/python + PATH=$(BUILD_DIR)/bin:${PATH} PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/zamalang/python_packages/zamalang_core LD_PRELOAD=$(BUILD_DIR)/lib/libZamalangRuntime.so pytest -vs tests/python test: test-check test-end-to-end-jit test-python diff --git a/compiler/include/zamalang-c/Support/CompilerEngine.h b/compiler/include/zamalang-c/Support/CompilerEngine.h index 8c9a422f1..d813406cb 100644 --- a/compiler/include/zamalang-c/Support/CompilerEngine.h +++ b/compiler/include/zamalang-c/Support/CompilerEngine.h @@ -32,10 +32,12 @@ typedef struct executionArguments executionArguments; // Build lambda from a textual representation of an MLIR module // The lambda will have `funcName` as entrypoint, and use runtimeLibPath (if not -// null) as a shared library during compilation +// null) as a shared library during compilation, +// a path to activate the use a cache for encryption keys for test purpose +// (unsecure). MLIR_CAPI_EXPORTED mlir::zamalang::JitCompilerEngine::Lambda buildLambda(const char *module, const char *funcName, - const char *runtimeLibPath); + const char *runtimeLibPath, const char *keySetCachePath); // Parse then print a textual representation of an MLIR module MLIR_CAPI_EXPORTED std::string roundTrip(const char *module); diff --git a/compiler/include/zamalang/Support/ClientParameters.h b/compiler/include/zamalang/Support/ClientParameters.h index 1431da930..a90a645e4 100644 --- a/compiler/include/zamalang/Support/ClientParameters.h +++ b/compiler/include/zamalang/Support/ClientParameters.h @@ -24,6 +24,8 @@ typedef uint64_t GlweDimension; typedef std::string LweSecretKeyID; struct LweSecretKeyParam { LweSize size; + + void hash(size_t &seed); }; typedef std::string BootstrapKeyID; @@ -34,6 +36,8 @@ struct BootstrapKeyParam { DecompositionBaseLog baseLog; GlweDimension glweDimension; Variance variance; + + void hash(size_t &seed); }; typedef std::string KeyswitchKeyID; @@ -43,6 +47,8 @@ struct KeyswitchKeyParam { DecompositionLevelCount level; DecompositionBaseLog baseLog; Variance variance; + + void hash(size_t &seed); }; struct Encoding { @@ -75,11 +81,13 @@ struct ClientParameters { std::map keyswitchKeys; std::vector inputs; std::vector outputs; + size_t hash(); }; llvm::Expected createClientParametersForV0(V0FHEContext context, llvm::StringRef name, mlir::ModuleOp module); + } // namespace zamalang } // namespace mlir diff --git a/compiler/include/zamalang/Support/JitCompilerEngine.h b/compiler/include/zamalang/Support/JitCompilerEngine.h index 018575bea..f132f2678 100644 --- a/compiler/include/zamalang/Support/JitCompilerEngine.h +++ b/compiler/include/zamalang/Support/JitCompilerEngine.h @@ -1,6 +1,7 @@ #ifndef ZAMALANG_SUPPORT_JIT_COMPILER_ENGINE_H #define ZAMALANG_SUPPORT_JIT_COMPILER_ENGINE_H +#include "zamalang/Support/KeySetCache.h" #include #include #include @@ -363,15 +364,18 @@ public: /// Use runtimeLibPath as a shared library if specified. llvm::Expected buildLambda(llvm::StringRef src, llvm::StringRef funcName = "main", + llvm::Optional cachePath = {}, llvm::Optional runtimeLibPath = {}); llvm::Expected buildLambda(std::unique_ptr buffer, llvm::StringRef funcName = "main", + llvm::Optional cachePath = {}, llvm::Optional runtimeLibPath = {}); llvm::Expected buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName = "main", + llvm::Optional cachePath = {}, llvm::Optional runtimeLibPath = {}); protected: diff --git a/compiler/include/zamalang/Support/KeySet.h b/compiler/include/zamalang/Support/KeySet.h index aab217caf..824c48eba 100644 --- a/compiler/include/zamalang/Support/KeySet.h +++ b/compiler/include/zamalang/Support/KeySet.h @@ -10,6 +10,7 @@ extern "C" { } #include "zamalang/Support/ClientParameters.h" +#include "zamalang/Support/KeySetCache.h" namespace mlir { namespace zamalang { @@ -17,6 +18,15 @@ namespace zamalang { class KeySet { public: ~KeySet(); + + static std::unique_ptr uninitialized(); + + llvm::Error generateKeysFromParams(ClientParameters ¶ms, + uint64_t seed_msb, uint64_t seed_lsb); + + llvm::Error setupEncryptionMaterial(ClientParameters ¶ms, + uint64_t seed_msb, uint64_t seed_lsb); + // allocate a KeySet according the ClientParameters. static llvm::Expected> generate(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb); @@ -46,6 +56,18 @@ public: context.bsk = std::get<1>(this->bootstrapKeys["bsk_v0"]); } + const std::map> & + getSecretKeys(); + + const std::map> & + getBootstrapKeys(); + + const std::map> & + getKeyswitchKeys(); + protected: llvm::Error generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param, SecretRandomGenerator *generator); @@ -54,6 +76,8 @@ protected: llvm::Error generateKeyswitchKey(KeyswitchKeyID id, KeyswitchKeyParam param, EncryptionRandomGenerator *generator); + friend class KeySetCache; + private: EncryptionRandomGenerator *encryptionRandomGenerator; std::map> @@ -66,6 +90,16 @@ private: inputs; std::vector> outputs; + + void setKeys( + std::map> + secretKeys, + std::map> + bootstrapKeys, + std::map> + keyswitchKeys); }; } // namespace zamalang diff --git a/compiler/include/zamalang/Support/KeySetCache.h b/compiler/include/zamalang/Support/KeySetCache.h new file mode 100644 index 000000000..64eaea52e --- /dev/null +++ b/compiler/include/zamalang/Support/KeySetCache.h @@ -0,0 +1,31 @@ +#ifndef ZAMALANG_SUPPORT_KEYSETCACHE_H_ +#define ZAMALANG_SUPPORT_KEYSETCACHE_H_ + +#include "zamalang/Support/KeySet.h" + +namespace mlir { +namespace zamalang { + +class KeySet; + +class KeySetCache { + std::string backingDirectoryPath; + +public: + KeySetCache(std::string backingDirectoryPath) + : backingDirectoryPath(backingDirectoryPath) {} + + llvm::Expected> + tryLoadOrGenerateSave(ClientParameters ¶ms, uint64_t seed_msb, + uint64_t seed_lsb); + +private: + static llvm::Expected> + tryLoadKeys(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb, + llvm::SmallString<0> &folderPath); +}; + +} // namespace zamalang +} // namespace mlir + +#endif \ No newline at end of file diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 86543b0a2..8b6b54ef2 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -17,6 +17,10 @@ using mlir::zamalang::JitCompilerEngine; using mlir::zamalang::LambdaArgument; +const char *noEmptyStringPtr(std::string &s) { + return (s.empty()) ? nullptr : s.c_str(); +} + /// Populate the compiler API python module. void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) { m.doc() = "Zamalang compiler python API"; @@ -31,14 +35,14 @@ void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) { pybind11::class_(m, "JitCompilerEngine") .def(pybind11::init()) - .def_static("build_lambda", [](std::string mlir_input, - std::string func_name, - std::string runtime_lib_path) { - if (runtime_lib_path.empty()) - return buildLambda(mlir_input.c_str(), func_name.c_str(), nullptr); - return buildLambda(mlir_input.c_str(), func_name.c_str(), - runtime_lib_path.c_str()); - }); + .def_static("build_lambda", + [](std::string mlir_input, std::string func_name, + std::string runtime_lib_path, + std::string keysetcache_path) { + return buildLambda(mlir_input.c_str(), func_name.c_str(), + noEmptyStringPtr(runtime_lib_path), + noEmptyStringPtr(keysetcache_path)); + }); pybind11::class_(m, "LambdaArgument") .def_static("from_tensor", lambdaArgumentFromTensor) diff --git a/compiler/lib/Bindings/Python/zamalang/compiler.py b/compiler/lib/Bindings/Python/zamalang/compiler.py index 665fd0ba0..927ee5cef 100644 --- a/compiler/lib/Bindings/Python/zamalang/compiler.py +++ b/compiler/lib/Bindings/Python/zamalang/compiler.py @@ -118,7 +118,8 @@ class CompilerEngine: self.compile_fhe(mlir_str) def compile_fhe( - self, mlir_str: str, func_name: str = "main", runtime_lib_path: str = None + self, mlir_str: str, func_name: str = "main", runtime_lib_path: str = None, + unsecure_key_set_cache_path: str = None, ): """Compile the MLIR input. @@ -126,6 +127,7 @@ class CompilerEngine: mlir_str (str): MLIR to compile. func_name (str): name of the function to set as entrypoint (default: main). runtime_lib_path (str): path to the runtime lib (default: None). + unsecure_key_set_cache_path (str): path to the activate keyset caching (default: None). Raises: TypeError: if the argument is not an str. @@ -140,7 +142,14 @@ class CompilerEngine: raise TypeError( "runtime_lib_path must be an str representing the path to the runtime lib" ) - self._lambda = self._engine.build_lambda(mlir_str, func_name, runtime_lib_path) + unsecure_key_set_cache_path = unsecure_key_set_cache_path or "" + if not isinstance(unsecure_key_set_cache_path, str): + raise TypeError( + "unsecure_key_set_cache_path must be a str" + ) + self._lambda = self._engine.build_lambda( + mlir_str, func_name, runtime_lib_path, + unsecure_key_set_cache_path) def run(self, *args: List[Union[int, np.ndarray]]) -> Union[int, np.ndarray]: """Run the compiled code. diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index 560968bcd..1f13206a9 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -1,20 +1,31 @@ +#include "llvm/ADT/SmallString.h" + #include "zamalang-c/Support/CompilerEngine.h" #include "zamalang/Support/CompilerEngine.h" #include "zamalang/Support/Jit.h" #include "zamalang/Support/JitCompilerEngine.h" +#include "zamalang/Support/KeySetCache.h" using mlir::zamalang::JitCompilerEngine; mlir::zamalang::JitCompilerEngine::Lambda buildLambda(const char *module, const char *funcName, - const char *runtimeLibPath) { + const char *runtimeLibPath, const char *keySetCachePath) { // Set the runtime library path if not nullptr llvm::Optional runtimeLibPathOptional = {}; if (runtimeLibPath != nullptr) runtimeLibPathOptional = runtimeLibPath; mlir::zamalang::JitCompilerEngine engine; + + using KeySetCache = mlir::zamalang::KeySetCache; + using optKeySetCache = llvm::Optional; + auto cacheOpt = optKeySetCache(); + if (keySetCachePath != nullptr) { + cacheOpt = KeySetCache(std::string(keySetCachePath)); + } + llvm::Expected lambdaOrErr = - engine.buildLambda(module, funcName, runtimeLibPathOptional); + engine.buildLambda(module, funcName, cacheOpt, runtimeLibPathOptional); if (!lambdaOrErr) { std::string backingString; llvm::raw_string_ostream os(backingString); diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 94516d8ab..51632a41c 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_library(ZamalangSupport KeySet.cpp logging.cpp Jit.cpp + KeySetCache.cpp LLVMEmitFile.cpp ADDITIONAL_HEADER_DIRS diff --git a/compiler/lib/Support/ClientParameters.cpp b/compiler/lib/Support/ClientParameters.cpp index 21bad8035..10c29b984 100644 --- a/compiler/lib/Support/ClientParameters.cpp +++ b/compiler/lib/Support/ClientParameters.cpp @@ -145,5 +145,44 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name, return c; } +// https://stackoverflow.com/a/38140932 +static inline void hash(std::size_t &seed) {} +template +static inline void hash(std::size_t &seed, const T &v, Rest... rest) { + // See https://softwareengineering.stackexchange.com/a/402543 + const auto GOLDEN_RATIO = 0x9e3779b97f4a7c15; // pseudo random bits + const std::hash hasher; + seed ^= hasher(v) + GOLDEN_RATIO + (seed << 6) + (seed >> 2); + hash(seed, rest...); +} + +void LweSecretKeyParam::hash(size_t &seed) { mlir::zamalang::hash(seed, size); } + +void BootstrapKeyParam::hash(size_t &seed) { + mlir::zamalang::hash(seed, inputSecretKeyID, outputSecretKeyID, level, + baseLog, glweDimension, variance); +} + +void KeyswitchKeyParam::hash(size_t &seed) { + mlir::zamalang::hash(seed, inputSecretKeyID, outputSecretKeyID, level, + baseLog, variance); +} + +std::size_t ClientParameters::hash() { + std::size_t currentHash = 1; + for (auto secretKeyParam : secretKeys) { + mlir::zamalang::hash(currentHash, secretKeyParam.first); + secretKeyParam.second.hash(currentHash); + } + for (auto bootstrapKeyParam : bootstrapKeys) { + mlir::zamalang::hash(currentHash, bootstrapKeyParam.first); + bootstrapKeyParam.second.hash(currentHash); + } + for (auto keyswitchParam : keyswitchKeys) { + mlir::zamalang::hash(currentHash, keyswitchParam.first); + keyswitchParam.second.hash(currentHash); + } + return currentHash; +} } // namespace zamalang } // namespace mlir diff --git a/compiler/lib/Support/JitCompilerEngine.cpp b/compiler/lib/Support/JitCompilerEngine.cpp index eb3859631..6e04991a5 100644 --- a/compiler/lib/Support/JitCompilerEngine.cpp +++ b/compiler/lib/Support/JitCompilerEngine.cpp @@ -37,25 +37,22 @@ JitCompilerEngine::findLLVMFuncOp(mlir::ModuleOp module, llvm::StringRef name) { llvm::Expected JitCompilerEngine::buildLambda(std::unique_ptr buffer, llvm::StringRef funcName, + llvm::Optional cache, llvm::Optional runtimeLibPath) { llvm::SourceMgr sm; - sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); - - llvm::Expected res = - this->buildLambda(sm, funcName, runtimeLibPath); - - return std::move(res); + return this->buildLambda(sm, funcName, cache, runtimeLibPath); } // Build a lambda from the function with the name given in `funcName` // from the source string `s`. llvm::Expected JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName, + llvm::Optional cache, llvm::Optional runtimeLibPath) { std::unique_ptr mb = llvm::MemoryBuffer::getMemBuffer(s); llvm::Expected res = - this->buildLambda(std::move(mb), funcName, runtimeLibPath); + this->buildLambda(std::move(mb), funcName, cache, runtimeLibPath); return std::move(res); } @@ -64,6 +61,7 @@ JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName, // `funcName` from the sources managed by the source manager `sm`. llvm::Expected JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName, + llvm::Optional cache, llvm::Optional runtimeLibPath) { MLIRContext &mlirContext = *this->compilationContext->getMLIRContext(); @@ -77,14 +75,17 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName, if (!compResOrErr) return std::move(compResOrErr.takeError()); - mlir::ModuleOp module = compResOrErr->mlirModuleRef->get(); + auto compRes = std::move(compResOrErr.get()); + + mlir::ModuleOp module = compRes.mlirModuleRef->get(); // Locate function to JIT-compile llvm::Expected funcOrError = - this->findLLVMFuncOp(compResOrErr->mlirModuleRef->get(), funcName); + this->findLLVMFuncOp(compRes.mlirModuleRef->get(), funcName); if (!funcOrError) - return std::move(funcOrError.takeError()); + return StreamStringError() << "Cannot find function \"" << funcName + << "\": " << std::move(funcOrError.takeError()); // Prepare LLVM infrastructure for JIT compilation llvm::InitializeNativeTarget(); @@ -98,24 +99,32 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName, mlir::zamalang::JITLambda::create(funcName, module, optPipeline, runtimeLibPath); + if (!lambdaOrErr) { + return StreamStringError() + << "Cannot create lambda: " << lambdaOrErr.takeError(); + } + + auto lambda = std::move(lambdaOrErr.get()); + // Generate the KeySet for encrypting lambda arguments, decrypting lambda // results - if (!compResOrErr->clientParameters.hasValue()) { + if (!compRes.clientParameters.hasValue()) { return StreamStringError("Cannot generate the keySet since client " "parameters has not been computed"); } llvm::Expected> keySetOrErr = - mlir::zamalang::KeySet::generate(*compResOrErr->clientParameters, 0, 0); + (cache.hasValue()) + ? cache->tryLoadOrGenerateSave(*compRes.clientParameters, 0, 0) + : KeySet::generate(*compRes.clientParameters, 0, 0); - if (auto err = keySetOrErr.takeError()) - return std::move(err); + if (!keySetOrErr) { + return keySetOrErr.takeError(); + } - if (!lambdaOrErr) - return std::move(lambdaOrErr.takeError()); + auto keySet = std::move(keySetOrErr.get()); - return Lambda{this->compilationContext, std::move(lambdaOrErr.get()), - std::move(*keySetOrErr)}; + return Lambda{this->compilationContext, std::move(lambda), std::move(keySet)}; } } // namespace zamalang diff --git a/compiler/lib/Support/KeySet.cpp b/compiler/lib/Support/KeySet.cpp index b82186c74..87a420d9f 100644 --- a/compiler/lib/Support/KeySet.cpp +++ b/compiler/lib/Support/KeySet.cpp @@ -1,4 +1,5 @@ #include "zamalang/Support/KeySet.h" +#include "zamalang/Support/Error.h" #define CAPI_ERR_TO_LLVM_ERROR(s, msg) \ { \ @@ -30,7 +31,79 @@ KeySet::~KeySet() { llvm::Expected> KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb) { - auto keySet = std::make_unique(); + + auto a = uninitialized(); + + auto fillError = a->generateKeysFromParams(params, seed_msb, seed_lsb); + + if (fillError) { + return StreamStringError() + << "Cannot fill keys from params: " << std::move(fillError); + } + + fillError = a->setupEncryptionMaterial(params, seed_msb, seed_lsb); + + if (fillError) { + return StreamStringError() + << "Cannot setup encryption material: " << std::move(fillError); + } + + return a; +} + +std::unique_ptr KeySet::uninitialized() { + return std::make_unique(); +} + +llvm::Error KeySet::setupEncryptionMaterial(ClientParameters ¶ms, + uint64_t seed_msb, + uint64_t seed_lsb) { + // Set inputs and outputs LWE secret keys + { + for (auto param : params.inputs) { + std::tuple input = { + param, nullptr, nullptr}; + if (param.encryption.hasValue()) { + auto inputSk = this->secretKeys.find(param.encryption->secretKeyID); + if (inputSk == this->secretKeys.end()) { + return llvm::make_error( + "input encryption secret key (" + param.encryption->secretKeyID + + ") does not exist ", + llvm::inconvertibleErrorCode()); + } + std::get<1>(input) = &inputSk->second.first; + std::get<2>(input) = inputSk->second.second; + } + this->inputs.push_back(input); + } + for (auto param : params.outputs) { + std::tuple output = + {param, nullptr, nullptr}; + if (param.encryption.hasValue()) { + auto outputSk = this->secretKeys.find(param.encryption->secretKeyID); + if (outputSk == this->secretKeys.end()) { + return llvm::make_error( + "cannot find output key to generate bootstrap key", + llvm::inconvertibleErrorCode()); + } + std::get<1>(output) = &outputSk->second.first; + std::get<2>(output) = outputSk->second.second; + } + this->outputs.push_back(output); + } + } + int err; + CAPI_ERR_TO_LLVM_ERROR( + this->encryptionRandomGenerator = + allocate_encryption_generator(&err, seed_msb, seed_lsb), + "cannot allocate encryption generator"); + + return llvm::Error::success(); +} + +llvm::Error KeySet::generateKeysFromParams(ClientParameters ¶ms, + uint64_t seed_msb, + uint64_t seed_lsb) { { // Generate LWE secret keys @@ -39,8 +112,8 @@ KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb, generator = allocate_secret_generator(&err, seed_msb, seed_lsb), "cannot allocate random generator"); for (auto secretKeyParam : params.secretKeys) { - auto e = keySet->generateSecretKey(secretKeyParam.first, - secretKeyParam.second, generator); + auto e = this->generateSecretKey(secretKeyParam.first, + secretKeyParam.second, generator); if (e) { return std::move(e); } @@ -50,62 +123,43 @@ KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb, } // Allocate the encryption random generator CAPI_ERR_TO_LLVM_ERROR( - keySet->encryptionRandomGenerator = + this->encryptionRandomGenerator = allocate_encryption_generator(&err, seed_msb, seed_lsb), "cannot allocate encryption generator"); // Generate bootstrap and keyswitch keys { for (auto bootstrapKeyParam : params.bootstrapKeys) { - auto e = keySet->generateBootstrapKey(bootstrapKeyParam.first, - bootstrapKeyParam.second, - keySet->encryptionRandomGenerator); + auto e = this->generateBootstrapKey(bootstrapKeyParam.first, + bootstrapKeyParam.second, + this->encryptionRandomGenerator); if (e) { return std::move(e); } } for (auto keyswitchParam : params.keyswitchKeys) { - auto e = keySet->generateKeyswitchKey(keyswitchParam.first, - keyswitchParam.second, - keySet->encryptionRandomGenerator); + auto e = this->generateKeyswitchKey(keyswitchParam.first, + keyswitchParam.second, + this->encryptionRandomGenerator); if (e) { return std::move(e); } } } - // Set inputs and outputs LWE secret keys - { - for (auto param : params.inputs) { - std::tuple input = { - param, nullptr, nullptr}; - if (param.encryption.hasValue()) { - auto inputSk = keySet->secretKeys.find(param.encryption->secretKeyID); - if (inputSk == keySet->secretKeys.end()) { - return llvm::make_error( - "cannot find input key to generate bootstrap key", - llvm::inconvertibleErrorCode()); - } - std::get<1>(input) = &inputSk->second.first; - std::get<2>(input) = inputSk->second.second; - } - keySet->inputs.push_back(input); - } - for (auto param : params.outputs) { - std::tuple output = - {param, nullptr, nullptr}; - if (param.encryption.hasValue()) { - auto outputSk = keySet->secretKeys.find(param.encryption->secretKeyID); - if (outputSk == keySet->secretKeys.end()) { - return llvm::make_error( - "cannot find output key to generate bootstrap key", - llvm::inconvertibleErrorCode()); - } - std::get<1>(output) = &outputSk->second.first; - std::get<2>(output) = outputSk->second.second; - } - keySet->outputs.push_back(output); - } - } - return std::move(keySet); + return llvm::Error::success(); +} + +void KeySet::setKeys( + std::map> + secretKeys, + std::map> + bootstrapKeys, + std::map> + keyswitchKeys) { + this->secretKeys = secretKeys; + this->bootstrapKeys = bootstrapKeys; + this->keyswitchKeys = keyswitchKeys; } llvm::Error KeySet::generateSecretKey(LweSecretKeyID id, @@ -120,6 +174,7 @@ llvm::Error KeySet::generateSecretKey(LweSecretKeyID id, "cannot fill secret key with random generator"); secretKeys[id] = {param, sk}; + return llvm::Error::success(); } @@ -136,7 +191,7 @@ llvm::Error KeySet::generateBootstrapKey(BootstrapKeyID id, auto outputSk = secretKeys.find(param.outputSecretKeyID); if (outputSk == secretKeys.end()) { return llvm::make_error( - "cannot find input key to generate bootstrap key", + "cannot find output key to generate bootstrap key", llvm::inconvertibleErrorCode()); } // Allocate the bootstrap key @@ -291,5 +346,22 @@ llvm::Error KeySet::decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext, return llvm::Error::success(); } +const std::map> + &KeySet::getSecretKeys() { + return secretKeys; +} + +const std::map> & +KeySet::getBootstrapKeys() { + return bootstrapKeys; +} + +const std::map> & +KeySet::getKeyswitchKeys() { + return keyswitchKeys; +} + } // namespace zamalang } // namespace mlir \ No newline at end of file diff --git a/compiler/lib/Support/KeySetCache.cpp b/compiler/lib/Support/KeySetCache.cpp new file mode 100644 index 000000000..e433b40b1 --- /dev/null +++ b/compiler/lib/Support/KeySetCache.cpp @@ -0,0 +1,193 @@ +#include "zamalang/Support/KeySetCache.h" +#include "zamalang/Support/Error.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/Path.h" +#include +#include +extern "C" { +#include "concrete-ffi.h" +} + +namespace mlir { +namespace zamalang { + +static std::string readFile(llvm::SmallString<0> &path) { + std::ifstream in((std::string)path, std::ofstream::binary); + std::stringstream sbuffer; + sbuffer << in.rdbuf(); + return sbuffer.str(); +} + +static void writeFile(llvm::SmallString<0> &path, Buffer content) { + std::ofstream out((std::string)path, std::ofstream::binary); + out.write((const char *)content.pointer, content.length); + out.close(); +} + +LweSecretKey_u64 *loadSecretKey(llvm::SmallString<0> &path) { + std::string content = readFile(path); + BufferView buffer = {(const uint8_t *)content.c_str(), content.length()}; + return deserialize_lwe_secret_key_u64(buffer); +} + +LweKeyswitchKey_u64 *loadKeyswitchKey(llvm::SmallString<0> &path) { + std::string content = readFile(path); + BufferView buffer = {(const uint8_t *)content.c_str(), content.length()}; + return deserialize_lwe_keyswitching_key_u64(buffer); +} + +LweBootstrapKey_u64 *loadBootstrapKey(llvm::SmallString<0> &path) { + std::string content = readFile(path); + BufferView buffer = {(const uint8_t *)content.c_str(), content.length()}; + return deserialize_lwe_bootstrap_key_u64(buffer); +} + +void saveSecretKey(llvm::SmallString<0> &path, LweSecretKey_u64 *key) { + Buffer buffer = serialize_lwe_secret_key_u64(key); + writeFile(path, buffer); + free(buffer.pointer); +} + +void saveBootstrapKey(llvm::SmallString<0> &path, LweBootstrapKey_u64 *key) { + Buffer buffer = serialize_lwe_bootstrap_key_u64(key); + writeFile(path, buffer); + free(buffer.pointer); +} + +void saveKeyswitchKey(llvm::SmallString<0> &path, LweKeyswitchKey_u64 *key) { + Buffer buffer = serialize_lwe_keyswitching_key_u64(key); + writeFile(path, buffer); + free(buffer.pointer); +} + +llvm::Expected> +KeySetCache::tryLoadKeys(ClientParameters ¶ms, uint64_t seed_msb, + uint64_t seed_lsb, llvm::SmallString<0> &folderPath) { + // TODO: text dump of all parameter in /hash + auto key_set = KeySet::uninitialized(); + + std::map> + secretKeys; + std::map> + bootstrapKeys; + std::map> + keyswitchKeys; + + // Load LWE secret keys + for (auto secretKeyParam : params.secretKeys) { + auto id = secretKeyParam.first; + auto param = secretKeyParam.second; + llvm::SmallString<0> path = folderPath; + llvm::sys::path::append(path, "secretKey_" + id); + LweSecretKey_u64 *sk = loadSecretKey(path); + secretKeys[id] = {param, sk}; + } + // Load bootstrap keys + for (auto bootstrapKeyParam : params.bootstrapKeys) { + auto id = bootstrapKeyParam.first; + auto param = bootstrapKeyParam.second; + llvm::SmallString<0> path = folderPath; + llvm::sys::path::append(path, "pbsKey_" + id); + LweBootstrapKey_u64 *bsk = loadBootstrapKey(path); + bootstrapKeys[id] = {param, bsk}; + } + // Load keyswitch keys + for (auto keyswitchParam : params.keyswitchKeys) { + auto id = keyswitchParam.first; + auto param = keyswitchParam.second; + llvm::SmallString<0> path = folderPath; + llvm::sys::path::append(path, "ksKey_" + id); + LweKeyswitchKey_u64 *ksk = loadKeyswitchKey(path); + keyswitchKeys[id] = {param, ksk}; + } + + key_set->setKeys(secretKeys, bootstrapKeys, keyswitchKeys); + + auto err = key_set->setupEncryptionMaterial(params, seed_msb, seed_lsb); + if (err) { + return StreamStringError() << "Cannot setup encryption material: " << err; + } + + return key_set; +} + +llvm::Error saveKeys(KeySet &key_set, llvm::SmallString<0> &folderPath) { + llvm::SmallString<0> folderIncompletePath = folderPath; + + folderIncompletePath.append(".incomplete"); + + auto err = llvm::sys::fs::create_directories(folderIncompletePath); + if (err) { + return StreamStringError() + << "Cannot create directory \"" << folderIncompletePath + << "\": " << err.message(); + } + + // Save LWE secret keys + for (auto secretKeyParam : key_set.getSecretKeys()) { + auto id = secretKeyParam.first; + auto key = secretKeyParam.second.second; + llvm::SmallString<0> path = folderIncompletePath; + llvm::sys::path::append(path, "secretKey_" + id); + saveSecretKey(path, key); + } + // Save bootstrap keys + for (auto bootstrapKeyParam : key_set.getBootstrapKeys()) { + auto id = bootstrapKeyParam.first; + auto key = bootstrapKeyParam.second.second; + llvm::SmallString<0> path = folderIncompletePath; + llvm::sys::path::append(path, "pbsKey_" + id); + saveBootstrapKey(path, key); + } + // Save keyswitch keys + for (auto keyswitchParam : key_set.getKeyswitchKeys()) { + auto id = keyswitchParam.first; + auto key = keyswitchParam.second.second; + llvm::SmallString<0> path = folderIncompletePath; + llvm::sys::path::append(path, "ksKey_" + id); + saveKeyswitchKey(path, key); + } + + err = llvm::sys::fs::rename(folderIncompletePath, folderPath); + if (err) { + return StreamStringError() + << "Cannot rename directory \"" << folderIncompletePath << "\" \"" + << folderPath << "\": " << err.message(); + } + + return llvm::Error::success(); +} + +llvm::Expected> +KeySetCache::tryLoadOrGenerateSave(ClientParameters ¶ms, uint64_t seed_msb, + uint64_t seed_lsb) { + + llvm::SmallString<0> folderPath = + llvm::SmallString<0>(this->backingDirectoryPath); + + llvm::sys::path::append(folderPath, std::to_string(params.hash())); + + llvm::sys::path::append(folderPath, std::to_string(seed_msb) + "_" + + std::to_string(seed_lsb)); + + if (llvm::sys::fs::exists(folderPath)) { + return tryLoadKeys(params, seed_msb, seed_lsb, folderPath); + } else { + auto key_set = KeySet::generate(params, seed_msb, seed_lsb); + + if (!key_set) { + return StreamStringError() + << "Cannot generate key set: " << key_set.takeError(); + } + + auto savedErr = saveKeys(*(key_set.get()), folderPath); + if (savedErr) { + return StreamStringError() << "Cannot save key set: " << savedErr; + } + + return key_set; + } +} + +} // namespace zamalang +} // namespace mlir \ No newline at end of file diff --git a/compiler/tests/python/test_compiler_engine.py b/compiler/tests/python/test_compiler_engine.py index 206ae0259..7ad500d20 100644 --- a/compiler/tests/python/test_compiler_engine.py +++ b/compiler/tests/python/test_compiler_engine.py @@ -1,9 +1,11 @@ import os +import tempfile import pytest import numpy as np from zamalang import CompilerEngine, library +KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache') @pytest.mark.parametrize( "mlir_input, args, expected_result", @@ -104,7 +106,7 @@ from zamalang import CompilerEngine, library ) def test_compile_and_run(mlir_input, args, expected_result): engine = CompilerEngine() - engine.compile_fhe(mlir_input) + engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH) if isinstance(expected_result, int): assert engine.run(*args) == expected_result else: @@ -129,7 +131,7 @@ def test_compile_and_run(mlir_input, args, expected_result): ) def test_compile_and_run_invalid_arg_number(mlir_input, args): engine = CompilerEngine() - engine.compile_fhe(mlir_input) + engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH) with pytest.raises(ValueError, match=r"wrong number of arguments"): engine.run(*args) @@ -154,7 +156,7 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args): ) def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size): engine = CompilerEngine() - engine.compile_fhe(mlir_input) + engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH) assert abs(engine.run(*args) - expected_result) / tab_size < 0.1 @@ -177,7 +179,7 @@ def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size): def test_compile_invalid(mlir_input): engine = CompilerEngine() with pytest.raises(RuntimeError, match=r"Compilation failed:"): - engine.compile_fhe(mlir_input) + engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH) MODULE_1 = """ diff --git a/compiler/tests/unittest/end_to_end_jit_test.h b/compiler/tests/unittest/end_to_end_jit_test.h index e3695d5ae..94e23de6a 100644 --- a/compiler/tests/unittest/end_to_end_jit_test.h +++ b/compiler/tests/unittest/end_to_end_jit_test.h @@ -5,6 +5,9 @@ #include "zamalang/Support/CompilerEngine.h" #include "zamalang/Support/JitCompilerEngine.h" +#include "zamalang/Support/KeySetCache.h" +#include "llvm/Support/Path.h" + #include "globals.h" #define ASSERT_LLVM_ERROR(err) \ @@ -80,18 +83,32 @@ static bool assert_expected_value(llvm::Expected &&val, const V &exp) { // and reult in abnormal termination. template mlir::zamalang::JitCompilerEngine::Lambda -internalCheckedJit(F checkfunc, llvm::StringRef src, +internalCheckedJit(F checkFunc, llvm::StringRef src, llvm::StringRef func = "main", bool useDefaultFHEConstraints = false) { + + llvm::SmallString<0> cachePath; + + llvm::sys::path::system_temp_directory(true, cachePath); + + llvm::sys::path::append(cachePath, "KeySetCache"); + + auto cachePathStr = std::string(cachePath); + auto optCache = llvm::Optional( + mlir::zamalang::KeySetCache(cachePathStr)); + mlir::zamalang::JitCompilerEngine engine; if (useDefaultFHEConstraints) engine.setFHEConstraints(defaultV0Constraints); llvm::Expected lambdaOrErr = - engine.buildLambda(src, func); + engine.buildLambda(src, func, optCache); - checkfunc(lambdaOrErr); + if (!lambdaOrErr) { + std::cout << llvm::toString(lambdaOrErr.takeError()) << std::endl; + } + checkFunc(lambdaOrErr); return std::move(*lambdaOrErr); } @@ -116,4 +133,4 @@ static inline uint64_t operator"" _u64(unsigned long long int v) { return v; } }, \ __VA_ARGS__) -#endif \ No newline at end of file +#endif