From af0789f1287c08020d8d2c5fa62b4dd52c50f486 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Tue, 24 Aug 2021 15:02:45 +0200 Subject: [PATCH] enhance(compiler/runtime): Add runtime tools to handle tensor inputs and outputs --- .../Conversion/Utils/GlobalFHEContext.h | 4 +- .../zamalang/Support/ClientParameters.h | 5 +- .../include/zamalang/Support/CompilerEngine.h | 15 +- .../include/zamalang/Support/CompilerTools.h | 41 +- compiler/include/zamalang/Support/KeySet.h | 3 + compiler/lib/Support/ClientParameters.cpp | 25 +- compiler/lib/Support/CompilerEngine.cpp | 54 +-- compiler/lib/Support/CompilerTools.cpp | 258 +++++++++++-- compiler/lib/Support/KeySet.cpp | 15 +- compiler/tests/unittest/CMakeLists.txt | 4 + compiler/tests/unittest/hello_test.cc | 350 +++++++++++++++++- 11 files changed, 701 insertions(+), 73 deletions(-) diff --git a/compiler/include/zamalang/Conversion/Utils/GlobalFHEContext.h b/compiler/include/zamalang/Conversion/Utils/GlobalFHEContext.h index 85aa2a0eb..eea6d3b65 100644 --- a/compiler/include/zamalang/Conversion/Utils/GlobalFHEContext.h +++ b/compiler/include/zamalang/Conversion/Utils/GlobalFHEContext.h @@ -1,5 +1,5 @@ -#ifndef ZAMALANG_CONVERSION_GLOBALFHECONTEXT_PATTERNS_H_ -#define ZAMALANG_CONVERSION_GLOBALFHECONTEXT_PATTERNS_H_ +#ifndef ZAMALANG_CONVERSION_GLOBALFHECONTEXT_H_ +#define ZAMALANG_CONVERSION_GLOBALFHECONTEXT_H_ #include namespace mlir { diff --git a/compiler/include/zamalang/Support/ClientParameters.h b/compiler/include/zamalang/Support/ClientParameters.h index af94647f7..69c2138f8 100644 --- a/compiler/include/zamalang/Support/ClientParameters.h +++ b/compiler/include/zamalang/Support/ClientParameters.h @@ -56,7 +56,10 @@ struct EncryptionGate { }; struct CircuitGateShape { - uint64_t size; + // Width of the scalar value + size_t width; + // Size of the buffer + size_t size; }; struct CircuitGate { diff --git a/compiler/include/zamalang/Support/CompilerEngine.h b/compiler/include/zamalang/Support/CompilerEngine.h index 5bdf7171d..b380a850b 100644 --- a/compiler/include/zamalang/Support/CompilerEngine.h +++ b/compiler/include/zamalang/Support/CompilerEngine.h @@ -15,6 +15,9 @@ namespace mlir { namespace zamalang { + +/// CompilerEngine is an tools that provides tools to implements the compilation +/// flow and manage the compilation flow state. class CompilerEngine { public: CompilerEngine() { @@ -26,10 +29,16 @@ public: delete context; } - // Compile an MLIR input - llvm::Expected compileFHE(std::string mlir_input); + // Compile an mlir programs from it's textual representation. + llvm::Error compile(std::string mlirStr); - // Run the compiled module + // Build the jit lambda argument. + llvm::Expected> buildArgument(); + + // Call the compiled function with and argument object. + llvm::Error invoke(JITLambda::Argument &arg); + + // Call the compiled function with a list of integer arguments. llvm::Expected run(std::vector args); // Get a printable representation of the compiled module diff --git a/compiler/include/zamalang/Support/CompilerTools.h b/compiler/include/zamalang/Support/CompilerTools.h index 6a33dd44c..c69e27246 100644 --- a/compiler/include/zamalang/Support/CompilerTools.h +++ b/compiler/include/zamalang/Support/CompilerTools.h @@ -51,17 +51,54 @@ public: // and decryption operations. static llvm::Expected> create(KeySet &keySet); - // Set the argument at the given pos as a uint64_t. + // Set a scalar argument at the given pos as a uint64_t. llvm::Error setArg(size_t pos, uint64_t arg); + // Set a argument at the given pos as a tensor of int64. + llvm::Error setArg(size_t pos, uint64_t *data, size_t size) { + return setArg(pos, 64, (void *)data, size); + } + + // Set a argument at the given pos as a tensor of int32. + llvm::Error setArg(size_t pos, uint32_t *data, size_t size) { + return setArg(pos, 32, (void *)data, size); + } + + // Set a argument at the given pos as a tensor of int32. + llvm::Error setArg(size_t pos, uint16_t *data, size_t size) { + return setArg(pos, 16, (void *)data, size); + } + + // Set a tensor argument at the given pos as a uint64_t. + llvm::Error setArg(size_t pos, uint8_t *data, size_t size) { + return setArg(pos, 8, (void *)data, size); + } + // Get the result at the given pos as an uint64_t. llvm::Error getResult(size_t pos, uint64_t &res); + // Fill the result. + llvm::Error getResult(size_t pos, uint64_t *res, size_t size); + private: + llvm::Error setArg(size_t pos, size_t width, void *data, size_t size); + friend JITLambda; + // Store the pointer on inputs values and outputs values std::vector rawArg; + // Store the values of inputs std::vector inputs; - std::vector results; + // Store the values of outputs + std::vector outputs; + // Store the input gates description and the offset of the argument. + std::vector> inputGates; + // Store the outputs gates description and the offset of the argument. + std::vector> outputGates; + // Store allocated lwe ciphertexts (for free) + std::vector allocatedCiphertexts; + // Store buffers of ciphertexts + std::vector ciphertextBuffers; + KeySet &keySet; }; JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name) diff --git a/compiler/include/zamalang/Support/KeySet.h b/compiler/include/zamalang/Support/KeySet.h index dcb290d50..88ce5364c 100644 --- a/compiler/include/zamalang/Support/KeySet.h +++ b/compiler/include/zamalang/Support/KeySet.h @@ -37,6 +37,9 @@ public: size_t numInputs() { return inputs.size(); } size_t numOutputs() { return outputs.size(); } + CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); } + CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); } + protected: llvm::Error generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param, SecretRandomGenerator *generator); diff --git a/compiler/lib/Support/ClientParameters.cpp b/compiler/lib/Support/ClientParameters.cpp index 841dbb1f2..8eea22d7f 100644 --- a/compiler/lib/Support/ClientParameters.cpp +++ b/compiler/lib/Support/ClientParameters.cpp @@ -14,12 +14,25 @@ llvm::Expected gateFromMLIRType(std::string secretKeyID, Precision precision, mlir::Type type) { if (type.isIntOrIndex()) { + // TODO - The index type is dependant of the target architecture, so + // actually we assume we target only 64 bits, we need to have some the size + // of the word of the target system. + size_t width = 64; + if (!type.isIndex()) { + width = type.getIntOrFloatBitWidth(); + } return CircuitGate{ .encryption = llvm::None, - .shape = {.size = 0}, + .shape = + { + .width = width, + .size = 0, + }, }; } if (type.isa()) { + // TODO - Get the width from the LWECiphertextType instead of global + // precision (could be possible after merge lowlfhe-ciphertext-parameter) return CircuitGate{ .encryption = llvm::Optional({ .secretKeyID = secretKeyID, @@ -27,17 +40,17 @@ llvm::Expected gateFromMLIRType(std::string secretKeyID, .variance = 0., .encoding = {.precision = precision}, }), - .shape = {.size = 0}, + .shape = {.width = precision, .size = 0}, }; } - auto memref = type.dyn_cast_or_null(); - if (memref != nullptr) { + auto tensor = type.dyn_cast_or_null(); + if (tensor != nullptr) { auto gate = - gateFromMLIRType(secretKeyID, precision, memref.getElementType()); + gateFromMLIRType(secretKeyID, precision, tensor.getElementType()); if (auto err = gate.takeError()) { return std::move(err); } - gate->shape.size = memref.getDimSize(0); + gate->shape.size = tensor.getDimSize(0); return gate; } return llvm::make_error( diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 69477f65f..8a015291c 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -23,9 +23,8 @@ std::string CompilerEngine::getCompiledModule() { return os.str(); } -llvm::Expected -CompilerEngine::compileFHE(std::string mlir_input) { - module_ref = mlir::parseSourceString(mlir_input, context); +llvm::Error CompilerEngine::compile(std::string mlirStr) { + module_ref = mlir::parseSourceString(mlirStr, context); if (!module_ref) { return llvm::make_error("mlir parsing failed", llvm::inconvertibleErrorCode()); @@ -60,29 +59,44 @@ CompilerEngine::compileFHE(std::string mlir_input) { return llvm::make_error( "failed to lower to LLVM dialect", llvm::inconvertibleErrorCode()); } - return mlir::success(); + return llvm::Error::success(); } -llvm::Expected CompilerEngine::run(std::vector args) { +llvm::Expected> +CompilerEngine::buildArgument() { + if (keySet.get() == nullptr) { + return llvm::make_error( + "CompilerEngine::buildArgument: invalid engine state, the keySet has " + "not be generated", + llvm::inconvertibleErrorCode()); + } + return JITLambda::Argument::create(*keySet); +} + +llvm::Error CompilerEngine::invoke(JITLambda::Argument &arg) { // Create the JIT lambda auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr); auto module = module_ref.get(); auto maybeLambda = mlir::zamalang::JITLambda::create("main", module, defaultOptPipeline); - if (!maybeLambda) { - return llvm::make_error("couldn't create lambda", - llvm::inconvertibleErrorCode()); + if (auto err = maybeLambda.takeError()) { + return std::move(err); } - auto lambda = std::move(maybeLambda.get()); + // Invoke the lambda + if (auto err = maybeLambda.get()->invoke(arg)) { + return std::move(err); + } + return llvm::Error::success(); +} - // Create the arguments of the JIT lambda - auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(*keySet); - if (auto err = maybeArguments.takeError()) { - return llvm::make_error("cannot create lambda args", - llvm::inconvertibleErrorCode()); +llvm::Expected CompilerEngine::run(std::vector args) { + // Build the argument of the JIT lambda. + auto maybeArgument = buildArgument(); + if (auto err = maybeArgument.takeError()) { + return std::move(err); } - // Set the arguments - auto arguments = std::move(maybeArguments.get()); + // Set the integer arguments + auto arguments = std::move(maybeArgument.get()); for (auto i = 0; i < args.size(); i++) { if (auto err = arguments->setArg(i, args[i])) { return llvm::make_error( @@ -90,14 +104,12 @@ llvm::Expected CompilerEngine::run(std::vector args) { } } // Invoke the lambda - if (lambda->invoke(*arguments)) { - return llvm::make_error("failed execution", - llvm::inconvertibleErrorCode()); + if (auto err = invoke(*arguments)) { + return std::move(err); } uint64_t res = 0; if (auto err = arguments->getResult(0, res)) { - return llvm::make_error("cannot get result", - llvm::inconvertibleErrorCode()); + return std::move(err); } return res; } diff --git a/compiler/lib/Support/CompilerTools.cpp b/compiler/lib/Support/CompilerTools.cpp index 829945d95..99849468c 100644 --- a/compiler/lib/Support/CompilerTools.cpp +++ b/compiler/lib/Support/CompilerTools.cpp @@ -141,10 +141,16 @@ JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module, } llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef args) { - if (this->type.getNumParams() != args.size() - 1) { - return llvm::make_error( - "invokeRaw: wrong number of argument", llvm::inconvertibleErrorCode()); - } + size_t nbReturn = 0; + // TODO - This check break with memref as we have 5 returns args. + // if (!this->type.getReturnType().isa()) { + // nbReturn = 1; + // } + // if (this->type.getNumParams() != args.size() - nbReturn) { + // return llvm::make_error( + // "invokeRaw: wrong number of argument", + // llvm::inconvertibleErrorCode()); + // } if (llvm::find(args, nullptr) != args.end()) { return llvm::make_error( "invoke: some arguments are null", llvm::inconvertibleErrorCode()); @@ -157,24 +163,58 @@ llvm::Error JITLambda::invoke(Argument &args) { } JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) { - inputs = std::vector(keySet.numInputs()); - results = std::vector(keySet.numOutputs()); + // Setting the inputs + { + auto numInputs = 0; + for (size_t i = 0; i < keySet.numInputs(); i++) { + auto offset = numInputs; + auto gate = keySet.inputGate(i); + inputGates.push_back({gate, offset}); + if (keySet.inputGate(i).shape.size == 0) { + // scalar gate + numInputs = numInputs + 1; + continue; + } + // memref gate, as we follow the standard calling convention + numInputs = numInputs + 5; + } + inputs = std::vector(numInputs); + } + + // Setting the outputs + { + auto numOutputs = 0; + for (size_t i = 0; i < keySet.numOutputs(); i++) { + auto offset = numOutputs; + auto gate = keySet.outputGate(i); + outputGates.push_back({gate, offset}); + if (gate.shape.size == 0) { + // scalar gate + numOutputs = numOutputs + 1; + continue; + } + // memref gate, as we follow the standard calling convention + numOutputs = numOutputs + 5; + } + outputs = std::vector(numOutputs); + } + // The raw argument contains pointers to inputs and pointers to store the // results - rawArg = - std::vector(keySet.numInputs() + keySet.numOutputs(), nullptr); - // Set the results pointer on the rawArg - for (auto i = keySet.numInputs(); i < rawArg.size(); i++) { - rawArg[i] = &results[i - keySet.numInputs()]; + rawArg = std::vector(inputs.size() + outputs.size(), nullptr); + // Set the pointer on outputs on rawArg + for (auto i = inputs.size(); i < rawArg.size(); i++) { + rawArg[i] = &outputs[i - inputs.size()]; } } JITLambda::Argument::~Argument() { int err; - for (auto i = 0; i < keySet.numInputs(); i++) { - if (keySet.isInputEncrypted(i)) { - free_lwe_ciphertext_u64(&err, (LweCiphertext_u64 *)(inputs[i])); - } + for (auto ct : allocatedCiphertexts) { + free_lwe_ciphertext_u64(&err, ct); + } + for (auto buffer : ciphertextBuffers) { + free(buffer); } } @@ -185,38 +225,206 @@ JITLambda::Argument::create(KeySet &keySet) { } llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) { + if (pos >= inputGates.size()) { + return llvm::make_error( + llvm::Twine("argument index out of bound: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + auto gate = inputGates[pos]; + auto info = std::get<0>(gate); + auto offset = std::get<1>(gate); + + // Check is the argument is a scalar + if (info.shape.size != 0) { + return llvm::make_error( + llvm::Twine("argument is not a scalar: pos=").concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + // If argument is not encrypted, just save. - if (!keySet.isInputEncrypted(pos)) { - inputs[pos] = (void *)arg; - rawArg[pos] = &inputs[pos]; + if (!info.encryption.hasValue()) { + inputs[offset] = (void *)arg; + rawArg[offset] = &inputs[offset]; return llvm::Error::success(); } - // Else if is encryted, allocate ciphertext. + // Else if is encryted, allocate ciphertext and encrypt. LweCiphertext_u64 *ctArg; if (auto err = this->keySet.allocate_lwe(pos, &ctArg)) { return std::move(err); } + allocatedCiphertexts.push_back(ctArg); if (auto err = this->keySet.encrypt_lwe(pos, ctArg, arg)) { return std::move(err); } - inputs[pos] = ctArg; - rawArg[pos] = &inputs[pos]; + inputs[offset] = ctArg; + rawArg[offset] = &inputs[offset]; + return llvm::Error::success(); +} + +llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data, + size_t size) { + auto gate = inputGates[pos]; + auto info = std::get<0>(gate); + auto offset = std::get<1>(gate); + // Check if the width is compatible + // TODO - I found this rules empirically, they are a spec somewhere? + if (info.shape.width <= 8 && width != 8) { + return llvm::make_error( + llvm::Twine("argument width should be 8: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + if (info.shape.width > 8 && info.shape.width <= 16 && width != 16) { + return llvm::make_error( + llvm::Twine("argument width should be 16: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + if (info.shape.width > 16 && info.shape.width <= 32 && width != 32) { + return llvm::make_error( + llvm::Twine("argument width should be 32: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + if (info.shape.width > 32 && info.shape.width <= 64 && width != 64) { + return llvm::make_error( + llvm::Twine("argument width should be 64: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + if (info.shape.width > 64) { + return llvm::make_error( + llvm::Twine("argument width not supported: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + // Check the size + if (info.shape.size == 0) { + return llvm::make_error( + llvm::Twine("argument is not a vector: pos=").concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + if (info.shape.size != size) { + return llvm::make_error( + llvm::Twine("vector argument has not the expected size") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + // If argument is not encrypted, just save with the right calling convention. + if (info.encryption.hasValue()) { + // Else if is encrypted + // For moment we support only 8 bits inputs + uint8_t *data8 = (uint8_t *)data; + if (width != 8) { + return llvm::make_error( + llvm::Twine( + "argument width > 8 for encrypted gates are not supported: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + + // Allocate a buffer for ciphertexts. + auto ctBuffer = + (LweCiphertext_u64 **)malloc(size * sizeof(LweCiphertext_u64 *)); + ciphertextBuffers.push_back(ctBuffer); + // Allocate ciphertexts and encrypt + for (auto i = 0; i < size; i++) { + if (auto err = this->keySet.allocate_lwe(pos, &ctBuffer[i])) { + return std::move(err); + } + allocatedCiphertexts.push_back(ctBuffer[i]); + if (auto err = this->keySet.encrypt_lwe(pos, ctBuffer[i], data8[i])) { + return std::move(err); + } + } + // Replace the data by the buffer to ciphertext + data = (void *)ctBuffer; + } + // Set the buffer as the memref calling convention expect. + // allocated + inputs[offset] = (void *)0; // TODO - Better understand how it is used. + rawArg[offset] = &inputs[offset]; + // aligned + inputs[offset + 1] = data; + rawArg[offset + 1] = &inputs[offset + 1]; + // offset + inputs[offset + 2] = (void *)0; + rawArg[offset + 2] = &inputs[offset + 2]; + // size + inputs[offset + 3] = (void *)size; + rawArg[offset + 3] = &inputs[offset + 3]; + // stride + inputs[offset + 4] = (void *)0; + rawArg[offset + 4] = &inputs[offset + 4]; return llvm::Error::success(); } llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) { + auto gate = outputGates[pos]; + auto info = std::get<0>(gate); + auto offset = std::get<1>(gate); + + // Check is the argument is a scalar + if (info.shape.size != 0) { + return llvm::make_error( + llvm::Twine("output is not a scalar, pos=").concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } // If result is not encrypted, just set the result - if (!keySet.isOutputEncrypted(pos)) { - res = (uint64_t)(results[pos]); + if (!info.encryption.hasValue()) { + res = (uint64_t)(outputs[offset]); return llvm::Error::success(); } // Else if is encryted, decrypt - LweCiphertext_u64 *ct = (LweCiphertext_u64 *)(results[pos]); + LweCiphertext_u64 *ct = (LweCiphertext_u64 *)(outputs[offset]); if (auto err = this->keySet.decrypt_lwe(pos, ct, res)) { return std::move(err); } return llvm::Error::success(); } +llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res, + size_t size) { + auto gate = outputGates[pos]; + auto info = std::get<0>(gate); + auto offset = std::get<1>(gate); + + // Check is the argument is a scalar + if (info.shape.size == 0) { + return llvm::make_error( + llvm::Twine("output is not a tensor, pos=").concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + if (!info.encryption.hasValue()) { + return llvm::make_error( + "unencrypted result as tensor output NYI", + llvm::inconvertibleErrorCode()); + } + // Get the values as the memref calling convention expect. + void *allocated = outputs[offset]; // TODO - Better understand how it is used. + // aligned + void *aligned = outputs[offset + 1]; + // offset + size_t offset_r = (size_t)outputs[offset + 2]; + // size + size_t size_r = (size_t)outputs[offset + 3]; + // stride + size_t stride = (size_t)outputs[offset + 4]; + // Check the sizes + if (info.shape.size != size || size_r != size) { + return llvm::make_error("output bad result buffer size", + llvm::inconvertibleErrorCode()); + } + // decrypt and fill the result buffer + for (auto i = 0; i < size_r; i++) { + LweCiphertext_u64 *ct = ((LweCiphertext_u64 **)(aligned))[i]; + if (auto err = this->keySet.decrypt_lwe(pos, ct, res[i])) { + return std::move(err); + } + } + return llvm::Error::success(); +} + } // namespace zamalang -} // namespace mlir \ No newline at end of file +} // namespace mlir diff --git a/compiler/lib/Support/KeySet.cpp b/compiler/lib/Support/KeySet.cpp index 739083c8d..f44b34181 100644 --- a/compiler/lib/Support/KeySet.cpp +++ b/compiler/lib/Support/KeySet.cpp @@ -42,7 +42,7 @@ KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb, auto e = keySet->generateSecretKey(secretKeyParam.first, secretKeyParam.second, generator); if (e) { - return e; + return std::move(e); } } CAPI_ERR_TO_LLVM_ERROR(free_secret_generator(&err, generator), @@ -60,7 +60,7 @@ KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb, bootstrapKeyParam.second, keySet->encryptionRandomGenerator); if (e) { - return e; + return std::move(e); } } for (auto keyswitchParam : params.keyswitchKeys) { @@ -68,7 +68,7 @@ KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb, keyswitchParam.second, keySet->encryptionRandomGenerator); if (e) { - return e; + return std::move(e); } } } @@ -112,9 +112,8 @@ llvm::Error KeySet::generateSecretKey(LweSecretKeyID id, LweSecretKeyParam param, SecretRandomGenerator *generator) { LweSecretKey_u64 *sk; - CAPI_ERR_TO_LLVM_ERROR( - sk = allocate_lwe_secret_key_u64(&err, {_0 : param.size}), - "cannot allocate secret key"); + CAPI_ERR_TO_LLVM_ERROR(sk = allocate_lwe_secret_key_u64(&err, {param.size}), + "cannot allocate secret key"); CAPI_ERR_TO_LLVM_ERROR(fill_lwe_secret_key_u64(&err, sk, generator), "cannot fill secret key with random generator") secretKeys[id] = {param, sk}; @@ -250,6 +249,7 @@ llvm::Error KeySet::encrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext, llvm::Error KeySet::decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext, uint64_t &output) { + if (argPos >= outputs.size()) { return llvm::make_error( "decrypt_lwe: position of argument is too high", @@ -262,13 +262,14 @@ llvm::Error KeySet::decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext, llvm::inconvertibleErrorCode()); } // Decrypt - Plaintext_u64 plaintext; + Plaintext_u64 plaintext = {0}; CAPI_ERR_TO_LLVM_ERROR( decrypt_lwe_u64(&err, std::get<2>(outputSk), ciphertext, &plaintext), "cannot decrypt"); // Decode output = plaintext._0 >> (64 - (std::get<0>(outputSk).encryption->encoding.precision + 1)); + return llvm::Error::success(); } diff --git a/compiler/tests/unittest/CMakeLists.txt b/compiler/tests/unittest/CMakeLists.txt index 2aeff5bad..8fa00b8e6 100644 --- a/compiler/tests/unittest/CMakeLists.txt +++ b/compiler/tests/unittest/CMakeLists.txt @@ -1,5 +1,8 @@ enable_testing() +include_directories(${PROJECT_SOURCE_DIR}/include) + + add_executable( hello_test hello_test.cc @@ -7,6 +10,7 @@ add_executable( target_link_libraries( hello_test gtest_main + ZamalangSupport ) include(GoogleTest) diff --git a/compiler/tests/unittest/hello_test.cc b/compiler/tests/unittest/hello_test.cc index 5a57e138f..429b346d5 100644 --- a/compiler/tests/unittest/hello_test.cc +++ b/compiler/tests/unittest/hello_test.cc @@ -1,9 +1,347 @@ #include -// Demonstrate some basic assertions. -TEST(HelloTest, BasicAssertions) { - // Expect two strings not to be equal. - EXPECT_STRNE("hello", "world"); - // Expect equality. - EXPECT_EQ(7 * 6, 42); +#include "zamalang/Support/CompilerEngine.h" + +#define ASSERT_LLVM_ERROR(err) \ + if (err) { \ + llvm::errs() << "error: " << std::move(err) << "\n"; \ + ASSERT_TRUE(false); \ + } + +TEST(CompileAndRunHLFHE, add_eint) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( +func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> { + %1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>) + return %1: !HLFHE.eint<7> } +)XXX"; + ASSERT_FALSE(engine.compile(mlirStr)); + auto maybeResult = engine.run({1, 2}); + ASSERT_TRUE((bool)maybeResult); + uint64_t result = maybeResult.get(); + ASSERT_EQ(result, 3); +} + +TEST(CompileAndRunTensorStd, extract_64) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( +func @main(%t: tensor<10xi64>, %i: index) -> i64{ + %c = tensor.extract %t[%i] : tensor<10xi64> + return %c : i64 +} +)XXX"; + ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + const size_t size = 10; + uint64_t t_arg[size]{0xFFFFFFFFFFFFFFFF, + 0, + 8978, + 2587490, + 90, + 197864, + 698735, + 72132, + 87474, + 42}; + for (size_t i = 0; i < size; i++) { + auto maybeArgument = engine.buildArgument(); + ASSERT_LLVM_ERROR(maybeArgument.takeError()); + auto argument = std::move(maybeArgument.get()); + // Set the %t argument + ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); + // Set the %i argument + ASSERT_LLVM_ERROR(argument->setArg(1, i)); + // Invoke the function + ASSERT_LLVM_ERROR(engine.invoke(*argument)); + // Get and assert the result + uint64_t res = 0; + ASSERT_LLVM_ERROR(argument->getResult(0, res)); + ASSERT_EQ(res, t_arg[i]); + } +} + +TEST(CompileAndRunTensorStd, extract_32) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( +func @main(%t: tensor<10xi32>, %i: index) -> i32{ + %c = tensor.extract %t[%i] : tensor<10xi32> + return %c : i32 +} +)XXX"; + ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + const size_t size = 10; + uint32_t t_arg[size]{0xFFFFFFFF, 0, 8978, 2587490, 90, + 197864, 698735, 72132, 87474, 42}; + for (size_t i = 0; i < size; i++) { + auto maybeArgument = engine.buildArgument(); + ASSERT_LLVM_ERROR(maybeArgument.takeError()); + auto argument = std::move(maybeArgument.get()); + // Set the %t argument + ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); + // Set the %i argument + ASSERT_LLVM_ERROR(argument->setArg(1, i)); + // Invoke the function + ASSERT_LLVM_ERROR(engine.invoke(*argument)); + // Get and assert the result + uint64_t res = 0; + ASSERT_LLVM_ERROR(argument->getResult(0, res)); + ASSERT_EQ(res, t_arg[i]); + } +} + +TEST(CompileAndRunTensorStd, extract_16) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( +func @main(%t: tensor<10xi16>, %i: index) -> i16{ + %c = tensor.extract %t[%i] : tensor<10xi16> + return %c : i16 +} +)XXX"; + ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + const size_t size = 10; + uint16_t t_arg[size]{0xFFFF, 0, 59589, 47826, 16227, + 63269, 36435, 52380, 7401, 13313}; + for (size_t i = 0; i < size; i++) { + auto maybeArgument = engine.buildArgument(); + ASSERT_LLVM_ERROR(maybeArgument.takeError()); + auto argument = std::move(maybeArgument.get()); + // Set the %t argument + ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); + // Set the %i argument + ASSERT_LLVM_ERROR(argument->setArg(1, i)); + // Invoke the function + ASSERT_LLVM_ERROR(engine.invoke(*argument)); + // Get and assert the result + uint64_t res = 0; + ASSERT_LLVM_ERROR(argument->getResult(0, res)); + ASSERT_EQ(res, t_arg[i]); + } +} + +TEST(CompileAndRunTensorStd, extract_8) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( +func @main(%t: tensor<10xi8>, %i: index) -> i8{ + %c = tensor.extract %t[%i] : tensor<10xi8> + return %c : i8 +} +)XXX"; + ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + const size_t size = 10; + uint8_t t_arg[size]{0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93}; + for (size_t i = 0; i < size; i++) { + auto maybeArgument = engine.buildArgument(); + ASSERT_LLVM_ERROR(maybeArgument.takeError()); + auto argument = std::move(maybeArgument.get()); + // Set the %t argument + ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); + // Set the %i argument + ASSERT_LLVM_ERROR(argument->setArg(1, i)); + // Invoke the function + ASSERT_LLVM_ERROR(engine.invoke(*argument)); + // Get and assert the result + uint64_t res = 0; + ASSERT_LLVM_ERROR(argument->getResult(0, res)); + ASSERT_EQ(res, t_arg[i]); + } +} + +TEST(CompileAndRunTensorStd, extract_5) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( +func @main(%t: tensor<10xi5>, %i: index) -> i5{ + %c = tensor.extract %t[%i] : tensor<10xi5> + return %c : i5 +} +)XXX"; + ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + const size_t size = 10; + uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; + for (size_t i = 0; i < size; i++) { + auto maybeArgument = engine.buildArgument(); + ASSERT_LLVM_ERROR(maybeArgument.takeError()); + auto argument = std::move(maybeArgument.get()); + // Set the %t argument + ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); + // Set the %i argument + ASSERT_LLVM_ERROR(argument->setArg(1, i)); + // Invoke the function + ASSERT_LLVM_ERROR(engine.invoke(*argument)); + // Get and assert the result + uint64_t res = 0; + ASSERT_LLVM_ERROR(argument->getResult(0, res)); + ASSERT_EQ(res, t_arg[i]); + } +} + +TEST(CompileAndRunTensorStd, extract_1) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( +func @main(%t: tensor<10xi1>, %i: index) -> i1{ + %c = tensor.extract %t[%i] : tensor<10xi1> + return %c : i1 +} +)XXX"; + ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + const size_t size = 10; + uint8_t t_arg[size]{0, 0, 1, 0, 1, 1, 0, 1, 1, 0}; + for (size_t i = 0; i < size; i++) { + auto maybeArgument = engine.buildArgument(); + ASSERT_LLVM_ERROR(maybeArgument.takeError()); + auto argument = std::move(maybeArgument.get()); + // Set the %t argument + ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); + // Set the %i argument + ASSERT_LLVM_ERROR(argument->setArg(1, i)); + // Invoke the function + ASSERT_LLVM_ERROR(engine.invoke(*argument)); + // Get and assert the result + uint64_t res = 0; + ASSERT_LLVM_ERROR(argument->getResult(0, res)); + ASSERT_EQ(res, t_arg[i]); + } +} + +TEST(CompileAndRunTensorEncrypted, extract_5) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( +func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index) -> !HLFHE.eint<5>{ + %c = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>> + return %c : !HLFHE.eint<5> +} +)XXX"; + ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + const size_t size = 10; + uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; + for (size_t i = 0; i < size; i++) { + auto maybeArgument = engine.buildArgument(); + ASSERT_LLVM_ERROR(maybeArgument.takeError()); + auto argument = std::move(maybeArgument.get()); + // Set the %t argument + ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); + // Set the %i argument + ASSERT_LLVM_ERROR(argument->setArg(1, i)); + // Invoke the function + ASSERT_LLVM_ERROR(engine.invoke(*argument)); + // Get and assert the result + uint64_t res = 0; + ASSERT_LLVM_ERROR(argument->getResult(0, res)); + ASSERT_EQ(res, t_arg[i]); + } +} + +TEST(CompileAndRunTensorEncrypted, extract_twice_and_add_5) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( +func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index, %j: index) -> !HLFHE.eint<5>{ + %ti = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>> + %tj = tensor.extract %t[%j] : tensor<10x!HLFHE.eint<5>> + %c = "HLFHE.add_eint"(%ti, %tj) : (!HLFHE.eint<5>, !HLFHE.eint<5>) -> !HLFHE.eint<5> + return %c : !HLFHE.eint<5> +} +)XXX"; + ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + const size_t size = 10; + uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; + for (size_t i = 0; i < size; i++) { + for (size_t j = 0; j < size; j++) { + auto maybeArgument = engine.buildArgument(); + ASSERT_LLVM_ERROR(maybeArgument.takeError()); + auto argument = std::move(maybeArgument.get()); + // Set the %t argument + ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); + // Set the %i argument + ASSERT_LLVM_ERROR(argument->setArg(1, i)); + // Set the %j argument + ASSERT_LLVM_ERROR(argument->setArg(2, j)); + // Invoke the function + ASSERT_LLVM_ERROR(engine.invoke(*argument)); + // Get and assert the result + uint64_t res = 0; + ASSERT_LLVM_ERROR(argument->getResult(0, res)); + ASSERT_EQ(res, t_arg[i] + t_arg[j]); + } + } +} + +TEST(CompileAndRunTensorEncrypted, dim_5) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( +func @main(%t: tensor<10x!HLFHE.eint<5>>) -> index{ + %c0 = constant 0 : index + %c = tensor.dim %t, %c0 : tensor<10x!HLFHE.eint<5>> + return %c : index +} +)XXX"; + ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + const size_t size = 10; + uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; + auto maybeArgument = engine.buildArgument(); + ASSERT_LLVM_ERROR(maybeArgument.takeError()); + auto argument = std::move(maybeArgument.get()); + // Set the %t argument + ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); + // Invoke the function + ASSERT_LLVM_ERROR(engine.invoke(*argument)); + // Get and assert the result + uint64_t res = 0; + ASSERT_LLVM_ERROR(argument->getResult(0, res)); + ASSERT_EQ(res, size); +} + +TEST(CompileAndRunTensorEncrypted, from_elements_5) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( +func @main(%0: !HLFHE.eint<5>) -> tensor<1x!HLFHE.eint<5>> { + %t = tensor.from_elements %0 : tensor<1x!HLFHE.eint<5>> + return %t: tensor<1x!HLFHE.eint<5>> +} +)XXX"; + ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + auto maybeArgument = engine.buildArgument(); + ASSERT_LLVM_ERROR(maybeArgument.takeError()); + auto argument = std::move(maybeArgument.get()); + // Set the %t argument + ASSERT_LLVM_ERROR(argument->setArg(0, 10)); + // Invoke the function + ASSERT_LLVM_ERROR(engine.invoke(*argument)); + // Get and assert the result + size_t size_res = 1; + uint64_t t_res[size_res]; + ASSERT_LLVM_ERROR(argument->getResult(0, t_res, size_res)); + ASSERT_EQ(t_res[0], 10); +} + +TEST(CompileAndRunTensorEncrypted, in_out_tensor_with_op_5) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( +func @main(%in: tensor<2x!HLFHE.eint<5>>) -> tensor<3x!HLFHE.eint<5>> { + %c_0 = constant 0 : index + %c_1 = constant 1 : index + %a = tensor.extract %in[%c_0] : tensor<2x!HLFHE.eint<5>> + %b = tensor.extract %in[%c_1] : tensor<2x!HLFHE.eint<5>> + %aplusa = "HLFHE.add_eint"(%a, %a): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>) + %aplusb = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>) + %bplusb = "HLFHE.add_eint"(%b, %b): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>) + %out = tensor.from_elements %aplusa, %aplusb, %bplusb : tensor<3x!HLFHE.eint<5>> + return %out: tensor<3x!HLFHE.eint<5>> +} +)XXX"; + ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + auto maybeArgument = engine.buildArgument(); + ASSERT_LLVM_ERROR(maybeArgument.takeError()); + auto argument = std::move(maybeArgument.get()); + // Set the argument + const size_t in_size = 2; + uint8_t in[in_size] = {2, 16}; + ASSERT_LLVM_ERROR(argument->setArg(0, in, in_size)); + // Invoke the function + ASSERT_LLVM_ERROR(engine.invoke(*argument)); + // Get and assert the result + const size_t size_res = 3; + uint64_t t_res[size_res]; + ASSERT_LLVM_ERROR(argument->getResult(0, t_res, size_res)); + ASSERT_EQ(t_res[0], in[0] + in[0]); + ASSERT_EQ(t_res[1], in[0] + in[1]); + ASSERT_EQ(t_res[2], in[1] + in[1]); +} \ No newline at end of file