From 8cd3a3a5997dba95504617416c28a5d09941972a Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Mon, 20 Jun 2022 11:01:06 +0200 Subject: [PATCH] feat(compiler): First draft to support FHE.eint up to 16bits For now what it works are only levelled ops with user parameters. (take a look to the tests) Done: - Add parameters to the fhe parameters to support CRT-based large integers - Add command line options and tests options to allows the user to give those new parameters - Update the dialects and pipeline to handle new fhe parameters for CRT-based large integers - Update the client parameters and the client library to handle the CRT-based large integers Todo: - Plug the optimizer to compute the CRT-based large interger parameters - Plug the pbs for the CRT-based large integer --- compiler/include/concretelang/ClientLib/CRT.h | 46 ++ .../concretelang/ClientLib/ClientParameters.h | 49 ++ .../ClientLib/EncryptedArguments.h | 126 +++- .../include/concretelang/ClientLib/KeySet.h | 7 +- .../concretelang/ClientLib/PublicArguments.h | 10 +- .../concretelang/Conversion/FHEToTFHE/Pass.h | 11 +- .../Conversion/FHEToTFHE/Patterns.h | 3 +- .../Conversion/TFHEToConcrete/Patterns.h | 26 +- .../Conversion/Utils/GlobalFHEContext.h | 31 + .../Conversion/Utils/TensorOpTypeConversion.h | 5 + .../Dialect/BConcrete/IR/BConcreteOps.td | 58 ++ .../Dialect/BConcrete/Transforms/Passes.h | 2 + .../Dialect/BConcrete/Transforms/Passes.td | 5 + .../Dialect/Concrete/IR/ConcreteOps.td | 38 +- .../Dialect/Concrete/IR/ConcreteTypes.td | 5 +- .../concretelang/Dialect/TFHE/IR/TFHEOps.td | 25 + .../concretelang/Dialect/TFHE/IR/TFHETypes.td | 40 +- .../include/concretelang/Runtime/wrappers.h | 18 + .../concretelang/Support/CompilerEngine.h | 5 + .../concretelang/Support/LambdaSupport.h | 9 +- .../include/concretelang/Support/Pipeline.h | 1 + compiler/lib/ClientLib/CMakeLists.txt | 13 +- compiler/lib/ClientLib/CRT.cpp | 99 ++++ compiler/lib/ClientLib/ClientParameters.cpp | 15 + compiler/lib/ClientLib/EncryptedArguments.cpp | 111 +--- compiler/lib/ClientLib/KeySet.cpp | 72 ++- compiler/lib/ClientLib/KeySetCache.cpp | 1 - .../ConcreteToBConcrete.cpp | 553 +++++++++-------- .../lib/Conversion/FHEToTFHE/FHEToTFHE.cpp | 216 ++++--- .../TFHEGlobalParametrization.cpp | 64 +- .../TFHEToConcrete/TFHEToConcrete.cpp | 40 ++ .../BufferizableOpInterfaceImpl.cpp | 62 +- .../BConcrete/Transforms/CMakeLists.txt | 3 +- .../BConcrete/Transforms/EliminateCRTOps.cpp | 561 ++++++++++++++++++ .../Dialect/Concrete/IR/ConcreteDialect.cpp | 72 ++- .../Concrete/Transforms/Optimization.cpp | 21 +- compiler/lib/Dialect/TFHE/IR/TFHEDialect.cpp | 3 +- compiler/lib/Dialect/TFHE/IR/TFHEOps.cpp | 13 + compiler/lib/Dialect/TFHE/IR/TFHETypes.cpp | 74 ++- compiler/lib/Runtime/CMakeLists.txt | 3 +- compiler/lib/Runtime/wrappers.cpp | 8 +- compiler/lib/ServerLib/DynamicRankCall.cpp | 69 +-- compiler/lib/ServerLib/ServerLambda.cpp | 15 +- compiler/lib/ServerLib/genDynamicRankCall.py | 4 +- compiler/lib/Support/CompilerEngine.cpp | 49 +- compiler/lib/Support/Jit.cpp | 29 +- compiler/lib/Support/Pipeline.cpp | 13 +- compiler/lib/Support/V0ClientParameters.cpp | 23 +- compiler/src/main.cpp | 77 ++- .../ConcreteToBConcrete/add_lwe.mlir | 13 +- .../ConcreteToBConcrete/add_lwe_int.mlir | 38 +- .../ConcreteToBConcrete/identity.mlir | 7 + .../ConcreteToBConcrete/mul_lwe_int.mlir | 33 +- .../ConcreteToBConcrete/neg_lwe.mlir | 9 + .../tensor_exapand_collapse_shape.mlir | 52 +- .../ConcreteToBConcrete/tensor_identity.mlir | 8 + .../TFHEToConcrete/add_glwe_int.mlir | 18 +- .../TFHEToConcrete/mul_glwe_int.mlir | 18 +- .../TFHEToConcrete/sub_int_glwe.mlir | 22 +- .../check_tests/Dialect/BConcrete/ops.mlir | 38 +- .../Dialect/Concrete/no_optimization.mlir | 14 +- .../check_tests/Dialect/Concrete/ops.mlir | 39 +- .../Dialect/Concrete/optimization.mlir | 20 +- .../check_tests/Dialect/Concrete/types.mlir | 8 +- .../Dialect/TFHE/op_add_glwe.invalid.mlir | 17 + .../Dialect/TFHE/op_add_glwe_int.invalid.mlir | 10 + .../Dialect/TFHE/op_mul_glwe_int.invalid.mlir | 10 + .../Dialect/TFHE/op_neg_glwe.invalid.mlir | 9 + .../Dialect/TFHE/op_sub_int_glwe.invalid.mlir | 11 + .../check_tests/Dialect/TFHE/types_glwe.mlir | 12 + .../end_to_end_fixture/EndToEndFixture.cpp | 8 +- .../end_to_end_fixture/EndToEndFixture.h | 4 +- .../end_to_end_encrypted_tensor.yaml | 241 ++++++++ .../end_to_end_fixture/end_to_end_fhe.yaml | 334 ++++++++++- .../end_to_end_fhelinalg.yaml | 226 +++++++ .../end_to_end_tests/end_to_end_jit_fhe.cc | 3 + .../end_to_end_jit_fhelinalg.cc | 95 --- compiler/tests/tests_tools/assert.h | 9 +- .../concretelang/ClientLib/CMakeLists.txt | 1 + .../unit_tests/concretelang/ClientLib/CRT.cpp | 50 ++ .../ClientLib/ClientParameters.cpp | 2 +- .../concretelang/ClientLib/KeySet.cpp | 47 +- 82 files changed, 3192 insertions(+), 1037 deletions(-) create mode 100644 compiler/include/concretelang/ClientLib/CRT.h create mode 100644 compiler/lib/ClientLib/CRT.cpp create mode 100644 compiler/lib/Dialect/BConcrete/Transforms/EliminateCRTOps.cpp create mode 100644 compiler/tests/unit_tests/concretelang/ClientLib/CRT.cpp diff --git a/compiler/include/concretelang/ClientLib/CRT.h b/compiler/include/concretelang/ClientLib/CRT.h new file mode 100644 index 000000000..3ac27282a --- /dev/null +++ b/compiler/include/concretelang/ClientLib/CRT.h @@ -0,0 +1,46 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_CLIENTLIB_CRT_H_ +#define CONCRETELANG_CLIENTLIB_CRT_H_ + +#include +#include + +namespace concretelang { +namespace clientlib { +namespace crt { + +/// Compute the product of the moduli of the crt decomposition. +/// +/// \param moduli The moduli of the crt decomposition +/// \returns The product of moduli +uint64_t productOfModuli(std::vector moduli); + +/// Compute the crt decomposition of a `val` according the given `moduli`. +/// +/// \param moduli The moduli to compute the decomposition. +/// \param val The value to decompose. +/// \returns The remainders. +std::vector crt(std::vector moduli, uint64_t val); + +/// Compute the inverse of the crt decomposition. +/// +/// \param moduli The moduli used to compute the inverse decomposition. +/// \param remainders The remainders of the decomposition. +uint64_t iCrt(std::vector moduli, std::vector remainders); + +/// Encode the plaintext with the given modulus and the product of moduli of the +/// crt decomposition +uint64_t encode(int64_t plaintext, uint64_t modulus, uint64_t product); + +/// Decode follow the crt encoding +uint64_t decode(uint64_t val, uint64_t modulus); + +} // namespace crt +} // namespace clientlib +} // namespace concretelang + +#endif diff --git a/compiler/include/concretelang/ClientLib/ClientParameters.h b/compiler/include/concretelang/ClientLib/ClientParameters.h index 8e8f98940..58112eb54 100644 --- a/compiler/include/concretelang/ClientLib/ClientParameters.h +++ b/compiler/include/concretelang/ClientLib/ClientParameters.h @@ -7,6 +7,7 @@ #define CONCRETELANG_CLIENTLIB_CLIENTPARAMETERS_H_ #include +#include #include #include @@ -31,6 +32,7 @@ typedef size_t DecompositionBaseLog; typedef size_t PolynomialSize; typedef size_t Precision; typedef double Variance; +typedef std::vector CRTDecomposition; typedef uint64_t LweDimension; typedef uint64_t GlweDimension; @@ -87,6 +89,7 @@ static inline bool operator==(const KeyswitchKeyParam &lhs, struct Encoding { Precision precision; + CRTDecomposition crt; }; static inline bool operator==(const Encoding &lhs, const Encoding &rhs) { return lhs.precision == rhs.precision; @@ -168,6 +171,52 @@ struct ClientParameters { } return secretKey->second; } + + /// bufferSize returns the size of the whole buffer of a gate. + int64_t bufferSize(CircuitGate gate) { + if (!gate.encryption.hasValue()) { + // Value is not encrypted just returns the tensor size + return gate.shape.size; + } + auto shapeSize = gate.shape.size == 0 ? 1 : gate.shape.size; + // Size of the ciphertext + return shapeSize * lweBufferSize(gate); + } + + /// lweBufferSize returns the size of one ciphertext of a gate. + int64_t lweBufferSize(CircuitGate gate) { + assert(gate.encryption.hasValue()); + auto nbBlocks = gate.encryption->encoding.crt.size(); + nbBlocks = nbBlocks == 0 ? 1 : nbBlocks; + + auto param = lweSecretKeyParam(gate); + assert(param.has_value()); + return param.value().lweSize() * nbBlocks; + } + + /// bufferShape returns the shape of the tensor for the given gate. It returns + /// the shape used at low-level, i.e. contains the dimensions for ciphertexts. + std::vector bufferShape(CircuitGate gate) { + if (!gate.encryption.hasValue()) { + // Value is not encrypted just returns the tensor shape + return gate.shape.dimensions; + } + auto lweSecreteKeyParam = lweSecretKeyParam(gate); + assert(lweSecreteKeyParam.has_value()); + + // Copy the shape + std::vector shape(gate.shape.dimensions); + + auto crt = gate.encryption->encoding.crt; + + // CRT case: Add one dimension equals to the number of blocks + if (!crt.empty()) { + shape.push_back(crt.size()); + } + // Add one dimension for the size of ciphertext(s) + shape.push_back(lweSecreteKeyParam.value().lweSize()); + return shape; + } }; static inline bool operator==(const ClientParameters &lhs, diff --git a/compiler/include/concretelang/ClientLib/EncryptedArguments.h b/compiler/include/concretelang/ClientLib/EncryptedArguments.h index 71eb68e66..513f70efb 100644 --- a/compiler/include/concretelang/ClientLib/EncryptedArguments.h +++ b/compiler/include/concretelang/ClientLib/EncryptedArguments.h @@ -23,11 +23,24 @@ using concretelang::error::StringError; class PublicArguments; +inline size_t bitWidthAsWord(size_t exactBitWidth) { + if (exactBitWidth <= 8) + return 8; + if (exactBitWidth <= 16) + return 16; + if (exactBitWidth <= 32) + return 32; + if (exactBitWidth <= 64) + return 64; + assert(false && "Bit witdh > 64 not supported"); +} + /// Temporary object used to hold and encrypt parameters before calling a /// ClientLambda. Use preferably TypeClientLambda and serializeCall(Args...). /// Otherwise convert it to a PublicArguments and use /// serializeCall(PublicArguments, KeySet). class EncryptedArguments { + public: EncryptedArguments() : currentPos(0) {} @@ -73,7 +86,10 @@ public: /// Add a vector-tensor argument. outcome::checked pushArg(std::vector arg, - KeySet &keySet); + KeySet &keySet) { + return pushArg((uint8_t *)arg.data(), + llvm::ArrayRef{(int64_t)arg.size()}, keySet); + } /// Add a 1D tensor argument with data and size of the dimension. template @@ -82,26 +98,20 @@ public: return pushArg(std::vector(data, data + dim1), keySet); } - // Add a tensor argument. - template - outcome::checked - pushArg(const T *data, llvm::ArrayRef shape, KeySet &keySet) { - return pushArg(8 * sizeof(T), static_cast(data), shape, - keySet); - } - /// Add a 1D tensor argument. template outcome::checked pushArg(std::array arg, KeySet &keySet) { - return pushArg(8, (void *)arg.data(), {size}, keySet); + return pushArg((uint8_t *)arg.data(), llvm::ArrayRef{size}, + keySet); } /// Add a 2D tensor argument. template outcome::checked pushArg(std::array, size0> arg, KeySet &keySet) { - return pushArg(8, (void *)arg.data(), {size0, size1}, keySet); + return pushArg((uint8_t *)arg.data(), llvm::ArrayRef{size0, size1}, + keySet); } /// Add a 3D tensor argument. @@ -109,7 +119,8 @@ public: outcome::checked pushArg(std::array, size1>, size0> arg, KeySet &keySet) { - return pushArg(8, (void *)arg.data(), {size0, size1, size2}, keySet); + return pushArg((uint8_t *)arg.data(), + llvm::ArrayRef{size0, size1, size2}, keySet); } // Generalize by computing shape by template recursion @@ -125,13 +136,94 @@ public: template outcome::checked pushArg(T *data, llvm::ArrayRef shape, KeySet &keySet) { - return pushArg(8 * sizeof(T), static_cast(data), shape, - keySet); + return pushArg(static_cast(data), shape, keySet); } - outcome::checked pushArg(size_t width, const void *data, - llvm::ArrayRef shape, - KeySet &keySet); + template + outcome::checked + pushArg(const T *data, llvm::ArrayRef shape, KeySet &keySet) { + OUTCOME_TRYV(checkPushTooManyArgs(keySet)); + auto pos = currentPos; + CircuitGate input = keySet.inputGate(pos); + // Check the width of data + if (input.shape.width > 64) { + return StringError("argument #") + << pos << " width > 64 bits is not supported"; + } + // Check the shape of tensor + if (input.shape.dimensions.empty()) { + return StringError("argument #") << pos << "is not a tensor"; + } + if (shape.size() != input.shape.dimensions.size()) { + return StringError("argument #") + << pos << "has not the expected number of dimension, got " + << shape.size() << " expected " << input.shape.dimensions.size(); + } + // Allocate empty + ciphertextBuffers.resize(ciphertextBuffers.size() + 1); + TensorData &values_and_sizes = ciphertextBuffers.back(); + + // Check shape + for (size_t i = 0; i < shape.size(); i++) { + if (shape[i] != input.shape.dimensions[i]) { + return StringError("argument #") + << pos << " has not the expected dimension #" << i << " , got " + << shape[i] << " expected " << input.shape.dimensions[i]; + } + } + // Set sizes + values_and_sizes.sizes = keySet.clientParameters().bufferShape(input); + + if (input.encryption.hasValue()) { + // Allocate values + values_and_sizes.values.resize( + keySet.clientParameters().bufferSize(input)); + auto lweSize = keySet.clientParameters().lweBufferSize(input); + auto &values = values_and_sizes.values; + for (size_t i = 0, offset = 0; i < input.shape.size; + i++, offset += lweSize) { + OUTCOME_TRYV(keySet.encrypt_lwe(pos, values.data() + offset, data[i])); + } + } else { + // Allocate values take care of gate bitwidth + auto bitsPerValue = bitWidthAsWord(input.shape.width); + auto bytesPerValue = bitsPerValue / 8; + auto nbWordPerValue = 8 / bytesPerValue; + // ceil division + auto size = (input.shape.size / nbWordPerValue) + + (input.shape.size % nbWordPerValue != 0); + size = size == 0 ? 1 : size; + values_and_sizes.values.resize(size); + auto v = (uint8_t *)values_and_sizes.values.data(); + for (size_t i = 0; i < input.shape.size; i++) { + auto dst = v + i * bytesPerValue; + auto src = (const uint8_t *)&data[i]; + for (size_t j = 0; j < bytesPerValue; j++) { + dst[j] = src[j]; + } + } + } + // allocated + preparedArgs.push_back(nullptr); + // aligned + preparedArgs.push_back((void *)values_and_sizes.values.data()); + // offset + preparedArgs.push_back((void *)0); + // sizes + for (size_t size : values_and_sizes.sizes) { + preparedArgs.push_back((void *)size); + } + + // Set the stride for each dimension, equal to the product of the + // following dimensions. + int64_t stride = values_and_sizes.length(); + for (size_t size : values_and_sizes.sizes) { + stride = (size == 0 ? 0 : (stride / size)); + preparedArgs.push_back((void *)stride); + } + currentPos++; + return outcome::success(); + } /// Recursive case for scalars: extract first scalar argument from /// parameter pack and forward rest diff --git a/compiler/include/concretelang/ClientLib/KeySet.h b/compiler/include/concretelang/ClientLib/KeySet.h index 3a78be378..4aa2c8904 100644 --- a/compiler/include/concretelang/ClientLib/KeySet.h +++ b/compiler/include/concretelang/ClientLib/KeySet.h @@ -37,7 +37,10 @@ public: static outcome::checked, StringError> generate(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb); - /// isInputEncrypted return true if the input at the given pos is encrypted. + /// Returns the ClientParameters associated with the KeySet. + ClientParameters clientParameters() { return _clientParameters; } + + // isInputEncrypted return true if the input at the given pos is encrypted. bool isInputEncrypted(size_t pos); /// getInputLweSecretKeyParam returns the parameters of the lwe secret key for @@ -155,6 +158,8 @@ private: std::map>> keyswitchKeys); + + clientlib::ClientParameters _clientParameters; }; } // namespace clientlib diff --git a/compiler/include/concretelang/ClientLib/PublicArguments.h b/compiler/include/concretelang/ClientLib/PublicArguments.h index 9baa12a44..3707fa310 100644 --- a/compiler/include/concretelang/ClientLib/PublicArguments.h +++ b/compiler/include/concretelang/ClientLib/PublicArguments.h @@ -108,8 +108,7 @@ struct PublicResult { } auto buffer = buffers[pos]; - auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize(); - + auto lweSize = clientParameters.lweBufferSize(gate); std::vector decryptedValues(buffer.length() / lweSize); for (size_t i = 0; i < decryptedValues.size(); i++) { auto ciphertext = &buffer.values[i * lweSize]; @@ -120,6 +119,13 @@ struct PublicResult { return decryptedValues; } + /// Return the shape of the clear tensor of a result. + outcome::checked, StringError> + asClearTextShape(size_t pos) { + OUTCOME_TRY(auto gate, clientParameters.ouput(pos)); + return gate.shape.dimensions; + } + // private: TODO tmp friend class ::concretelang::serverlib::ServerLambda; ClientParameters clientParameters; diff --git a/compiler/include/concretelang/Conversion/FHEToTFHE/Pass.h b/compiler/include/concretelang/Conversion/FHEToTFHE/Pass.h index aa2dbfe7f..0ea6408c5 100644 --- a/compiler/include/concretelang/Conversion/FHEToTFHE/Pass.h +++ b/compiler/include/concretelang/Conversion/FHEToTFHE/Pass.h @@ -10,8 +10,17 @@ namespace mlir { namespace concretelang { + +// ApplyLookupTableLowering indicates the strategy to lower an +// FHE.apply_loopup_table ops +enum ApplyLookupTableLowering { + KeySwitchBoostrapLowering, + WopPBSLowering, +}; + /// Create a pass to convert `FHE` dialect to `TFHE` dialect. -std::unique_ptr> createConvertFHEToTFHEPass(); +std::unique_ptr> +createConvertFHEToTFHEPass(ApplyLookupTableLowering lower); } // namespace concretelang } // namespace mlir diff --git a/compiler/include/concretelang/Conversion/FHEToTFHE/Patterns.h b/compiler/include/concretelang/Conversion/FHEToTFHE/Patterns.h index 52cbd4e67..4b94380d3 100644 --- a/compiler/include/concretelang/Conversion/FHEToTFHE/Patterns.h +++ b/compiler/include/concretelang/Conversion/FHEToTFHE/Patterns.h @@ -23,7 +23,8 @@ using TFHE::GLWECipherTextType; GLWECipherTextType convertTypeEncryptedIntegerToGLWE(mlir::MLIRContext *context, EncryptedIntegerType eint) { - return GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth()); + return GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth(), + llvm::ArrayRef()); } /// Converts the type `t` to `TFHE::GlweCiphetext` if `t` is a diff --git a/compiler/include/concretelang/Conversion/TFHEToConcrete/Patterns.h b/compiler/include/concretelang/Conversion/TFHEToConcrete/Patterns.h index f8a750266..701e2f73a 100644 --- a/compiler/include/concretelang/Conversion/TFHEToConcrete/Patterns.h +++ b/compiler/include/concretelang/Conversion/TFHEToConcrete/Patterns.h @@ -26,7 +26,8 @@ LweCiphertextType convertTypeToLWE(mlir::MLIRContext *context, auto glwe = type.dyn_cast_or_null(); if (glwe != nullptr) { assert(glwe.getPolynomialSize() == 1); - return LweCiphertextType::get(context, glwe.getDimension(), glwe.getP()); + return LweCiphertextType::get(context, glwe.getDimension(), glwe.getP(), + glwe.getCrtDecomposition()); } auto lwe = type.dyn_cast_or_null(); if (lwe != nullptr) { @@ -122,19 +123,10 @@ mlir::Value createConcreteOpFromTFHE(mlir::PatternRewriter &rewriter, mlir::Value createAddPlainLweCiphertextWithGlwe( mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Value arg0, mlir::Value arg1, mlir::OpResult result, mlir::Type encryptedType) { - PlaintextType encoded_type = - convertPlaintextTypeFromType(rewriter.getContext(), encryptedType); - // encode int into plaintext - mlir::Value encoded = rewriter - .create( - loc, encoded_type, arg1) - .plaintext(); - - // replace op using the encoded plaintext instead of int auto op = rewriter .create( - loc, result.getType(), arg0, encoded); + loc, result.getType(), arg0, arg1); convertOperandAndResultTypes(rewriter, op, convertTypeToLWEIfTFHEType); @@ -171,21 +163,11 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Value arg0, mlir::Value arg1, mlir::OpResult result) { - auto inType = arg0.getType(); - CleartextType encoded_type = - convertCleartextTypeFromType(rewriter.getContext(), inType); - // encode int into plaintext - mlir::Value encoded = - rewriter - .create( - loc, encoded_type, arg1) - .cleartext(); - // replace op using the encoded plaintext instead of int auto op = rewriter .create( - loc, result.getType(), arg0, encoded); + loc, result.getType(), arg0, arg1); convertOperandAndResultTypes(rewriter, op, convertTypeToLWEIfTFHEType); diff --git a/compiler/include/concretelang/Conversion/Utils/GlobalFHEContext.h b/compiler/include/concretelang/Conversion/Utils/GlobalFHEContext.h index fdba7ca80..9c6677b7f 100644 --- a/compiler/include/concretelang/Conversion/Utils/GlobalFHEContext.h +++ b/compiler/include/concretelang/Conversion/Utils/GlobalFHEContext.h @@ -6,15 +6,44 @@ #ifndef CONCRETELANG_CONVERSION_GLOBALFHECONTEXT_H_ #define CONCRETELANG_CONVERSION_GLOBALFHECONTEXT_H_ #include +#include +#include + +#include "llvm/ADT/Optional.h" namespace mlir { namespace concretelang { +typedef std::vector CRTDecomposition; + struct V0FHEConstraint { size_t norm2; size_t p; }; +struct PackingKeySwitchParameter { + size_t inputLweDimension; + size_t inputLweCount; + size_t outputPolynomialSize; + size_t level; + size_t baseLog; +}; + +struct CitcuitBoostrapParameter { + size_t level; + size_t baseLog; +}; + +struct WopPBSParameter { + PackingKeySwitchParameter packingKeySwitch; + CitcuitBoostrapParameter circuitBootstrap; +}; + +struct LargeIntegerParameter { + CRTDecomposition crtDecomposition; + WopPBSParameter wopPBS; +}; + struct V0Parameter { size_t glweDimension; size_t logPolynomialSize; @@ -24,6 +53,8 @@ struct V0Parameter { size_t ksLevel; size_t ksLogBase; + llvm::Optional largeInteger; + V0Parameter() = delete; V0Parameter(size_t glweDimension, size_t logPolynomialSize, size_t nSmall, diff --git a/compiler/include/concretelang/Conversion/Utils/TensorOpTypeConversion.h b/compiler/include/concretelang/Conversion/Utils/TensorOpTypeConversion.h index 543df439f..a61f3399b 100644 --- a/compiler/include/concretelang/Conversion/Utils/TensorOpTypeConversion.h +++ b/compiler/include/concretelang/Conversion/Utils/TensorOpTypeConversion.h @@ -28,6 +28,11 @@ populateWithTensorTypeConverterPatterns(mlir::RewritePatternSet &patterns, patterns.getContext(), typeConverter); addDynamicallyLegalTypeOp(target, typeConverter); + + // InsertOp + patterns.add>( + patterns.getContext(), typeConverter); + addDynamicallyLegalTypeOp(target, typeConverter); // InsertSliceOp patterns.add>( patterns.getContext(), typeConverter); diff --git a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td b/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td index 73fabc8de..91d1dcb62 100644 --- a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td +++ b/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td @@ -20,21 +20,56 @@ def BConcrete_AddLweBuffersOp : BConcrete_Op<"add_lwe_buffer"> { let results = (outs 1DTensorOf<[I64]>:$result); } +def BConcrete_AddCRTLweBuffersOp : BConcrete_Op<"add_crt_lwe_buffer"> { + let arguments = (ins + 2DTensorOf<[I64]>:$lhs, + 2DTensorOf<[I64]>:$rhs, + I64ArrayAttr:$crtDecomposition + ); + let results = (outs 2DTensorOf<[I64]>:$result); +} + def BConcrete_AddPlaintextLweBufferOp : BConcrete_Op<"add_plaintext_lwe_buffer"> { let arguments = (ins 1DTensorOf<[I64]>:$lhs, I64:$rhs); let results = (outs 1DTensorOf<[I64]>:$result); } +def BConcrete_AddPlaintextCRTLweBufferOp : BConcrete_Op<"add_plaintext_crt_lwe_buffer"> { + let arguments = (ins + 2DTensorOf<[I64]>:$lhs, + AnyInteger:$rhs, + I64ArrayAttr:$crtDecomposition + ); + let results = (outs 2DTensorOf<[I64]>:$result); +} + def BConcrete_MulCleartextLweBufferOp : BConcrete_Op<"mul_cleartext_lwe_buffer"> { let arguments = (ins 1DTensorOf<[I64]>:$lhs, I64:$rhs); let results = (outs 1DTensorOf<[I64]>:$result); } +def BConcrete_MulCleartextCRTLweBufferOp : BConcrete_Op<"mul_cleartext_crt_lwe_buffer"> { + let arguments = (ins + 2DTensorOf<[I64]>:$lhs, + AnyInteger:$rhs, + I64ArrayAttr:$crtDecomposition + ); + let results = (outs 2DTensorOf<[I64]>:$result); +} + def BConcrete_NegateLweBufferOp : BConcrete_Op<"negate_lwe_buffer"> { let arguments = (ins 1DTensorOf<[I64]>:$ciphertext); let results = (outs 1DTensorOf<[I64]>:$result); } +def BConcrete_NegateCRTLweBufferOp : BConcrete_Op<"negate_crt_lwe_buffer"> { + let arguments = (ins + 2DTensorOf<[I64]>:$ciphertext, + I64ArrayAttr:$crtDecomposition + ); + let results = (outs 2DTensorOf<[I64]>:$result); +} + def BConcrete_FillGlweFromTable : BConcrete_Op<"fill_glwe_from_table"> { let arguments = (ins 1DTensorOf<[I64]>:$glwe, @@ -67,5 +102,28 @@ def BConcrete_BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> { let results = (outs 1DTensorOf<[I64]>:$result); } +// TODO(16bits): hack +def BConcrete_WopPBSCRTLweBufferOp : BConcrete_Op<"wop_pbs_crt_lwe_buffer"> { + let arguments = (ins + 2DTensorOf<[I64]>:$ciphertext, + 1DTensorOf<[I64]>:$lookupTable, + // Bootstrap parameters + I32Attr : $bootstrapLevel, + I32Attr : $bootstrapBaseLog, + // Keyswitch parameters + I32Attr : $keyswitchLevel, + I32Attr : $keyswitchBaseLog, + // Packing keyswitch key parameters + I32Attr : $packingKeySwitchInputLweDimension, + I32Attr : $packingKeySwitchinputLweCount, + I32Attr : $packingKeySwitchoutputPolynomialSize, + I32Attr : $packingKeySwitchLevel, + I32Attr : $packingKeySwitchBaseLog, + // Circuit bootstrap parameters + I32Attr : $circuitBootstrapLevel, + I32Attr : $circuitBootstrapBaseLog + ); + let results = (outs 2DTensorOf<[I64]>:$result); +} #endif diff --git a/compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.h b/compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.h index 7e0a49447..9367fbb46 100644 --- a/compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.h +++ b/compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.h @@ -14,6 +14,8 @@ namespace mlir { namespace concretelang { std::unique_ptr> createAddRuntimeContext(); + +std::unique_ptr> createEliminateCRTOps(); } // namespace concretelang } // namespace mlir diff --git a/compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.td b/compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.td index e77d7770a..54251a113 100644 --- a/compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.td +++ b/compiler/include/concretelang/Dialect/BConcrete/Transforms/Passes.td @@ -16,4 +16,9 @@ def AddRuntimeContext : Pass<"add-runtime-context", "mlir::ModuleOp"> { let constructor = "mlir::concretelang::createAddRuntimeContext()"; } +def EliminateCRTOps : Pass<"eliminate-bconcrete-crt-ops", "mlir::func::FuncOp"> { + let summary = "Eliminate the crt bconcrete operators."; + let constructor = "mlir::concretelang::createEliminateCRTOpsPass()"; +} + #endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES diff --git a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td index 13dcb73c7..d3b6e6fdc 100644 --- a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td +++ b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td @@ -34,14 +34,14 @@ def Concrete_AddLweCiphertextsOp : Concrete_Op<"add_lwe_ciphertexts"> { def Concrete_AddPlaintextLweCiphertextOp : Concrete_Op<"add_plaintext_lwe_ciphertext"> { let summary = "Returns the sum of a clear integer and a lwe ciphertext"; - let arguments = (ins Concrete_LweCiphertextType:$lhs, Concrete_PlaintextType:$rhs); + let arguments = (ins Concrete_LweCiphertextType:$lhs, AnyInteger:$rhs); let results = (outs Concrete_LweCiphertextType:$result); } def Concrete_MulCleartextLweCiphertextOp : Concrete_Op<"mul_cleartext_lwe_ciphertext"> { let summary = "Returns the product of a clear integer and a lwe ciphertext"; - let arguments = (ins Concrete_LweCiphertextType:$lhs, Concrete_CleartextType:$rhs); + let arguments = (ins Concrete_LweCiphertextType:$lhs, AnyInteger:$rhs); let results = (outs Concrete_LweCiphertextType:$result); } @@ -82,18 +82,30 @@ def Concrete_KeySwitchLweOp : Concrete_Op<"keyswitch_lwe"> { let results = (outs Concrete_LweCiphertextType:$result); } -def Concrete_EncodeIntOp : Concrete_Op<"encode_int"> { - let summary = "Encodes an integer (for it to later be added to a LWE ciphertext)"; +// TODO(16bits): hack +def Concrete_WopPBSLweOp : Concrete_Op<"wop_pbs_lwe"> { + let summary = ""; - let arguments = (ins AnyInteger:$i); - let results = (outs Concrete_PlaintextType:$plaintext); -} - -def Concrete_IntToCleartextOp : Concrete_Op<"int_to_cleartext", [NoSideEffect]> { - let summary = "Keyswitches a LWE ciphertext"; - - let arguments = (ins AnyInteger:$i); - let results = (outs Concrete_CleartextType:$cleartext); + let arguments = (ins + Concrete_LweCiphertextType:$ciphertext, + 1DTensorOf<[I64]>:$accumulator, + // Bootstrap parameters + I32Attr : $bootstrapLevel, + I32Attr : $bootstrapBaseLog, + // Keyswitch parameters + I32Attr : $keyswitchLevel, + I32Attr : $keyswitchBaseLog, + // Packing keyswitch key parameters + I32Attr : $packingKeySwitchInputLweDimension, + I32Attr : $packingKeySwitchinputLweCount, + I32Attr : $packingKeySwitchoutputPolynomialSize, + I32Attr : $packingKeySwitchLevel, + I32Attr : $packingKeySwitchBaseLog, + // Circuit bootstrap parameters + I32Attr : $circuitBootstrapLevel, + I32Attr : $circuitBootstrapBaseLog + ); + let results = (outs Concrete_LweCiphertextType:$result); } #endif diff --git a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteTypes.td b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteTypes.td index 7f052f906..12c64fefb 100644 --- a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteTypes.td +++ b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteTypes.td @@ -40,7 +40,10 @@ def Concrete_LweCiphertextType : Concrete_Type<"LweCiphertext", [MemRefElementTy // The dimension of the lwe ciphertext "signed":$dimension, // Precision of the lwe ciphertext - "signed":$p + "signed":$p, + // CRT decomposition for large integers + ArrayRefParameter<"int64_t", "CRT decomposition">:$crtDecomposition + ); let hasCustomAssemblyFormat = 1; diff --git a/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td b/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td index 182f6e274..a83417186 100644 --- a/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td +++ b/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td @@ -115,4 +115,29 @@ def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe"> { let results = (outs TFHE_GLWECipherTextType : $result); } +def TFHE_WopPBSGLWEOp : TFHE_Op<"wop_pbs_glwe"> { + let summary = ""; + + let arguments = (ins + TFHE_GLWECipherTextType : $ciphertext, + 1DTensorOf<[I64]> : $lookupTable, + // Bootstrap parameters + I32Attr : $bootstrapLevel, + I32Attr : $bootstrapBaseLog, + // Keyswitch parameters + I32Attr : $keyswitchLevel, + I32Attr : $keyswitchBaseLog, + // Packing keyswitch key parameters + I32Attr : $packingKeySwitchInputLweDimension, + I32Attr : $packingKeySwitchinputLweCount, + I32Attr : $packingKeySwitchoutputPolynomialSize, + I32Attr : $packingKeySwitchLevel, + I32Attr : $packingKeySwitchBaseLog, + // Circuit bootstrap parameters + I32Attr : $circuitBootstrapLevel, + I32Attr : $circuitBootstrapBaseLog + ); + let results = (outs TFHE_GLWECipherTextType:$result); +} + #endif diff --git a/compiler/include/concretelang/Dialect/TFHE/IR/TFHETypes.td b/compiler/include/concretelang/Dialect/TFHE/IR/TFHETypes.td index 582a2795d..4ec31a728 100644 --- a/compiler/include/concretelang/Dialect/TFHE/IR/TFHETypes.td +++ b/compiler/include/concretelang/Dialect/TFHE/IR/TFHETypes.td @@ -6,18 +6,18 @@ include "concretelang/Dialect/TFHE/IR/TFHEDialect.td" include "mlir/IR/BuiltinTypes.td" -class TFHE_Type traits = []> : TypeDef { } +class TFHE_Type traits = []> + : TypeDef {} -def TFHE_GLWECipherTextType : TFHE_Type<"GLWECipherText", [MemRefElementTypeInterface]> { - let mnemonic = "glwe"; +def TFHE_GLWECipherTextType + : TFHE_Type<"GLWECipherText", [MemRefElementTypeInterface]> { + let mnemonic = "glwe"; - let summary = "A GLWE ciphertext"; + let summary = "A GLWE ciphertext"; - let description = [{ - An GLWE cipher text - }]; + let description = [{An GLWE cipher text}]; - let parameters = (ins + let parameters = (ins // The size of the mask "signed":$dimension, // Size of the polynome @@ -25,22 +25,22 @@ def TFHE_GLWECipherTextType : TFHE_Type<"GLWECipherText", [MemRefElementTypeInte // Number of bits of the ciphertext "signed":$bits, // Number of bits of the plain text representation - "signed":$p + "signed":$p, + // CRT decomposition for large integers + ArrayRefParameter<"int64_t", "CRT decomposition">:$crtDecomposition ); - let hasCustomAssemblyFormat = 1; + let hasCustomAssemblyFormat = 1; - let genVerifyDecl = true; + let genVerifyDecl = true; - let extraClassDeclaration = [{ - // Returns true if has an unparametrized parameters - bool hasUnparametrizedParameters() { - return getDimension() ==-1 || - getPolynomialSize() == -1 || - getBits() == -1 || - getP() == -1; - }; - }]; + let extraClassDeclaration = [{ + // Returns true if has an unparametrized parameters + bool hasUnparametrizedParameters() { + return getDimension() == -1 || getPolynomialSize() == -1 || + getBits() == -1 || getP() == -1; + }; + }]; } #endif diff --git a/compiler/include/concretelang/Runtime/wrappers.h b/compiler/include/concretelang/Runtime/wrappers.h index d9c6b4d1c..d9cd39363 100644 --- a/compiler/include/concretelang/Runtime/wrappers.h +++ b/compiler/include/concretelang/Runtime/wrappers.h @@ -58,6 +58,24 @@ void memref_bootstrap_lwe_u64( uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride, mlir::concretelang::RuntimeContext *context); +uint64_t encode_crt(int64_t plaintext, uint64_t modulus, uint64_t product); + +// TODO(16bits): Hackish wrapper for the 16 bits quick win +void memref_wop_pbs_crt_buffer( + // Output memref 2D memref + uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset_0, + uint64_t out_offset_1, uint64_t out_size_0, uint64_t out_size_1, + uint64_t out_stride_0, uint64_t out_stride_1, + // Input memref + uint64_t *in_allocated, uint64_t *in_aligned, uint64_t in_offset_0, + uint64_t in_offset_1, uint64_t in_size_0, uint64_t in_size_1, + uint64_t in_stride_0, uint64_t in_stride_1, + // clear text lut + uint64_t *lut_ct_allocated, uint64_t *lut_ct_aligned, + uint64_t lut_ct_offset, uint64_t lut_ct_size, uint64_t lut_ct_stride, + // runtime context that hold evluation keys + mlir::concretelang::RuntimeContext *context); + void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned, uint64_t src_offset, uint64_t src_size, uint64_t src_stride, uint64_t *dst_allocated, diff --git a/compiler/include/concretelang/Support/CompilerEngine.h b/compiler/include/concretelang/Support/CompilerEngine.h index 3b753a63e..2bd608dc2 100644 --- a/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compiler/include/concretelang/Support/CompilerEngine.h @@ -42,6 +42,11 @@ struct CompilationOptions { llvm::Optional v0Parameter; + /// largeIntegerParameter force the compiler engine to lower FHE.eint using + /// the large integers strategy with the given parameters. + llvm::Optional + largeIntegerParameter; + bool verifyDiagnostics; bool autoParallelize; diff --git a/compiler/include/concretelang/Support/LambdaSupport.h b/compiler/include/concretelang/Support/LambdaSupport.h index e1c5646f5..8fa73faf0 100644 --- a/compiler/include/concretelang/Support/LambdaSupport.h +++ b/compiler/include/concretelang/Support/LambdaSupport.h @@ -94,14 +94,15 @@ buildTensorLambdaResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { llvm::Expected> tensorOrError = typedResult>(keySet, result); - if (auto err = tensorOrError.takeError()) return std::move(err); - std::vector tensorDim(result.buffers[0].sizes.begin(), - result.buffers[0].sizes.end() - 1); + + auto tensorDim = result.asClearTextShape(0); + if (tensorDim.has_error()) + return StreamStringError(tensorDim.error().mesg); return std::make_unique>>( - *tensorOrError, tensorDim); + *tensorOrError, tensorDim.value()); } /// pecialization of `typedResult()` for a single result wrapped into diff --git a/compiler/include/concretelang/Support/Pipeline.h b/compiler/include/concretelang/Support/Pipeline.h index a888e5ca5..cbab04756 100644 --- a/compiler/include/concretelang/Support/Pipeline.h +++ b/compiler/include/concretelang/Support/Pipeline.h @@ -36,6 +36,7 @@ markFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module, mlir::LogicalResult lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, + llvm::Optional &fheContext, std::function enablePass); mlir::LogicalResult diff --git a/compiler/lib/ClientLib/CMakeLists.txt b/compiler/lib/ClientLib/CMakeLists.txt index 653b97111..2f4554dac 100644 --- a/compiler/lib/ClientLib/CMakeLists.txt +++ b/compiler/lib/ClientLib/CMakeLists.txt @@ -1,12 +1,12 @@ -add_compile_options( -Werror ) +add_compile_options(-Werror) -if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") +if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") # using Clang - add_compile_options( -Wno-error=pessimizing-move -Wno-pessimizing-move ) -elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + add_compile_options(-Wno-error=pessimizing-move -Wno-pessimizing-move) +elseif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") # using GCC if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0) - add_compile_options( -Werror -Wno-error=pessimizing-move -Wno-pessimizing-move ) + add_compile_options(-Werror -Wno-error=pessimizing-move -Wno-pessimizing-move) endif() endif() @@ -14,6 +14,7 @@ add_mlir_library( ConcretelangClientLib ClientLambda.cpp ClientParameters.cpp + CRT.cpp EncryptedArguments.cpp KeySet.cpp KeySetCache.cpp @@ -23,8 +24,6 @@ add_mlir_library( ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/concretelang/ClientLib - LINK_LIBS - ConcretelangRuntime LINK_LIBS PUBLIC Concrete ) diff --git a/compiler/lib/ClientLib/CRT.cpp b/compiler/lib/ClientLib/CRT.cpp new file mode 100644 index 000000000..ec7eabad0 --- /dev/null +++ b/compiler/lib/ClientLib/CRT.cpp @@ -0,0 +1,99 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include +#include + +#include "concretelang/ClientLib/CRT.h" + +namespace concretelang { +namespace clientlib { +namespace crt { +uint64_t productOfModuli(std::vector moduli) { + uint64_t product = 1; + for (auto modulus : moduli) { + product *= modulus; + } + return product; +} + +std::vector crt(std::vector moduli, uint64_t val) { + std::vector remainders(moduli.size(), 0); + + for (size_t i = 0; i < moduli.size(); i++) { + remainders[i] = val % moduli[i]; + } + return remainders; +} + +// https://www.geeksforgeeks.org/multiplicative-inverse-under-modulo-m/ +// Returns modulo inverse of a with respect +// to m using extended Euclid Algorithm +// Assumption: a and m are coprimes, i.e., +// gcd(a, m) = 1 +int64_t modInverse(int64_t a, int64_t m) { + int64_t m0 = m; + int64_t y = 0, x = 1; + + if (m == 1) + return 0; + + while (a > 1) { + // q is quotient + int64_t q = a / m; + int64_t t = m; + + // m is remainder now, process same as + // Euclid's algo + m = a % m; + a = t; + t = y; + + // Update y and x + y = x - q * y; + x = t; + } + + // Make x positive + if (x < 0) + x += m0; + + return x; +} + +uint64_t iCrt(std::vector moduli, std::vector remainders) { + // Compute the product of moduli + int64_t product = productOfModuli(moduli); + + int64_t result = 0; + + // Apply above formula + for (size_t i = 0; i < remainders.size(); i++) { + int tmp = product / moduli[i]; + result += remainders[i] * modInverse(tmp, moduli[i]) * tmp; + } + + return result % product; +} + +uint64_t encode(int64_t plaintext, uint64_t modulus, uint64_t product) { + // values are represented on the interval [0; product[ so we represent + // plantext on this interval + if (plaintext < 0) { + plaintext = product + plaintext; + } + __uint128_t m = plaintext % modulus; + return m * ((__uint128_t)(1) << 64) / modulus; +} + +uint64_t decode(uint64_t val, uint64_t modulus) { + auto result = (__uint128_t)val * (__uint128_t)modulus; + result = result + ((result & ((__uint128_t)(1) << 63)) << 1); + result = result / ((__uint128_t)(1) << 64); + return (uint64_t)result % modulus; +} +} // namespace crt +} // namespace clientlib +} // namespace concretelang diff --git a/compiler/lib/ClientLib/ClientParameters.cpp b/compiler/lib/ClientLib/ClientParameters.cpp index 7354feadc..403fa399a 100644 --- a/compiler/lib/ClientLib/ClientParameters.cpp +++ b/compiler/lib/ClientLib/ClientParameters.cpp @@ -234,6 +234,9 @@ llvm::json::Value toJSON(const Encoding &v) { llvm::json::Object object{ {"precision", v.precision}, }; + if (!v.crt.empty()) { + object.insert({"crt", v.crt}); + } return object; } bool fromJSON(const llvm::json::Value j, Encoding &v, llvm::json::Path p) { @@ -248,6 +251,18 @@ bool fromJSON(const llvm::json::Value j, Encoding &v, llvm::json::Path p) { return false; } v.precision = precision.getValue(); + auto crt = obj->getArray("crt"); + if (crt != nullptr) { + for (auto dim : *crt) { + auto iDim = dim.getAsInteger(); + if (!iDim.hasValue()) { + p.report("dimensions must be integer"); + return false; + } + v.crt.push_back(iDim.getValue()); + } + } + return true; } diff --git a/compiler/lib/ClientLib/EncryptedArguments.cpp b/compiler/lib/ClientLib/EncryptedArguments.cpp index 7ef765e72..c193c7085 100644 --- a/compiler/lib/ClientLib/EncryptedArguments.cpp +++ b/compiler/lib/ClientLib/EncryptedArguments.cpp @@ -11,17 +11,6 @@ namespace clientlib { using StringError = concretelang::error::StringError; -size_t bitWidthAsWord(size_t exactBitWidth) { - size_t sortedWordBitWidths[] = {8, 16, 32, 64}; - size_t previousWidth = 0; - for (auto currentWidth : sortedWordBitWidths) { - if (previousWidth < exactBitWidth && exactBitWidth <= currentWidth) { - return currentWidth; - } - } - return exactBitWidth; -} - outcome::checked, StringError> EncryptedArguments::exportPublicArguments(ClientParameters clientParameters, RuntimeContext runtimeContext) { @@ -33,7 +22,7 @@ outcome::checked EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) { OUTCOME_TRYV(checkPushTooManyArgs(keySet)); auto pos = currentPos++; - CircuitGate input = keySet.inputGate(pos); + OUTCOME_TRY(CircuitGate input, keySet.clientParameters().input(pos)); if (input.shape.size != 0) { return StringError("argument #") << pos << " is not a scalar"; } @@ -42,12 +31,11 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) { preparedArgs.push_back((void *)arg); return outcome::success(); } - ciphertextBuffers.resize(ciphertextBuffers.size() + 1); // Allocate empty + // Allocate empty + ciphertextBuffers.resize(ciphertextBuffers.size() + 1); TensorData &values_and_sizes = ciphertextBuffers.back(); - auto lweSize = keySet.getInputLweSecretKeyParam(pos).lweSize(); - values_and_sizes.sizes.push_back(lweSize); - values_and_sizes.values.resize(lweSize); - + values_and_sizes.sizes = keySet.clientParameters().bufferShape(input); + values_and_sizes.values.resize(keySet.clientParameters().bufferSize(input)); OUTCOME_TRYV(keySet.encrypt_lwe(pos, values_and_sizes.values.data(), arg)); // Note: Since we bufferized lwe ciphertext take care of memref calling // convention @@ -57,97 +45,18 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) { preparedArgs.push_back((void *)values_and_sizes.values.data()); // offset preparedArgs.push_back((void *)0); - // size - preparedArgs.push_back((void *)values_and_sizes.values.size()); - // stride - preparedArgs.push_back((void *)1); - return outcome::success(); -} - -outcome::checked -EncryptedArguments::pushArg(std::vector arg, KeySet &keySet) { - return pushArg(8, (void *)arg.data(), {(int64_t)arg.size()}, keySet); -} - -outcome::checked -EncryptedArguments::pushArg(size_t width, const void *data, - llvm::ArrayRef shape, KeySet &keySet) { - OUTCOME_TRYV(checkPushTooManyArgs(keySet)); - auto pos = currentPos; - CircuitGate input = keySet.inputGate(pos); - // Check the width of data - if (input.shape.width > 64) { - return StringError("argument #") - << pos << " width > 64 bits is not supported"; - } - auto roundedSize = bitWidthAsWord(input.shape.width); - if (width != roundedSize) { - return StringError("argument #") << pos << "width mismatch, got " << width - << " expected " << roundedSize; - } - // Check the shape of tensor - if (input.shape.dimensions.empty()) { - return StringError("argument #") << pos << "is not a tensor"; - } - if (shape.size() != input.shape.dimensions.size()) { - return StringError("argument #") - << pos << "has not the expected number of dimension, got " - << shape.size() << " expected " << input.shape.dimensions.size(); - } - ciphertextBuffers.resize(ciphertextBuffers.size() + 1); // Allocate empty - TensorData &values_and_sizes = ciphertextBuffers.back(); - for (size_t i = 0; i < shape.size(); i++) { - values_and_sizes.sizes.push_back(shape[i]); - if (shape[i] != input.shape.dimensions[i]) { - return StringError("argument #") - << pos << " has not the expected dimension #" << i << " , got " - << shape[i] << " expected " << input.shape.dimensions[i]; - } - } - if (input.encryption.hasValue()) { - auto lweSize = keySet.getInputLweSecretKeyParam(pos).lweSize(); - values_and_sizes.sizes.push_back(lweSize); - - // Encrypted tensor: for now we support only 8 bits for encrypted tensor - if (width != 8) { - return StringError("argument #") - << pos << " width mismatch, expected 8 got " << width; - } - const uint8_t *data8 = (const uint8_t *)data; - - // Allocate a buffer for ciphertexts of size of tensor - values_and_sizes.values.resize(input.shape.size * lweSize); - auto &values = values_and_sizes.values; - // Allocate ciphertexts and encrypt, for every values in tensor - for (size_t i = 0, offset = 0; i < input.shape.size; - i++, offset += lweSize) { - OUTCOME_TRYV(keySet.encrypt_lwe(pos, values.data() + offset, data8[i])); - } - } else { - values_and_sizes.values.resize(input.shape.size); - for (size_t i = 0; i < input.shape.size; i++) { - values_and_sizes.values[i] = ((const uint64_t *)data)[i]; - } - } - // allocated - preparedArgs.push_back(nullptr); - // aligned - preparedArgs.push_back((void *)values_and_sizes.values.data()); - // offset - preparedArgs.push_back((void *)0); // sizes - for (size_t size : values_and_sizes.sizes) { + for (auto size : values_and_sizes.sizes) { preparedArgs.push_back((void *)size); } - - // Set the stride for each dimension, equal to the product of the - // following dimensions. + // strides int64_t stride = values_and_sizes.length(); - for (size_t size : values_and_sizes.sizes) { + for (size_t i = 0; i < values_and_sizes.sizes.size() - 1; i++) { + auto size = values_and_sizes.sizes[i]; stride = (size == 0 ? 0 : (stride / size)); preparedArgs.push_back((void *)stride); } - currentPos++; + preparedArgs.push_back((void *)1); return outcome::success(); } diff --git a/compiler/lib/ClientLib/KeySet.cpp b/compiler/lib/ClientLib/KeySet.cpp index 18bda3732..e4e19af3b 100644 --- a/compiler/lib/ClientLib/KeySet.cpp +++ b/compiler/lib/ClientLib/KeySet.cpp @@ -4,6 +4,7 @@ // for license information. #include "concretelang/ClientLib/KeySet.h" +#include "concretelang/ClientLib/CRT.h" #include "concretelang/Support/Error.h" #define CAPI_ERR_TO_STRINGERROR(instr, msg) \ @@ -31,16 +32,16 @@ outcome::checked, StringError> KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb) { auto keySet = std::make_unique(); - OUTCOME_TRYV(keySet->generateKeysFromParams(params, seed_msb, seed_lsb)); OUTCOME_TRYV(keySet->setupEncryptionMaterial(params, seed_msb, seed_lsb)); - return std::move(keySet); } outcome::checked KeySet::setupEncryptionMaterial(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb) { + _clientParameters = params; + // Set inputs and outputs LWE secret keys { for (auto param : params.inputs) { @@ -189,9 +190,16 @@ KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size) { return StringError("allocate_lwe position of argument is too high"); } auto inputSk = inputs[argPos]; + auto encryption = std::get<0>(inputSk).encryption; + if (!encryption.hasValue()) { + return StringError("allocate_lwe argument #") + << argPos << "is not encypeted"; + } + auto numBlocks = + encryption->encoding.crt.empty() ? 1 : encryption->encoding.crt.size(); size = std::get<1>(inputSk).lweSize(); - *ciphertext = (uint64_t *)malloc(sizeof(uint64_t) * size); + *ciphertext = (uint64_t *)malloc(sizeof(uint64_t) * size * numBlocks); return outcome::success(); } @@ -205,20 +213,40 @@ bool KeySet::isOutputEncrypted(size_t argPos) { std::get<0>(outputs[argPos]).encryption.hasValue(); } +/// Return the number of bits to represents the given value +uint64_t bitWidthOfValue(uint64_t value) { return std::ceil(std::log2(value)); } + outcome::checked KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) { if (argPos >= inputs.size()) { return StringError("encrypt_lwe position of argument is too high"); } auto inputSk = inputs[argPos]; - if (!std::get<0>(inputSk).encryption.hasValue()) { + auto encryption = std::get<0>(inputSk).encryption; + if (!encryption.hasValue()) { return StringError("encrypt_lwe the positional argument is not encrypted"); } - // Encode - TODO we could check if the input value is in the right range - uint64_t plaintext = - input << (64 - (std::get<0>(inputSk).encryption->encoding.precision + 1)); - ::encrypt_lwe_u64(engine, std::get<2>(inputSk), ciphertext, plaintext, - std::get<0>(inputSk).encryption->variance); + auto encoding = encryption->encoding; + auto lweSecretKeyParam = std::get<1>(inputSk); + auto lweSecretKey = std::get<2>(inputSk); + // CRT encoding - N blocks with crt encoding + auto crt = encryption->encoding.crt; + if (!crt.empty()) { + // Put each decomposition into a new ciphertext + auto product = crt::productOfModuli(crt); + for (auto modulus : crt) { + auto plaintext = crt::encode(input, modulus, product); + ::encrypt_lwe_u64(engine, lweSecretKey, ciphertext, plaintext, + encryption->variance); + ciphertext = ciphertext + lweSecretKeyParam.lweSize(); + } + return outcome::success(); + } + // Simple TFHE integers - 1 blocks with one padding bits + // TODO we could check if the input value is in the right range + uint64_t plaintext = input << (64 - (encryption->encoding.precision + 1)); + ::encrypt_lwe_u64(engine, lweSecretKey, ciphertext, plaintext, + encryption->variance); return outcome::success(); } @@ -228,13 +256,31 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) { return StringError("decrypt_lwe: position of argument is too high"); } auto outputSk = outputs[argPos]; - if (!std::get<0>(outputSk).encryption.hasValue()) { + auto lweSecretKey = std::get<2>(outputSk); + auto lweSecretKeyParam = std::get<1>(outputSk); + auto encryption = std::get<0>(outputSk).encryption; + if (!encryption.hasValue()) { return StringError("decrypt_lwe: the positional argument is not encrypted"); } - uint64_t plaintext = - ::decrypt_lwe_u64(engine, std::get<2>(outputSk), ciphertext); + auto crt = encryption->encoding.crt; + // CRT encoding - N blocks with crt encoding + if (!crt.empty()) { + std::vector remainders; + // decrypt and decode remainders + for (auto modulus : crt) { + auto decrypted = ::decrypt_lwe_u64(engine, lweSecretKey, ciphertext); + auto plaintext = crt::decode(decrypted, modulus); + remainders.push_back(plaintext); + ciphertext = ciphertext + lweSecretKeyParam.lweSize(); + } + // compute the inverse crt + output = crt::iCrt(crt, remainders); + return outcome::success(); + } + // Simple TFHE integers - 1 blocks with one padding bits + uint64_t plaintext = ::decrypt_lwe_u64(engine, lweSecretKey, ciphertext); // Decode - size_t precision = std::get<0>(outputSk).encryption->encoding.precision; + size_t precision = encryption->encoding.precision; output = plaintext >> (64 - precision - 2); size_t carry = output % 2; output = ((output >> 1) + carry) % (1 << (precision + 1)); diff --git a/compiler/lib/ClientLib/KeySetCache.cpp b/compiler/lib/ClientLib/KeySetCache.cpp index 5970df70b..3cfb72d71 100644 --- a/compiler/lib/ClientLib/KeySetCache.cpp +++ b/compiler/lib/ClientLib/KeySetCache.cpp @@ -90,7 +90,6 @@ KeySetCache::loadKeys(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb, std::string folderPath) { // TODO: text dump of all parameter in /hash auto key_set = std::make_unique(); - // Mark the folder as recently use. // e.g. so the CI can do some cleanup of unused keys. utime(folderPath.c_str(), nullptr); diff --git a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp index f57192a07..86865b07b 100644 --- a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp +++ b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp @@ -38,6 +38,9 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Function.h" +namespace Concrete = ::mlir::concretelang::Concrete; +namespace BConcrete = ::mlir::concretelang::BConcrete; + namespace { struct ConcreteToBConcretePass : public ConcreteToBConcreteBase { @@ -68,9 +71,14 @@ public: }); addConversion([&](mlir::concretelang::Concrete::LweCiphertextType type) { assert(type.getDimension() != -1); + llvm::SmallVector shape; + auto crt = type.getCrtDecomposition(); + if (!crt.empty()) { + shape.push_back(crt.size()); + } + shape.push_back(type.getDimension() + 1); return mlir::RankedTensorType::get( - {type.getDimension() + 1}, - mlir::IntegerType::get(type.getContext(), 64)); + shape, mlir::IntegerType::get(type.getContext(), 64)); }); addConversion([&](mlir::concretelang::Concrete::GlweCiphertextType type) { assert(type.getGlweDimension() != -1); @@ -91,91 +99,18 @@ public: mlir::SmallVector newShape; newShape.reserve(type.getShape().size() + 1); newShape.append(type.getShape().begin(), type.getShape().end()); + auto crt = lwe.getCrtDecomposition(); + if (!crt.empty()) { + newShape.push_back(crt.size()); + } newShape.push_back(lwe.getDimension() + 1); mlir::Type r = mlir::RankedTensorType::get( newShape, mlir::IntegerType::get(type.getContext(), 64)); return r; }); - addConversion([&](mlir::MemRefType type) { - auto lwe = type.getElementType() - .dyn_cast_or_null< - mlir::concretelang::Concrete::LweCiphertextType>(); - if (lwe == nullptr) { - return (mlir::Type)(type); - } - assert(lwe.getDimension() != -1); - mlir::SmallVector newShape; - newShape.reserve(type.getShape().size() + 1); - newShape.append(type.getShape().begin(), type.getShape().end()); - newShape.push_back(lwe.getDimension() + 1); - mlir::Type r = mlir::MemRefType::get( - newShape, mlir::IntegerType::get(type.getContext(), 64)); - return r; - }); } }; -struct ConcreteEncodeIntOpPattern - : public mlir::OpRewritePattern { - ConcreteEncodeIntOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : mlir::OpRewritePattern( - context, benefit) {} - - mlir::LogicalResult - matchAndRewrite(mlir::concretelang::Concrete::EncodeIntOp op, - mlir::PatternRewriter &rewriter) const override { - { - mlir::Value castedInt = rewriter.create( - op.getLoc(), rewriter.getIntegerType(64), op->getOperands().front()); - mlir::Value constantShiftOp = rewriter.create( - op.getLoc(), rewriter.getI64IntegerAttr(64 - op.getType().getP())); - - mlir::Type resultType = rewriter.getIntegerType(64); - rewriter.replaceOpWithNewOp( - op, resultType, castedInt, constantShiftOp); - } - return mlir::success(); - }; -}; - -struct ConcreteIntToCleartextOpPattern - : public mlir::OpRewritePattern< - mlir::concretelang::Concrete::IntToCleartextOp> { - ConcreteIntToCleartextOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : mlir::OpRewritePattern( - context, benefit) {} - - mlir::LogicalResult - matchAndRewrite(mlir::concretelang::Concrete::IntToCleartextOp op, - mlir::PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, rewriter.getIntegerType(64), op->getOperands().front()); - return mlir::success(); - }; -}; - -/// This rewrite pattern transforms any instance of `Concrete.zero_tensor` -/// operators. -/// -/// Example: -/// -/// ```mlir -/// %0 = "Concrete.zero_tensor" () : -/// tensor<...x!Concrete.lwe_ciphertext> -/// ``` -/// -/// becomes: -/// -/// ```mlir -/// %0 = tensor.generate { -/// ^bb0(... : index): -/// %c0 = arith.constant 0 : i64 -/// tensor.yield %z -/// }: tensor<...xlweDim+1xi64> -/// i64> -/// ``` template struct ZeroOpPattern : public mlir::OpRewritePattern { ZeroOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) @@ -205,20 +140,7 @@ struct ZeroOpPattern : public mlir::OpRewritePattern { }; }; -/// This template rewrite pattern transforms any instance of -/// `ConcreteOp` to an instance of `BConcreteOp`. -/// -/// Example: -/// -/// %0 = "ConcreteOp"(%arg0, ...) : -/// (!Concrete.lwe_ciphertext, ...) -> -/// (!Concrete.lwe_ciphertext) -/// -/// becomes: -/// -/// %0 = "BConcreteOp"(%arg0, ...) : (tensor>, ..., ) -> -/// (tensor>) -template +template struct LowToBConcrete : public mlir::OpRewritePattern { LowToBConcrete(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} @@ -236,9 +158,126 @@ struct LowToBConcrete : public mlir::OpRewritePattern { llvm::ArrayRef<::mlir::NamedAttribute> attributes = concreteOp.getOperation()->getAttrs(); - BConcreteOp bConcreteOp = rewriter.replaceOpWithNewOp( - concreteOp, newResultTy, concreteOp.getOperation()->getOperands(), - attributes); + auto crt = resultTy.getCrtDecomposition(); + mlir::Operation *bConcreteOp; + if (crt.empty()) { + bConcreteOp = rewriter.replaceOpWithNewOp( + concreteOp, newResultTy, concreteOp.getOperation()->getOperands(), + attributes); + } else { + auto newAttributes = attributes.vec(); + newAttributes.push_back(rewriter.getNamedAttr( + "crtDecomposition", rewriter.getI64ArrayAttr(crt))); + bConcreteOp = rewriter.replaceOpWithNewOp( + concreteOp, newResultTy, concreteOp.getOperation()->getOperands(), + newAttributes); + } + + mlir::concretelang::convertOperandAndResultTypes( + rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) { + return converter.convertType(t); + }); + + return ::mlir::success(); + }; +}; + +struct AddPlaintextLweCiphertextOpPattern + : public mlir::OpRewritePattern { + AddPlaintextLweCiphertextOpPattern(::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern( + context, benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(Concrete::AddPlaintextLweCiphertextOp concreteOp, + ::mlir::PatternRewriter &rewriter) const override { + ConcreteToBConcreteTypeConverter converter; + auto loc = concreteOp.getLoc(); + mlir::concretelang::Concrete::LweCiphertextType resultTy = + ((mlir::Type)concreteOp->getResult(0).getType()) + .cast(); + auto newResultTy = + converter.convertType(resultTy).cast(); + + llvm::ArrayRef<::mlir::NamedAttribute> attributes = + concreteOp.getOperation()->getAttrs(); + + auto crt = resultTy.getCrtDecomposition(); + mlir::Operation *bConcreteOp; + if (crt.empty()) { + // Encode the plaintext value + mlir::Value castedInt = rewriter.create( + loc, rewriter.getIntegerType(64), concreteOp.rhs()); + mlir::Value constantShiftOp = rewriter.create( + loc, + rewriter.getI64IntegerAttr(64 - concreteOp.getType().getP() - 1)); + auto encoded = rewriter.create( + loc, rewriter.getI64Type(), castedInt, constantShiftOp); + bConcreteOp = + rewriter.replaceOpWithNewOp( + concreteOp, newResultTy, + mlir::ValueRange{concreteOp.lhs(), encoded}, attributes); + } else { + // The encoding is done when we eliminate CRT ops + auto newAttributes = attributes.vec(); + newAttributes.push_back(rewriter.getNamedAttr( + "crtDecomposition", rewriter.getI64ArrayAttr(crt))); + bConcreteOp = + rewriter.replaceOpWithNewOp( + concreteOp, newResultTy, concreteOp.getOperation()->getOperands(), + newAttributes); + } + + mlir::concretelang::convertOperandAndResultTypes( + rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) { + return converter.convertType(t); + }); + + return ::mlir::success(); + }; +}; + +struct MulCleartextLweCiphertextOpPattern + : public mlir::OpRewritePattern { + MulCleartextLweCiphertextOpPattern(::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern( + context, benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(Concrete::MulCleartextLweCiphertextOp concreteOp, + ::mlir::PatternRewriter &rewriter) const override { + ConcreteToBConcreteTypeConverter converter; + auto loc = concreteOp.getLoc(); + mlir::concretelang::Concrete::LweCiphertextType resultTy = + ((mlir::Type)concreteOp->getResult(0).getType()) + .cast(); + auto newResultTy = + converter.convertType(resultTy).cast(); + + llvm::ArrayRef<::mlir::NamedAttribute> attributes = + concreteOp.getOperation()->getAttrs(); + + auto crt = resultTy.getCrtDecomposition(); + mlir::Operation *bConcreteOp; + if (crt.empty()) { + // Encode the plaintext value + mlir::Value castedInt = rewriter.create( + loc, rewriter.getIntegerType(64), concreteOp.rhs()); + bConcreteOp = + rewriter.replaceOpWithNewOp( + concreteOp, newResultTy, + mlir::ValueRange{concreteOp.lhs(), castedInt}, attributes); + } else { + auto newAttributes = attributes.vec(); + newAttributes.push_back(rewriter.getNamedAttr( + "crtDecomposition", rewriter.getI64ArrayAttr(crt))); + bConcreteOp = + rewriter.replaceOpWithNewOp( + concreteOp, newResultTy, concreteOp.getOperation()->getOperands(), + newAttributes); + } mlir::concretelang::convertOperandAndResultTypes( rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) { @@ -249,27 +288,6 @@ struct LowToBConcrete : public mlir::OpRewritePattern { }; }; -/// This rewrite pattern transforms any instance of -/// `Concrete.glwe_from_table` operators. -/// -/// Example: -/// -/// ```mlir -/// %0 = "Concrete.glwe_from_table"(%tlu) -/// : (tensor<$Dxi64>) -> -/// !Concrete.glwe_ciphertext<$polynomialSize,$glweDimension,$p> -/// ``` -/// -/// with $D = 2^$p -/// -/// becomes: -/// -/// ```mlir -/// %0 = linalg.init_tensor [polynomialSize*(glweDimension+1)] -/// : tensor -/// "BConcrete.fill_glwe_from_table" : (%0, polynomialSize, glweDimension, %tlu) -/// : tensor, i64, i64, tensor<$Dxi64> -/// ``` struct GlweFromTablePattern : public mlir::OpRewritePattern< mlir::concretelang::Concrete::GlweFromTable> { GlweFromTablePattern(::mlir::MLIRContext *context, @@ -307,26 +325,6 @@ struct GlweFromTablePattern : public mlir::OpRewritePattern< }; }; -/// This rewrite pattern transforms any instance of -/// `tensor.extract_slice` operators that operates on tensor of lwe ciphertext. -/// -/// Example: -/// -/// ```mlir -/// %0 = tensor.extract_slice %arg0 -/// [offsets...] [sizes...] [strides...] -/// : tensor<...x!Concrete.lwe_ciphertext> to -/// tensor<...x!Concrete.lwe_ciphertext> -/// ``` -/// -/// becomes: -/// -/// ```mlir -/// %0 = tensor.extract_slice %arg0 -/// [offsets..., 0] [sizes..., lweDimension+1] [strides..., 1] -/// : tensor<...xlweDimension+1,i64> to -/// tensor<...xlweDimension+1,i64> -/// ``` struct ExtractSliceOpPattern : public mlir::OpRewritePattern { ExtractSliceOpPattern(::mlir::MLIRContext *context, @@ -339,29 +337,41 @@ struct ExtractSliceOpPattern ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; auto resultTy = extractSliceOp.result().getType(); - auto resultEltTy = + auto lweResultTy = resultTy.cast() .getElementType() .cast(); - auto newResultTy = converter.convertType(resultTy); + auto nbBlock = lweResultTy.getCrtDecomposition().size(); + auto newResultTy = + converter.convertType(resultTy).cast(); // add 0 to the static_offsets mlir::SmallVector staticOffsets; staticOffsets.append(extractSliceOp.static_offsets().begin(), extractSliceOp.static_offsets().end()); + if (nbBlock != 0) { + staticOffsets.push_back(rewriter.getI64IntegerAttr(0)); + } staticOffsets.push_back(rewriter.getI64IntegerAttr(0)); // add the lweSize to the sizes mlir::SmallVector staticSizes; staticSizes.append(extractSliceOp.static_sizes().begin(), extractSliceOp.static_sizes().end()); - staticSizes.push_back( - rewriter.getI64IntegerAttr(resultEltTy.getDimension() + 1)); + if (nbBlock != 0) { + staticSizes.push_back(rewriter.getI64IntegerAttr( + newResultTy.getDimSize(newResultTy.getRank() - 2))); + } + staticSizes.push_back(rewriter.getI64IntegerAttr( + newResultTy.getDimSize(newResultTy.getRank() - 1))); // add 1 to the strides mlir::SmallVector staticStrides; staticStrides.append(extractSliceOp.static_strides().begin(), extractSliceOp.static_strides().end()); + if (nbBlock != 0) { + staticStrides.push_back(rewriter.getI64IntegerAttr(1)); + } staticStrides.push_back(rewriter.getI64IntegerAttr(1)); // replace tensor.extract_slice to the new one @@ -382,26 +392,6 @@ struct ExtractSliceOpPattern }; }; -/// This rewrite pattern transforms any instance of -/// `tensor.extract` operators that operates on tensor of lwe ciphertext. -/// -/// Example: -/// -/// ```mlir -/// %0 = tensor.extract %t[offsets...] -/// : tensor<...x!Concrete.lwe_ciphertext> -/// ``` -/// -/// becomes: -/// -/// ```mlir -/// %1 = tensor.extract_slice %arg0 -/// [offsets...] [1..., lweDimension+1] [1...] -/// : tensor<...xlweDimension+1,i64> to -/// tensor<1...xlweDimension+1,i64> -/// %0 = linalg.tensor_collapse_shape %0 [[...]] : -/// tensor<1x1xlweDimension+1xi64> into tensor -/// ``` // TODO: since they are a bug on lowering extract_slice with rank reduction we // add a linalg.tensor_collapse_shape after the extract_slice without rank // reduction. See @@ -424,20 +414,29 @@ struct ExtractOpPattern if (lweResultTy == nullptr) { return mlir::failure(); } - + auto nbBlock = lweResultTy.getCrtDecomposition().size(); auto newResultTy = converter.convertType(lweResultTy).cast(); - auto rankOfResult = extractOp.indices().size() + 1; - + auto rankOfResult = extractOp.indices().size() + + /* for the lwe dimension */ 1 + + /* for the block dimension */ + (nbBlock == 0 ? 0 : 1); // [min..., 0] for static_offsets () mlir::SmallVector staticOffsets( rankOfResult, rewriter.getI64IntegerAttr(std::numeric_limits::min())); + if (nbBlock != 0) { + staticOffsets[staticOffsets.size() - 2] = rewriter.getI64IntegerAttr(0); + } staticOffsets[staticOffsets.size() - 1] = rewriter.getI64IntegerAttr(0); - // [1..., lweDimension+1] for static_sizes + // [1..., lweDimension+1] for static_sizes or + // [1..., nbBlock, lweDimension+1] mlir::SmallVector staticSizes( rankOfResult, rewriter.getI64IntegerAttr(1)); + if (nbBlock != 0) { + staticSizes[staticSizes.size() - 2] = rewriter.getI64IntegerAttr(nbBlock); + } staticSizes[staticSizes.size() - 1] = rewriter.getI64IntegerAttr( newResultTy.getDimSize(newResultTy.getRank() - 1)); @@ -446,38 +445,45 @@ struct ExtractOpPattern rankOfResult, rewriter.getI64IntegerAttr(1)); // replace tensor.extract_slice to the new one - mlir::SmallVector extractedSliceShape( - extractOp.indices().size() + 1, 0); - extractedSliceShape.reserve(extractOp.indices().size() + 1); - for (size_t i = 0; i < extractedSliceShape.size() - 1; i++) { - extractedSliceShape[i] = 1; + mlir::SmallVector extractedSliceShape(rankOfResult, 1); + if (nbBlock != 0) { + extractedSliceShape[extractedSliceShape.size() - 2] = nbBlock; + extractedSliceShape[extractedSliceShape.size() - 1] = + newResultTy.getDimSize(1); + } else { + extractedSliceShape[extractedSliceShape.size() - 1] = + newResultTy.getDimSize(0); } - extractedSliceShape[extractedSliceShape.size() - 1] = - newResultTy.getDimSize(0); auto extractedSliceType = mlir::RankedTensorType::get(extractedSliceShape, rewriter.getI64Type()); + auto extractedSlice = rewriter.create( extractOp.getLoc(), extractedSliceType, extractOp.tensor(), extractOp.indices(), mlir::SmallVector{}, mlir::SmallVector{}, rewriter.getArrayAttr(staticOffsets), rewriter.getArrayAttr(staticSizes), rewriter.getArrayAttr(staticStrides)); - mlir::concretelang::convertOperandAndResultTypes( rewriter, extractedSlice, [&](mlir::MLIRContext *, mlir::Type t) { return converter.convertType(t); }); mlir::ReassociationIndices reassociation; - for (int64_t i = 0; i < extractedSliceType.getRank(); i++) { + for (int64_t i = 0; + i < extractedSliceType.getRank() - (nbBlock == 0 ? 0 : 1); i++) { reassociation.push_back(i); } + mlir::SmallVector reassocs{reassociation}; + + if (nbBlock != 0) { + reassocs.push_back({extractedSliceType.getRank() - 1}); + } + mlir::tensor::CollapseShapeOp collapseOp = rewriter.replaceOpWithNewOp( - extractOp, newResultTy, extractedSlice, - mlir::SmallVector{reassociation}); + extractOp, newResultTy, extractedSlice, reassocs); mlir::concretelang::convertOperandAndResultTypes( rewriter, collapseOp, [&](mlir::MLIRContext *, mlir::Type t) { @@ -488,26 +494,6 @@ struct ExtractOpPattern }; }; -/// This rewrite pattern transforms any instance of -/// `tensor.insert_slice` operators that operates on tensor of lwe ciphertext. -/// -/// Example: -/// -/// ```mlir -/// %0 = tensor.insert_slice %arg1 -/// into %arg0[offsets...] [sizes...] [strides...] -/// : tensor<...x!Concrete.lwe_ciphertext> into -/// tensor<...x!Concrete.lwe_ciphertext> -/// ``` -/// -/// becomes: -/// -/// ```mlir -/// %0 = tensor.insert_slice %arg1 -/// into %arg0[offsets..., 0] [sizes..., lweDimension+1] [strides..., 1] -/// : tensor<...xlweDimension+1xi64> into -/// tensor<...xlweDimension+1xi64> -/// ``` struct InsertSliceOpPattern : public mlir::OpRewritePattern { InsertSliceOpPattern(::mlir::MLIRContext *context, @@ -520,7 +506,14 @@ struct InsertSliceOpPattern ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; auto resultTy = insertSliceOp.result().getType(); - + auto lweResultTy = + resultTy.cast() + .getElementType() + .cast(); + if (lweResultTy == nullptr) { + return mlir::failure(); + } + auto nbBlock = lweResultTy.getCrtDecomposition().size(); auto newResultTy = converter.convertType(resultTy).cast(); @@ -528,12 +521,19 @@ struct InsertSliceOpPattern mlir::SmallVector staticOffsets; staticOffsets.append(insertSliceOp.static_offsets().begin(), insertSliceOp.static_offsets().end()); + if (nbBlock != 0) { + staticOffsets.push_back(rewriter.getI64IntegerAttr(0)); + } staticOffsets.push_back(rewriter.getI64IntegerAttr(0)); // add lweDimension+1 to static_sizes mlir::SmallVector staticSizes; staticSizes.append(insertSliceOp.static_sizes().begin(), insertSliceOp.static_sizes().end()); + if (nbBlock != 0) { + staticSizes.push_back(rewriter.getI64IntegerAttr( + newResultTy.getDimSize(newResultTy.getRank() - 2))); + } staticSizes.push_back(rewriter.getI64IntegerAttr( newResultTy.getDimSize(newResultTy.getRank() - 1))); @@ -541,6 +541,9 @@ struct InsertSliceOpPattern mlir::SmallVector staticStrides; staticStrides.append(insertSliceOp.static_strides().begin(), insertSliceOp.static_strides().end()); + if (nbBlock != 0) { + staticStrides.push_back(rewriter.getI64IntegerAttr(1)); + } staticStrides.push_back(rewriter.getI64IntegerAttr(1)); // replace tensor.insert_slice with the new one @@ -560,28 +563,6 @@ struct InsertSliceOpPattern }; }; -/// This rewrite pattern transforms any instance of `tensor.insert` -/// operators that operates on an lwe ciphertexts to a -/// `tensor.insert_slice` op operating on the bufferized representation -/// of the ciphertext. -/// -/// Example: -/// -/// ```mlir -/// %0 = tensor.insert %arg1 -/// into %arg0[offsets...] -/// : !Concrete.lwe_ciphertext into -/// tensor<...x!Concrete.lwe_ciphertext> -/// ``` -/// -/// becomes: -/// -/// ```mlir -/// %0 = tensor.insert_slice %arg1 -/// into %arg0[offsets..., 0] [sizes..., lweDimension+1] [strides..., 1] -/// : tensor into -/// tensor<...xlweDimension+1xi64> -/// ``` struct InsertOpPattern : public mlir::OpRewritePattern { InsertOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) @@ -591,14 +572,24 @@ struct InsertOpPattern : public mlir::OpRewritePattern { matchAndRewrite(mlir::tensor::InsertOp insertOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; - mlir::Type resultTy = insertOp.result().getType(); + auto resultTy = + insertOp.result().getType().dyn_cast_or_null(); + auto lweResultTy = resultTy.getElementType() + .dyn_cast_or_null(); + if (lweResultTy == nullptr) { + return mlir::failure(); + }; + auto hasBlock = lweResultTy.getCrtDecomposition().size() != 0; mlir::RankedTensorType newResultTy = converter.convertType(resultTy).cast(); - // add 0 to static_offsets + // add zeros to static_offsets mlir::SmallVector offsets; offsets.append(insertOp.indices().begin(), insertOp.indices().end()); offsets.push_back(rewriter.getIndexAttr(0)); + if (hasBlock) { + offsets.push_back(rewriter.getIndexAttr(0)); + } // Inserting a smaller tensor into a (potentially) bigger one. Set // dimensions for all leading dimensions of the target tensor not @@ -607,6 +598,10 @@ struct InsertOpPattern : public mlir::OpRewritePattern { rewriter.getI64IntegerAttr(1)); // Add size for the bufferized source element + if (hasBlock) { + sizes.push_back(rewriter.getI64IntegerAttr( + newResultTy.getDimSize(newResultTy.getRank() - 2))); + } sizes.push_back(rewriter.getI64IntegerAttr( newResultTy.getDimSize(newResultTy.getRank() - 1))); @@ -711,26 +706,26 @@ struct FromElementsOpPattern }; }; -/// This template rewrite pattern transforms any instance of -/// `ShapeOp` operators that operates on tensor of lwe ciphertext by adding the -/// lwe size as a size of the tensor result and by adding a trivial -/// reassociation at the end of the reassociations map. -/// -/// Example: -/// -/// ```mlir -/// %0 = "ShapeOp" %arg0 [reassocations...] -/// : tensor<...x!Concrete.lwe_ciphertext> into -/// tensor<...x!Concrete.lwe_ciphertext> -/// ``` -/// -/// becomes: -/// -/// ```mlir -/// %0 = "ShapeOp" %arg0 [reassociations..., [inRank or outRank]] -/// : tensor<...xlweDimesion+1xi64> into -/// tensor<...xlweDimesion+1xi64> -/// ``` +// This template rewrite pattern transforms any instance of +// `ShapeOp` operators that operates on tensor of lwe ciphertext by adding the +// lwe size as a size of the tensor result and by adding a trivial +// reassociation at the end of the reassociations map. +// +// Example: +// +// ```mlir +// %0 = "ShapeOp" %arg0 [reassocations...] +// : tensor<...x!Concrete.lwe_ciphertext> into +// tensor<...x!Concrete.lwe_ciphertext> +// ``` +// +// becomes: +// +// ```mlir +// %0 = "ShapeOp" %arg0 [reassociations..., [inRank or outRank]] +// : tensor<...xlweDimesion+1xi64> into +// tensor<...xlweDimesion+1xi64> +// ``` template struct TensorShapeOpPattern : public mlir::OpRewritePattern { TensorShapeOpPattern(::mlir::MLIRContext *context, @@ -741,22 +736,35 @@ struct TensorShapeOpPattern : public mlir::OpRewritePattern { matchAndRewrite(ShapeOp shapeOp, ::mlir::PatternRewriter &rewriter) const override { ConcreteToBConcreteTypeConverter converter; - auto resultTy = shapeOp.result().getType(); + auto resultTy = ((mlir::Type)shapeOp.result().getType()).cast(); + auto lweResultTy = + ((mlir::Type)resultTy.getElementType()) + .cast(); auto newResultTy = ((mlir::Type)converter.convertType(resultTy)).cast(); - // add [rank] to reassociations - auto oldReassocs = shapeOp.getReassociationIndices(); - mlir::SmallVector newReassocs; - newReassocs.append(oldReassocs.begin(), oldReassocs.end()); - mlir::ReassociationIndices lweAssoc; auto reassocTy = ((mlir::Type)converter.convertType( (inRank ? shapeOp.src() : shapeOp.result()).getType())) .cast(); - lweAssoc.push_back(reassocTy.getRank() - 1); - newReassocs.push_back(lweAssoc); + + auto oldReassocs = shapeOp.getReassociationIndices(); + mlir::SmallVector newReassocs; + newReassocs.append(oldReassocs.begin(), oldReassocs.end()); + // add [rank-1] to reassociations if crt decomp + if (!lweResultTy.getCrtDecomposition().empty()) { + mlir::ReassociationIndices lweAssoc; + lweAssoc.push_back(reassocTy.getRank() - 2); + newReassocs.push_back(lweAssoc); + } + + // add [rank] to reassociations + { + mlir::ReassociationIndices lweAssoc; + lweAssoc.push_back(reassocTy.getRank() - 1); + newReassocs.push_back(lweAssoc); + } ShapeOp op = rewriter.replaceOpWithNewOp( shapeOp, newResultTy, shapeOp.src(), newReassocs); @@ -785,20 +793,6 @@ void insertTensorShapeOpPattern(mlir::MLIRContext &context, }); } -/// Rewrites `bufferization.alloc_tensor` ops for which the converted type in -/// BConcrete is different from the original type. -/// -/// Example: -/// -/// ``` -/// bufferization.alloc_tensor() : tensor<4x!Concrete.lwe_ciphertext<4096,6>> -/// ``` -/// -/// becomes: -/// -/// ``` -/// bufferization.alloc_tensor() : tensor<4x4097xi64> -/// ``` struct AllocTensorOpPattern : public mlir::OpRewritePattern { AllocTensorOpPattern(::mlir::MLIRContext *context, @@ -877,10 +871,6 @@ void ConcreteToBConcretePass::runOnOperation() { // Add Concrete ops are illegal after the conversion target.addIllegalDialect(); - // Add patterns to convert cleartext and plaintext to i64 - patterns - .insert( - &getContext()); target.addLegalDialect(); // Add patterns to convert the zero ops to tensor.generate @@ -894,27 +884,25 @@ void ConcreteToBConcretePass::runOnOperation() { // BConcrete op patterns.insert< LowToBConcrete, - LowToBConcrete< - mlir::concretelang::Concrete::AddPlaintextLweCiphertextOp, - mlir::concretelang::BConcrete::AddPlaintextLweBufferOp>, - LowToBConcrete< - mlir::concretelang::Concrete::MulCleartextLweCiphertextOp, - mlir::concretelang::BConcrete::MulCleartextLweBufferOp>, - LowToBConcrete< - mlir::concretelang::Concrete::MulCleartextLweCiphertextOp, - mlir::concretelang::BConcrete::MulCleartextLweBufferOp>, + mlir::concretelang::BConcrete::AddLweBuffersOp, + BConcrete::AddCRTLweBuffersOp>, + AddPlaintextLweCiphertextOpPattern, MulCleartextLweCiphertextOpPattern, LowToBConcrete, + mlir::concretelang::BConcrete::NegateLweBufferOp, + BConcrete::NegateCRTLweBufferOp>, LowToBConcrete, LowToBConcrete>( - &getContext()); + mlir::concretelang::BConcrete::BootstrapLweBufferOp, + mlir::concretelang::BConcrete::KeySwitchLweBufferOp>, + LowToBConcrete>(&getContext()); patterns.insert(&getContext()); - // Add patterns to rewrite tensor operators that works on encrypted tensors + // Add patterns to rewrite tensor operators that works on encrypted + // tensors patterns .insert(&getContext()); @@ -939,7 +927,8 @@ void ConcreteToBConcretePass::runOnOperation() { patterns.insert(&getContext()); // Add patterns to rewrite some of memref ops that was introduced by the - // linalg bufferization of encrypted tensor (first conversion of this pass) + // linalg bufferization of encrypted tensor (first conversion of this + // pass) insertTensorShapeOpPattern(getContext(), patterns, target); insertTensorShapeOpPattern { - void runOnOperation() final; -}; -} // namespace using mlir::concretelang::FHE::EncryptedIntegerType; using mlir::concretelang::TFHE::GLWECipherTextType; @@ -84,10 +80,10 @@ public: /// : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> /// !TFHE.glwe<{_,_,_}{2}> /// ``` -struct ApplyLookupTableEintOpPattern +struct ApplyLookupTableEintOpToKeyswitchBootstrapPattern : public mlir::OpRewritePattern { - ApplyLookupTableEintOpPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) + ApplyLookupTableEintOpToKeyswitchBootstrapPattern( + mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(context, benefit) {} @@ -115,6 +111,51 @@ struct ApplyLookupTableEintOpPattern }; }; +/// This rewrite pattern transforms any instance of `FHE.apply_lookup_table` +/// operators. +/// +/// Example: +/// +/// ```mlir +/// %0 = "FHE.apply_lookup_table"(%ct, %lut): (!FHE.eint<2>, tensor<4xi64>) +/// ->(!FHE.eint<2>) +/// ``` +/// +/// becomes: +/// +/// ```mlir +/// %0 = "TFHE.wop_pbs_glwe"(%ct, %lut) +/// : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) -> +/// (!TFHE.glwe<{_,_,_}{2}>) +/// ``` +struct ApplyLookupTableEintOpToWopPBSPattern + : public mlir::OpRewritePattern { + ApplyLookupTableEintOpToWopPBSPattern(mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(context, + benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(FHE::ApplyLookupTableEintOp lutOp, + mlir::PatternRewriter &rewriter) const override { + FHEToTFHETypeConverter converter; + auto inputTy = converter.convertType(lutOp.a().getType()) + .cast(); + auto resultTy = converter.convertType(lutOp.getType()); + // %0 = "TFHE.wop_pbs_glwe"(%ct, %lut) + // : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) -> + // (!TFHE.glwe<{_,_,_}{2}>) + auto wopPBS = rewriter.replaceOpWithNewOp( + lutOp, resultTy, lutOp.a(), lutOp.lut(), -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1); + mlir::concretelang::convertOperandAndResultTypes( + rewriter, wopPBS, [&](mlir::MLIRContext *, mlir::Type t) { + return converter.convertType(t); + }); + return ::mlir::success(); + }; +}; + /// This rewrite pattern transforms any instance of `FHE.sub_eint_int` /// operators to a negation and an addition. struct SubEintIntOpPattern : public mlir::OpRewritePattern { @@ -194,97 +235,118 @@ struct SubEintOpPattern : public mlir::OpRewritePattern { }; }; -void FHEToTFHEPass::runOnOperation() { - auto op = this->getOperation(); +struct FHEToTFHEPass : public FHEToTFHEBase { - mlir::ConversionTarget target(getContext()); - FHEToTFHETypeConverter converter; + FHEToTFHEPass(mlir::concretelang::ApplyLookupTableLowering lutLowerStrategy) + : lutLowerStrategy(lutLowerStrategy) {} - // Mark ops from the target dialect as legal operations - target.addLegalDialect(); - target.addLegalDialect(); + void runOnOperation() { + auto op = this->getOperation(); - // Make sure that no ops from `FHE` remain after the lowering - target.addIllegalDialect(); + mlir::ConversionTarget target(getContext()); + FHEToTFHETypeConverter converter; - // Make sure that no ops `linalg.generic` that have illegal types - target - .addDynamicallyLegalOp( - [&](mlir::Operation *op) { - return ( - converter.isLegal(op->getOperandTypes()) && - converter.isLegal(op->getResultTypes()) && - converter.isLegal(op->getRegion(0).front().getArgumentTypes())); - }); + // Mark ops from the target dialect as legal operations + target.addLegalDialect(); + target.addLegalDialect(); - // Make sure that func has legal signature - target.addDynamicallyLegalOp( - [&](mlir::func::FuncOp funcOp) { - return converter.isSignatureLegal(funcOp.getFunctionType()) && - converter.isLegal(&funcOp.getBody()); - }); + // Make sure that no ops from `FHE` remain after the lowering + target.addIllegalDialect(); - // Add all patterns required to lower all ops from `FHE` to - // `TFHE` - mlir::RewritePatternSet patterns(&getContext()); + // Make sure that no ops `linalg.generic` that have illegal types + target.addDynamicallyLegalOp( + [&](mlir::Operation *op) { + return ( + converter.isLegal(op->getOperandTypes()) && + converter.isLegal(op->getResultTypes()) && + converter.isLegal(op->getRegion(0).front().getArgumentTypes())); + }); - populateWithGeneratedFHEToTFHE(patterns); + // Make sure that func has legal signature + target.addDynamicallyLegalOp( + [&](mlir::func::FuncOp funcOp) { + return converter.isSignatureLegal(funcOp.getFunctionType()) && + converter.isLegal(&funcOp.getBody()); + }); - patterns.add< - mlir::concretelang::GenericTypeConverterPattern>( - patterns.getContext(), converter); + // Add all patterns required to lower all ops from `FHE` to + // `TFHE` + mlir::RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); + populateWithGeneratedFHEToTFHE(patterns); - patterns.add>( - &getContext(), converter); + patterns.add< + mlir::concretelang::GenericTypeConverterPattern>( + patterns.getContext(), converter); - patterns.add< - mlir::concretelang::GenericTypeConverterPattern>( - patterns.getContext(), converter); + switch (lutLowerStrategy) { + case mlir::concretelang::KeySwitchBoostrapLowering: + patterns.add( + &getContext()); + break; + case mlir::concretelang::WopPBSLowering: + patterns.add(&getContext()); + break; + } - patterns.add>( - &getContext(), converter); + patterns.add(&getContext()); + patterns.add(&getContext()); - patterns.add< - RegionOpTypeConverterPattern>( - &getContext(), converter); - patterns.add>(&getContext(), converter); + patterns.add>( + &getContext(), converter); - mlir::concretelang::populateWithTensorTypeConverterPatterns(patterns, target, - converter); + patterns.add< + mlir::concretelang::GenericTypeConverterPattern>( + patterns.getContext(), converter); - mlir::populateFunctionOpInterfaceTypeConversionPattern( - patterns, converter); + patterns.add>( + &getContext(), converter); - // Conversion of RT Dialect Ops - patterns.add>(patterns.getContext(), - converter); - mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::concretelang::RT::DataflowTaskOp>(target, converter); - patterns.add>(patterns.getContext(), - converter); - mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::concretelang::RT::DataflowYieldOp>(target, converter); + patterns.add< + RegionOpTypeConverterPattern>( + &getContext(), converter); + patterns.add>(&getContext(), converter); - // Apply conversion - if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { - this->signalPassFailure(); + mlir::concretelang::populateWithTensorTypeConverterPatterns( + patterns, target, converter); + + mlir::populateFunctionOpInterfaceTypeConversionPattern( + patterns, converter); + + // Conversion of RT Dialect Ops + patterns.add>(patterns.getContext(), + converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::DataflowTaskOp>(target, converter); + patterns.add>(patterns.getContext(), + converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::DataflowYieldOp>(target, converter); + + // Apply conversion + if (mlir::applyPartialConversion(op, target, std::move(patterns)) + .failed()) { + this->signalPassFailure(); + } } -} + +private: + mlir::concretelang::ApplyLookupTableLowering lutLowerStrategy; +}; +} // namespace namespace mlir { namespace concretelang { -std::unique_ptr> createConvertFHEToTFHEPass() { - return std::make_unique(); +std::unique_ptr> +createConvertFHEToTFHEPass(ApplyLookupTableLowering lower) { + return std::make_unique(lower); } } // namespace concretelang } // namespace mlir diff --git a/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp b/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp index 26706d19b..1d72d8408 100644 --- a/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp +++ b/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp @@ -59,12 +59,17 @@ public: auto dimension = cryptoParameters.getNBigGlweDimension(); auto polynomialSize = 1; auto precision = (signed)type.getP(); + auto crtDecomposition = + cryptoParameters.largeInteger.hasValue() + ? cryptoParameters.largeInteger->crtDecomposition + : mlir::concretelang::CRTDecomposition{}; if ((int)dimension == type.getDimension() && (int)polynomialSize == type.getPolynomialSize()) { return type; } return TFHE::GLWECipherTextType::get(type.getContext(), dimension, - polynomialSize, bits, precision); + polynomialSize, bits, precision, + crtDecomposition); } TFHE::GLWECipherTextType glweLookupTableType(GLWECipherTextType &type) { @@ -73,7 +78,7 @@ public: auto polynomialSize = cryptoParameters.getPolynomialSize(); auto precision = (signed)type.getP(); return TFHE::GLWECipherTextType::get(type.getContext(), dimension, - polynomialSize, bits, precision); + polynomialSize, bits, precision, {}); } TFHE::GLWECipherTextType glweIntraPBSType(GLWECipherTextType &type) { @@ -82,7 +87,7 @@ public: auto polynomialSize = 1; auto precision = (signed)type.getP(); return TFHE::GLWECipherTextType::get(type.getContext(), dimension, - polynomialSize, bits, precision); + polynomialSize, bits, precision, {}); } mlir::concretelang::V0Parameter cryptoParameters; @@ -154,6 +159,49 @@ private: mlir::concretelang::V0Parameter &cryptoParameters; }; +struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern { + WopPBSGLWEOpPattern(mlir::MLIRContext *context, + TFHEGlobalParametrizationTypeConverter &converter, + mlir::concretelang::V0Parameter &cryptoParameters, + mlir::PatternBenefit benefit = + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) + : mlir::OpRewritePattern(context, benefit), + converter(converter), cryptoParameters(cryptoParameters) {} + + mlir::LogicalResult + matchAndRewrite(TFHE::WopPBSGLWEOp wopPBSOp, + mlir::PatternRewriter &rewriter) const override { + auto newOp = rewriter.replaceOpWithNewOp( + wopPBSOp, converter.convertType(wopPBSOp.result().getType()), + wopPBSOp.ciphertext(), wopPBSOp.lookupTable(), + // Bootstrap parameters + cryptoParameters.brLevel, cryptoParameters.brLogBase, + // Keyswitch parameters + cryptoParameters.ksLevel, cryptoParameters.ksLogBase, + // Packing keyswitch key parameters + cryptoParameters.largeInteger->wopPBS.packingKeySwitch + .inputLweDimension, + cryptoParameters.largeInteger->wopPBS.packingKeySwitch.inputLweCount, + cryptoParameters.largeInteger->wopPBS.packingKeySwitch + .outputPolynomialSize, + cryptoParameters.largeInteger->wopPBS.packingKeySwitch.level, + cryptoParameters.largeInteger->wopPBS.packingKeySwitch.baseLog, + // Circuit bootstrap parameters + cryptoParameters.largeInteger->wopPBS.circuitBootstrap.level, + cryptoParameters.largeInteger->wopPBS.circuitBootstrap.baseLog); + rewriter.startRootUpdate(newOp); + auto ciphertextType = + wopPBSOp.ciphertext().getType().cast(); + newOp.ciphertext().setType(converter.glweInterPBSType(ciphertextType)); + rewriter.finalizeRootUpdate(newOp); + return mlir::success(); + }; + +private: + TFHEGlobalParametrizationTypeConverter &converter; + mlir::concretelang::V0Parameter &cryptoParameters; +}; + /// This rewrite pattern transforms any instance of `TFHE.glwe_from_table` by /// parametrize GLWE return type and pad the table if the precision has been /// changed. @@ -271,6 +319,16 @@ void TFHEGlobalParametrizationPass::runOnOperation() { return converter.isLegal(op->getResultTypes()); }); + // Parametrize wop pbs + patterns.add(&getContext(), converter, + cryptoParameters); + target.addDynamicallyLegalOp( + [&](TFHE::WopPBSGLWEOp op) { + return !op.getType() + .cast() + .hasUnparametrizedParameters(); + }); + // Add all patterns to convert TFHE types populateWithTFHEOpTypeConversionPatterns(patterns, target, converter); patterns.add { + WopPBSGLWEOpPattern(mlir::MLIRContext *context, + mlir::TypeConverter &converter, + mlir::PatternBenefit benefit = 100) + : mlir::OpRewritePattern(context, benefit), + converter(converter) {} + + mlir::LogicalResult + matchAndRewrite(TFHE::WopPBSGLWEOp wopOp, + mlir::PatternRewriter &rewriter) const override { + mlir::Type resultType = converter.convertType(wopOp.getType()); + + auto newOp = rewriter.replaceOpWithNewOp( + wopOp, resultType, wopOp.ciphertext(), wopOp.lookupTable(), + // Bootstrap parameters + wopOp.bootstrapLevel(), wopOp.bootstrapBaseLog(), + // Keyswitch parameters + wopOp.keyswitchLevel(), wopOp.keyswitchBaseLog(), + // Packing keyswitch key parameters + wopOp.packingKeySwitchInputLweDimension(), + wopOp.packingKeySwitchinputLweCount(), + wopOp.packingKeySwitchoutputPolynomialSize(), + wopOp.packingKeySwitchLevel(), wopOp.packingKeySwitchBaseLog(), + // Circuit bootstrap parameters + wopOp.circuitBootstrapLevel(), wopOp.circuitBootstrapBaseLog()); + + rewriter.startRootUpdate(newOp); + + newOp.ciphertext().setType( + converter.convertType(wopOp.ciphertext().getType())); + + rewriter.finalizeRootUpdate(newOp); + return ::mlir::success(); + } + +private: + mlir::TypeConverter &converter; +}; + void TFHEToConcretePass::runOnOperation() { auto op = this->getOperation(); @@ -147,6 +186,7 @@ void TFHEToConcretePass::runOnOperation() { mlir::concretelang::Concrete::ZeroTensorLWEOp>>(&getContext(), converter); patterns.add(&getContext()); patterns.add(&getContext(), converter); + patterns.add(&getContext(), converter); target.addDynamicallyLegalOp( [&](Concrete::BootstrapLweOp op) { return (converter.isLegal(op->getOperandTypes()) && diff --git a/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp b/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp index 48492d422..279d7ff11 100644 --- a/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp @@ -36,23 +36,23 @@ namespace {} // namespace namespace { -mlir::Type getDynamic1DMemrefWithUnknownOffset(mlir::RewriterBase &rewriter) { +mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter, + size_t rank) { mlir::MLIRContext *ctx = rewriter.getContext(); - - return mlir::MemRefType::get( - {-1}, rewriter.getI64Type(), - mlir::AffineMap::get(1, 1, - mlir::getAffineDimExpr(0, ctx) + - mlir::getAffineSymbolExpr(0, ctx))); + std::vector shape(rank, -1); + return mlir::MemRefType::get(shape, rewriter.getI64Type(), + rewriter.getMultiDimIdentityMap(rank)); } -/// Returns `memref.cast %0 : memref to memref` if %0 a 1D memref -mlir::Value getCasted1DMemRef(mlir::RewriterBase &rewriter, mlir::Location loc, - mlir::Value value) { +// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>` +mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Location loc, + mlir::Value value) { mlir::Type valueType = value.getType(); - if (valueType.isa()) { + if (auto memrefTy = valueType.dyn_cast_or_null()) { return rewriter.create( - loc, getDynamic1DMemrefWithUnknownOffset(rewriter), value); + loc, + getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()), + value); } else { return value; } @@ -69,10 +69,13 @@ char memref_bootstrap_lwe_u64[] = "memref_bootstrap_lwe_u64"; char memref_expand_lut_in_trivial_glwe_ct_u64[] = "memref_expand_lut_in_trivial_glwe_ct_u64"; +char memref_wop_pbs_crt_buffer[] = "memref_wop_pbs_crt_buffer"; + mlir::LogicalResult insertForwardDeclarationOfTheCAPI( mlir::Operation *op, mlir::RewriterBase &rewriter, char const *funcName) { - auto memref1DType = getDynamic1DMemrefWithUnknownOffset(rewriter); + auto memref1DType = getDynamicMemrefWithUnknownOffset(rewriter, 1); + auto memref2DType = getDynamicMemrefWithUnknownOffset(rewriter, 2); auto contextType = mlir::concretelang::Concrete::ContextType::get(rewriter.getContext()); @@ -109,6 +112,15 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI( memref1DType, }, {}); + } else if (funcName == memref_wop_pbs_crt_buffer) { + funcType = mlir::FunctionType::get(rewriter.getContext(), + { + memref2DType, + memref2DType, + memref1DType, + contextType, + }, + {}); } else { op->emitError("unknwon external function") << funcName; return mlir::failure(); @@ -185,7 +197,7 @@ struct BufferizableWithCallOpInterface // The first operand is the result mlir::SmallVector operands{ - getCasted1DMemRef(rewriter, loc, *outMemref), + getCastedMemRef(rewriter, loc, *outMemref), }; // For all tensor operand get the corresponding casted buffer for (auto &operand : op->getOpOperands()) { @@ -194,7 +206,7 @@ struct BufferizableWithCallOpInterface } else { auto memrefOperand = bufferization::getBuffer(rewriter, operand.get(), options); - operands.push_back(getCasted1DMemRef(rewriter, loc, memrefOperand)); + operands.push_back(getCastedMemRef(rewriter, loc, memrefOperand)); } } // Append the context argument @@ -264,14 +276,14 @@ struct BufferizableGlweFromTableOpInterface auto loc = op->getLoc(); auto castOp = cast(op); - auto glweOp = getCasted1DMemRef( - rewriter, loc, - bufferization::getBuffer(rewriter, castOp->getOpOperand(0).get(), - options)); - auto lutOp = getCasted1DMemRef( - rewriter, loc, - bufferization::getBuffer(rewriter, castOp->getOpOperand(1).get(), - options)); + auto glweOp = + getCastedMemRef(rewriter, loc, + bufferization::getBuffer( + rewriter, castOp->getOpOperand(0).get(), options)); + auto lutOp = + getCastedMemRef(rewriter, loc, + bufferization::getBuffer( + rewriter, castOp->getOpOperand(1).get(), options)); auto polySizeOp = rewriter.create( op->getLoc(), rewriter.getI32IntegerAttr(castOp.polynomialSize())); @@ -326,6 +338,10 @@ void mlir::concretelang::BConcrete:: BConcrete::BootstrapLweBufferOp::attachInterface< BufferizableWithCallOpInterface>(*ctx); + // TODO(16bits): hack + BConcrete::WopPBSCRTLweBufferOp::attachInterface< + BufferizableWithCallOpInterface>(*ctx); BConcrete::FillGlweFromTable::attachInterface< BufferizableGlweFromTableOpInterface>(*ctx); }); diff --git a/compiler/lib/Dialect/BConcrete/Transforms/CMakeLists.txt b/compiler/lib/Dialect/BConcrete/Transforms/CMakeLists.txt index 7b49ffef7..cef63baf0 100644 --- a/compiler/lib/Dialect/BConcrete/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/BConcrete/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(ConcretelangBConcreteTransforms BufferizableOpInterfaceImpl.cpp AddRuntimeContext.cpp + EliminateCRTOps.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/BConcrete @@ -18,4 +19,4 @@ add_mlir_dialect_library(ConcretelangBConcreteTransforms MLIRMemRefDialect MLIRPass MLIRTransforms - ) +) diff --git a/compiler/lib/Dialect/BConcrete/Transforms/EliminateCRTOps.cpp b/compiler/lib/Dialect/BConcrete/Transforms/EliminateCRTOps.cpp new file mode 100644 index 000000000..3c3cf1161 --- /dev/null +++ b/compiler/lib/Dialect/BConcrete/Transforms/EliminateCRTOps.cpp @@ -0,0 +1,561 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "concretelang/ClientLib/CRT.h" +#include "concretelang/Conversion/Tools.h" +#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" +#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" +#include "concretelang/Dialect/BConcrete/Transforms/Passes.h" + +namespace arith = mlir::arith; +namespace tensor = mlir::tensor; +namespace bufferization = mlir::bufferization; +namespace scf = mlir::scf; +namespace BConcrete = mlir::concretelang::BConcrete; +namespace crt = concretelang::clientlib::crt; + +namespace { + +char encode_crt[] = "encode_crt"; + +// This template rewrite pattern transforms any instance of +// `BConcreteCRTOp` operators to `BConcreteOp` on +// each block. +// +// Example: +// +// ```mlir +// %0 = "BConcreteCRTOp"(%arg0, %arg1) {crtDecomposition = [...]} +// : (tensor, tensor) -> +// (tensor) +// ``` +// +// becomes: +// +// ```mlir +// %c0 = arith.constant 0 : index +// %c1 = arith.constant 1 : index +// %cB = arith.constant nbBlocks : index +// %init = linalg.tensor_init [B, lweSize] : tensor +// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) -> +// (tensor) { +// %blockArg = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1] +// : tensor +// %tmp = "BConcreteOp"(%blockArg) +// : (tensor) -> (tensor) +// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1] +// : tensor into tensor +// scf.yield %res : tensor +// } +// ``` +template +struct BConcreteCRTUnaryOpPattern + : public mlir::OpRewritePattern { + BConcreteCRTUnaryOpPattern(mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, benefit) {} + + mlir::LogicalResult + matchAndRewrite(BConcreteCRTOp op, + mlir::PatternRewriter &rewriter) const override { + auto resultTy = + ((mlir::Type)op.getResult().getType()).cast(); + auto loc = op.getLoc(); + assert(resultTy.getShape().size() == 2); + auto shape = resultTy.getShape(); + + // %c0 = arith.constant 0 : index + // %c1 = arith.constant 1 : index + // %cB = arith.constant nbBlocks : index + auto c0 = rewriter.create(loc, 0); + auto c1 = rewriter.create(loc, 1); + auto cB = rewriter.create(loc, shape[0]); + + // %init = linalg.tensor_init [B, lweSize] : tensor + mlir::Value init = rewriter.create( + op.getLoc(), resultTy, mlir::ValueRange{}); + + // %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) -> + // (tensor) { + rewriter.replaceOpWithNewOp( + op, c0, cB, c1, init, + [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i, + mlir::ValueRange iterArgs) { + // [%i, 0] + mlir::SmallVector offsets{ + i, rewriter.getI64IntegerAttr(0)}; + // [1, lweSize] + mlir::SmallVector sizes{ + rewriter.getI64IntegerAttr(1), + rewriter.getI64IntegerAttr(shape[1])}; + // [1, 1] + mlir::SmallVector strides{ + rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)}; + + auto blockTy = mlir::RankedTensorType::get({shape[1]}, + resultTy.getElementType()); + + // %blockArg = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1] + // : tensor + auto blockArg = builder.create( + loc, blockTy, op.ciphertext(), offsets, sizes, strides); + // %tmp = "BConcrete.add_lwe_buffer"(%blockArg0, %blockArg1) + // : (tensor, tensor) -> + // (tensor) + auto tmp = builder.create(loc, blockTy, blockArg); + + // %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, + // 1] : tensor into tensor + auto res = builder.create( + loc, tmp, iterArgs[0], offsets, sizes, strides); + // scf.yield %res : tensor + builder.create(loc, (mlir::Value)res); + }); + + return mlir::success(); + } +}; + +// This template rewrite pattern transforms any instance of +// `BConcreteCRTOp` operators to `BConcreteOp` on +// each block. +// +// Example: +// +// ```mlir +// %0 = "BConcreteCRTOp"(%arg0, %arg1) {crtDecomposition = [...]} +// : (tensor, tensor) -> +// (tensor) +// ``` +// +// becomes: +// +// ```mlir +// %c0 = arith.constant 0 : index +// %c1 = arith.constant 1 : index +// %cB = arith.constant nbBlocks : index +// %init = linalg.tensor_init [B, lweSize] : tensor +// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) -> +// (tensor) { +// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1] +// : tensor +// %blockArg1 = tensor.extract_slice %arg1[%i, 0] [1, lweSize] [1, 1] +// : tensor +// %tmp = "BConcreteOp"(%blockArg0, %blockArg1) +// : (tensor, tensor) -> +// (tensor) +// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1] +// : tensor into tensor +// scf.yield %res : tensor +// } +// ``` +template +struct BConcreteCRTBinaryOpPattern + : public mlir::OpRewritePattern { + BConcreteCRTBinaryOpPattern(mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, benefit) {} + + mlir::LogicalResult + matchAndRewrite(BConcreteCRTOp op, + mlir::PatternRewriter &rewriter) const override { + auto resultTy = + ((mlir::Type)op.getResult().getType()).cast(); + auto loc = op.getLoc(); + assert(resultTy.getShape().size() == 2); + auto shape = resultTy.getShape(); + + // %c0 = arith.constant 0 : index + // %c1 = arith.constant 1 : index + // %cB = arith.constant nbBlocks : index + auto c0 = rewriter.create(loc, 0); + auto c1 = rewriter.create(loc, 1); + auto cB = rewriter.create(loc, shape[0]); + + // %init = linalg.tensor_init [B, lweSize] : tensor + mlir::Value init = rewriter.create( + op.getLoc(), resultTy, mlir::ValueRange{}); + + // %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) -> + // (tensor) { + rewriter.replaceOpWithNewOp( + op, c0, cB, c1, init, + [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i, + mlir::ValueRange iterArgs) { + // [%i, 0] + mlir::SmallVector offsets{ + i, rewriter.getI64IntegerAttr(0)}; + // [1, lweSize] + mlir::SmallVector sizes{ + rewriter.getI64IntegerAttr(1), + rewriter.getI64IntegerAttr(shape[1])}; + // [1, 1] + mlir::SmallVector strides{ + rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)}; + + auto blockTy = mlir::RankedTensorType::get({shape[1]}, + resultTy.getElementType()); + + // %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1] + // : tensor + auto blockArg0 = builder.create( + loc, blockTy, op.lhs(), offsets, sizes, strides); + // %blockArg1 = tensor.extract_slice %arg1[%i, 0] [1, lweSize] [1, 1] + // : tensor + auto blockArg1 = builder.create( + loc, blockTy, op.rhs(), offsets, sizes, strides); + // %tmp = "BConcrete.add_lwe_buffer"(%blockArg0, %blockArg1) + // : (tensor, tensor) -> + // (tensor) + auto tmp = + builder.create(loc, blockTy, blockArg0, blockArg1); + + // %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, + // 1] : tensor into tensor + auto res = builder.create( + loc, tmp, iterArgs[0], offsets, sizes, strides); + // scf.yield %res : tensor + builder.create(loc, (mlir::Value)res); + }); + + return mlir::success(); + } +}; + +// This template rewrite pattern transforms any instance of +// `BConcreteCRTOp` operators to `BConcreteOp` on +// each block with the crt decomposition of the cleartext. +// +// Example: +// +// ```mlir +// %0 = "BConcreteCRTOp"(%arg0, %x) {crtDecomposition = [d0...dn]} +// : (tensor, i64) -> (tensor) +// ``` +// +// becomes: +// +// ```mlir +// // Build the decomposition of the plaintext +// %x0_a = arith.constant 64/d0 : f64 +// %x0_b = arith.mulf %x, %x0_a : i64 +// %x0 = arith.fptoui %x0_b : f64 to i64 +// ... +// %xn_a = arith.constant 64/dn : f64 +// %xn_b = arith.mulf %x, %xn_a : i64 +// %xn = arith.fptoui %xn_b : f64 to i64 +// %x_decomp = tensor.from_elements %x0, ..., %xn : tensor +// // Loop on blocks +// %c0 = arith.constant 0 : index +// %c1 = arith.constant 1 : index +// %cB = arith.constant nbBlocks : index +// %init = linalg.tensor_init [B, lweSize] : tensor +// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) -> +// (tensor) { +// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1] +// : tensor +// %blockArg1 = tensor.extract %x_decomp[%i] : tensor +// %tmp = "BConcreteOp"(%blockArg0, %blockArg1) +// : (tensor, i64) -> (tensor) +// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1] +// : tensor into tensor +// scf.yield %res : tensor +// } +// ``` +struct AddPlaintextCRTLweBufferOpPattern + : public mlir::OpRewritePattern { + AddPlaintextCRTLweBufferOpPattern(mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, + benefit) { + } + + mlir::LogicalResult + matchAndRewrite(BConcrete::AddPlaintextCRTLweBufferOp op, + mlir::PatternRewriter &rewriter) const override { + auto resultTy = + ((mlir::Type)op.getResult().getType()).cast(); + auto loc = op.getLoc(); + assert(resultTy.getShape().size() == 2); + auto shape = resultTy.getShape(); + + auto rhs = op.rhs(); + mlir::SmallVector plaintextElements; + uint64_t moduliProduct = 1; + for (mlir::Attribute di : op.crtDecomposition()) { + moduliProduct *= di.cast().getValue().getZExtValue(); + } + if (auto cst = + mlir::dyn_cast_or_null(rhs.getDefiningOp())) { + auto apCst = cst.getValue().cast().getValue(); + auto value = apCst.getSExtValue(); + + // constant value, encode at compile time + for (mlir::Attribute di : op.crtDecomposition()) { + auto modulus = di.cast().getValue().getZExtValue(); + + auto encoded = crt::encode(value, modulus, moduliProduct); + plaintextElements.push_back( + rewriter.create(loc, encoded, 64)); + } + } else { + // dynamic value, encode at runtime + if (insertForwardDeclaration( + op, rewriter, encode_crt, + mlir::FunctionType::get(rewriter.getContext(), + {rewriter.getI64Type(), + rewriter.getI64Type(), + rewriter.getI64Type()}, + {rewriter.getI64Type()})) + .failed()) { + return mlir::failure(); + } + auto extOp = + rewriter.create(loc, rewriter.getI64Type(), rhs); + auto moduliProductOp = + rewriter.create(loc, moduliProduct, 64); + for (mlir::Attribute di : op.crtDecomposition()) { + auto modulus = di.cast().getValue().getZExtValue(); + auto modulusOp = + rewriter.create(loc, modulus, 64); + plaintextElements.push_back( + rewriter + .create( + loc, encode_crt, mlir::TypeRange{rewriter.getI64Type()}, + mlir::ValueRange{extOp, modulusOp, moduliProductOp}) + .getResult(0)); + } + } + + // %x_decomp = tensor.from_elements %x0, ..., %xn : tensor + auto x_decomp = + rewriter.create(loc, plaintextElements); + + // %c0 = arith.constant 0 : index + // %c1 = arith.constant 1 : index + // %cB = arith.constant nbBlocks : index + auto c0 = rewriter.create(loc, 0); + auto c1 = rewriter.create(loc, 1); + auto cB = rewriter.create(loc, shape[0]); + + // %init = linalg.tensor_init [B, lweSize] : tensor + mlir::Value init = rewriter.create( + op.getLoc(), resultTy, mlir::ValueRange{}); + + // %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) -> + // (tensor) { + rewriter.replaceOpWithNewOp( + op, c0, cB, c1, init, + [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i, + mlir::ValueRange iterArgs) { + // [%i, 0] + mlir::SmallVector offsets{ + i, rewriter.getI64IntegerAttr(0)}; + // [1, lweSize] + mlir::SmallVector sizes{ + rewriter.getI64IntegerAttr(1), + rewriter.getI64IntegerAttr(shape[1])}; + // [1, 1] + mlir::SmallVector strides{ + rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)}; + + auto blockTy = mlir::RankedTensorType::get({shape[1]}, + resultTy.getElementType()); + + // %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1] + // : tensor + auto blockArg0 = builder.create( + loc, blockTy, op.lhs(), offsets, sizes, strides); + // %blockArg1 = tensor.extract %x_decomp[%i] : tensor + auto blockArg1 = builder.create(loc, x_decomp, i); + // %tmp = "BConcreteOp"(%blockArg0, %blockArg1) + // : (tensor, i64) -> (tensor) + auto tmp = builder.create( + loc, blockTy, blockArg0, blockArg1); + + // %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, + // 1] : tensor into tensor + auto res = builder.create( + loc, tmp, iterArgs[0], offsets, sizes, strides); + // scf.yield %res : tensor + builder.create(loc, (mlir::Value)res); + }); + + return mlir::success(); + } +}; + +// This template rewrite pattern transforms any instance of +// `BConcreteCRTOp` operators to `BConcreteOp` on +// each block with the crt decomposition of the cleartext. +// +// Example: +// +// ```mlir +// %0 = "BConcreteCRTOp"(%arg0, %x) {crtDecomposition = [d0...dn]} +// : (tensor, i64) -> (tensor) +// ``` +// +// becomes: +// +// ```mlir +// // Build the decomposition of the plaintext +// %x0_a = arith.constant 64/d0 : f64 +// %x0_b = arith.mulf %x, %x0_a : i64 +// %x0 = arith.fptoui %x0_b : f64 to i64 +// ... +// %xn_a = arith.constant 64/dn : f64 +// %xn_b = arith.mulf %x, %xn_a : i64 +// %xn = arith.fptoui %xn_b : f64 to i64 +// %x_decomp = tensor.from_elements %x0, ..., %xn : tensor +// // Loop on blocks +// %c0 = arith.constant 0 : index +// %c1 = arith.constant 1 : index +// %cB = arith.constant nbBlocks : index +// %init = linalg.tensor_init [B, lweSize] : tensor +// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) -> +// (tensor) { +// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1] +// : tensor +// %blockArg1 = tensor.extract %x_decomp[%i] : tensor +// %tmp = "BConcreteOp"(%blockArg0, %blockArg1) +// : (tensor, i64) -> (tensor) +// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1] +// : tensor into tensor +// scf.yield %res : tensor +// } +// ``` +struct MulCleartextCRTLweBufferOpPattern + : public mlir::OpRewritePattern { + MulCleartextCRTLweBufferOpPattern(mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, + benefit) { + } + + mlir::LogicalResult + matchAndRewrite(BConcrete::MulCleartextCRTLweBufferOp op, + mlir::PatternRewriter &rewriter) const override { + auto resultTy = + ((mlir::Type)op.getResult().getType()).cast(); + auto loc = op.getLoc(); + assert(resultTy.getShape().size() == 2); + auto shape = resultTy.getShape(); + + // %c0 = arith.constant 0 : index + // %c1 = arith.constant 1 : index + // %cB = arith.constant nbBlocks : index + auto c0 = rewriter.create(loc, 0); + auto c1 = rewriter.create(loc, 1); + auto cB = rewriter.create(loc, shape[0]); + + // %init = linalg.tensor_init [B, lweSize] : tensor + mlir::Value init = rewriter.create( + op.getLoc(), resultTy, mlir::ValueRange{}); + + auto rhs = rewriter.create(op.getLoc(), + rewriter.getI64Type(), op.rhs()); + + // %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) -> + // (tensor) { + rewriter.replaceOpWithNewOp( + op, c0, cB, c1, init, + [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i, + mlir::ValueRange iterArgs) { + // [%i, 0] + mlir::SmallVector offsets{ + i, rewriter.getI64IntegerAttr(0)}; + // [1, lweSize] + mlir::SmallVector sizes{ + rewriter.getI64IntegerAttr(1), + rewriter.getI64IntegerAttr(shape[1])}; + // [1, 1] + mlir::SmallVector strides{ + rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)}; + + auto blockTy = mlir::RankedTensorType::get({shape[1]}, + resultTy.getElementType()); + + // %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1] + // : tensor + auto blockArg0 = builder.create( + loc, blockTy, op.lhs(), offsets, sizes, strides); + + // %tmp = BConcrete.mul_cleartext_lwe_buffer(%blockArg0, %x) + // : (tensor, i64) -> (tensor) + auto tmp = builder.create( + loc, blockTy, blockArg0, rhs); + + // %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, + // 1] : tensor into tensor + auto res = builder.create( + loc, tmp, iterArgs[0], offsets, sizes, strides); + // scf.yield %res : tensor + builder.create(loc, (mlir::Value)res); + }); + + return mlir::success(); + } +}; + +struct EliminateCRTOpsPass : public EliminateCRTOpsBase { + void runOnOperation() final; +}; + +void EliminateCRTOpsPass::runOnOperation() { + auto op = getOperation(); + + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); + + // add_crt_lwe_buffers + target.addIllegalOp(); + patterns.add>( + &getContext()); + + // add_plaintext_crt_lwe_buffers + target.addIllegalOp(); + patterns.add(&getContext()); + + // mul_cleartext_crt_lwe_buffer + target.addIllegalOp(); + patterns.add(&getContext()); + + target.addIllegalOp(); + patterns.add>( + &getContext()); + + // This dialect are used to transforms crt ops to bconcrete ops + target + .addLegalDialect(); + + // Apply the conversion + if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { + this->signalPassFailure(); + return; + } +} +} // namespace + +namespace mlir { +namespace concretelang { +std::unique_ptr> createEliminateCRTOps() { + return std::make_unique(); +} +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Dialect/Concrete/IR/ConcreteDialect.cpp b/compiler/lib/Dialect/Concrete/IR/ConcreteDialect.cpp index e31c355f5..f9dd3e8e5 100644 --- a/compiler/lib/Dialect/Concrete/IR/ConcreteDialect.cpp +++ b/compiler/lib/Dialect/Concrete/IR/ConcreteDialect.cpp @@ -25,6 +25,13 @@ void ConcreteDialect::initialize() { >(); } +void printSigned(mlir::AsmPrinter &p, signed i) { + if (i == -1) + p << "_"; + else + p << i; +} + mlir::Type GlweCiphertextType::parse(mlir::AsmParser &parser) { if (parser.parseLess()) return Type(); @@ -50,42 +57,61 @@ mlir::Type GlweCiphertextType::parse(mlir::AsmParser &parser) { void GlweCiphertextType::print(mlir::AsmPrinter &p) const { p << "<"; - if (getImpl()->glweDimension == -1) - p << "_"; - else - p << getImpl()->glweDimension; + printSigned(p, getGlweDimension()); p << ","; - if (getImpl()->polynomialSize == -1) - p << "_"; - else - p << getImpl()->polynomialSize; + printSigned(p, getPolynomialSize()); p << ","; - if (getImpl()->p == -1) - p << "_"; - else - p << getImpl()->p; + printSigned(p, getP()); p << ">"; } void LweCiphertextType::print(mlir::AsmPrinter &p) const { p << "<"; - - if (getDimension() == -1) - p << "_"; - else - p << getDimension(); - + // decomposition parameters if any + auto crt = getCrtDecomposition(); + if (!crt.empty()) { + p << "crt=["; + for (auto c : crt.drop_back(1)) { + printSigned(p, c); + p << ","; + } + printSigned(p, crt.back()); + p << "]"; + p << ","; + } + printSigned(p, getDimension()); p << ","; - if (getP() == -1) - p << "_"; - else - p << getP(); + printSigned(p, getP()); p << ">"; } mlir::Type LweCiphertextType::parse(mlir::AsmParser &parser) { if (parser.parseLess()) return mlir::Type(); + + // Parse for the crt decomposition if any + std::vector crtDecomposition; + if (!parser.parseOptionalKeyword("crt")) { + if (parser.parseEqual() || parser.parseLSquare()) + return mlir::Type(); + while (true) { + int64_t c = -1; + if (parser.parseOptionalKeyword("_") && parser.parseInteger(c)) { + return mlir::Type(); + } + crtDecomposition.push_back(c); + if (parser.parseOptionalComma()) { + if (parser.parseRSquare()) { + return mlir::Type(); + } else { + break; + } + } + } + if (parser.parseComma()) + return mlir::Type(); + } + int dimension = -1; if (parser.parseOptionalKeyword("_") && parser.parseInteger(dimension)) return mlir::Type(); @@ -99,7 +125,7 @@ mlir::Type LweCiphertextType::parse(mlir::AsmParser &parser) { mlir::Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); - return getChecked(loc, loc.getContext(), dimension, p); + return getChecked(loc, loc.getContext(), dimension, p, crtDecomposition); } void CleartextType::print(mlir::AsmPrinter &p) const { diff --git a/compiler/lib/Dialect/Concrete/Transforms/Optimization.cpp b/compiler/lib/Dialect/Concrete/Transforms/Optimization.cpp index 50c153e09..ba59ba2b3 100644 --- a/compiler/lib/Dialect/Concrete/Transforms/Optimization.cpp +++ b/compiler/lib/Dialect/Concrete/Transforms/Optimization.cpp @@ -16,29 +16,10 @@ namespace concretelang { namespace { -/// Get the integer value that the cleartext was created from if it exists. -llvm::Optional -getIntegerFromCleartextIfExists(mlir::Value cleartext) { - assert( - cleartext.getType().isa()); - // Cleartext are supposed to be created from integers - auto intToCleartextOp = cleartext.getDefiningOp(); - if (intToCleartextOp == nullptr) - return {}; - if (llvm::isa(intToCleartextOp)) { - // We want to match when the integer value is constant - return intToCleartextOp->getOperand(0); - } - return {}; -} - /// Get the constant integer that the cleartext was created from if it exists. llvm::Optional getConstantIntFromCleartextIfExists(mlir::Value cleartext) { - auto cleartextInt = getIntegerFromCleartextIfExists(cleartext); - if (!cleartextInt.hasValue()) - return {}; - auto constantOp = cleartextInt.getValue().getDefiningOp(); + auto constantOp = cleartext.getDefiningOp(); if (constantOp == nullptr) return {}; if (llvm::isa(constantOp)) { diff --git a/compiler/lib/Dialect/TFHE/IR/TFHEDialect.cpp b/compiler/lib/Dialect/TFHE/IR/TFHEDialect.cpp index 73c8f99c9..fecf5a585 100644 --- a/compiler/lib/Dialect/TFHE/IR/TFHEDialect.cpp +++ b/compiler/lib/Dialect/TFHE/IR/TFHEDialect.cpp @@ -32,7 +32,8 @@ void TFHEDialect::initialize() { /// - The bits parameter is 64 (we support only this for v0) ::mlir::LogicalResult GLWECipherTextType::verify( ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, - signed dimension, signed polynomialSize, signed bits, signed p) { + signed dimension, signed polynomialSize, signed bits, signed p, + llvm::ArrayRef) { if (bits != -1 && bits != 64) { emitError() << "GLWE bits parameter can only be 64"; return ::mlir::failure(); diff --git a/compiler/lib/Dialect/TFHE/IR/TFHEOps.cpp b/compiler/lib/Dialect/TFHE/IR/TFHEOps.cpp index 7afcf4b1c..5cacc3035 100644 --- a/compiler/lib/Dialect/TFHE/IR/TFHEOps.cpp +++ b/compiler/lib/Dialect/TFHE/IR/TFHEOps.cpp @@ -40,6 +40,10 @@ mlir::LogicalResult _verifyGLWEIntegerOperator(mlir::OpState &op, emitOpErrorForIncompatibleGLWEParameter(op, "p"); return mlir::failure(); } + if (a.getCrtDecomposition() != result.getCrtDecomposition()) { + emitOpErrorForIncompatibleGLWEParameter(op, "crt"); + return mlir::failure(); + } // verify consistency of width of inputs if ((int)b.getWidth() > a.getP() + 1) { @@ -107,6 +111,11 @@ mlir::LogicalResult verifyBinaryGLWEOperator(Operator &op) { emitOpErrorForIncompatibleGLWEParameter(op, "p"); return mlir::failure(); } + if (a.getCrtDecomposition() != b.getCrtDecomposition() || + a.getCrtDecomposition() != result.getCrtDecomposition()) { + emitOpErrorForIncompatibleGLWEParameter(op, "crt"); + return mlir::failure(); + } return mlir::success(); } @@ -137,6 +146,10 @@ mlir::LogicalResult verifyUnaryGLWEOperator(Operator &op) { emitOpErrorForIncompatibleGLWEParameter(op, "p"); return mlir::failure(); } + if (a.getCrtDecomposition() != result.getCrtDecomposition()) { + emitOpErrorForIncompatibleGLWEParameter(op, "crt"); + return mlir::failure(); + } return mlir::success(); } diff --git a/compiler/lib/Dialect/TFHE/IR/TFHETypes.cpp b/compiler/lib/Dialect/TFHE/IR/TFHETypes.cpp index 338e00df8..eccc7f04c 100644 --- a/compiler/lib/Dialect/TFHE/IR/TFHETypes.cpp +++ b/compiler/lib/Dialect/TFHE/IR/TFHETypes.cpp @@ -9,28 +9,35 @@ namespace mlir { namespace concretelang { namespace TFHE { +void printSigned(mlir::AsmPrinter &p, signed i) { + if (i == -1) + p << "_"; + else + p << i; +} + void GLWECipherTextType::print(mlir::AsmPrinter &p) const { - p << "<{"; - if (getDimension() == -1) - p << "_"; - else - p << getDimension(); - p << ","; - if (getPolynomialSize() == -1) - p << "_"; - else - p << getPolynomialSize(); - p << ","; - if (getBits() == -1) - p << "_"; - else - p << getBits(); - p << "}"; + p << "<"; + auto crt = getCrtDecomposition(); + if (!crt.empty()) { + p << "crt=["; + for (auto c : crt.drop_back(1)) { + printSigned(p, c); + p << ","; + } + printSigned(p, crt.back()); + p << "]"; + } p << "{"; - if (getP() == -1) - p << "_"; - else - p << getP(); + printSigned(p, getDimension()); + p << ","; + printSigned(p, getPolynomialSize()); + p << ","; + printSigned(p, getBits()); + p << "}"; + + p << "{"; + printSigned(p, getP()); p << "}>"; } @@ -38,9 +45,31 @@ mlir::Type GLWECipherTextType::parse(AsmParser &parser) { if (parser.parseLess()) return mlir::Type(); - // First parameters block + // Parse for the crt decomposition if any + std::vector crtDecomposition; + if (!parser.parseOptionalKeyword("crt")) { + if (parser.parseEqual() || parser.parseLSquare()) + return mlir::Type(); + while (true) { + signed c = -1; + if (parser.parseOptionalKeyword("_") && parser.parseInteger(c)) { + return mlir::Type(); + } + crtDecomposition.push_back(c); + if (parser.parseOptionalComma()) { + if (parser.parseRSquare()) { + return mlir::Type(); + } else { + break; + } + } + } + } + if (parser.parseLBrace()) return mlir::Type(); + + // First parameters block int dimension = -1; if (parser.parseOptionalKeyword("_") && parser.parseInteger(dimension)) return mlir::Type(); @@ -69,7 +98,8 @@ mlir::Type GLWECipherTextType::parse(AsmParser &parser) { if (parser.parseGreater()) return mlir::Type(); Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); - return getChecked(loc, loc.getContext(), dimension, polynomialSize, bits, p); + return getChecked(loc, loc.getContext(), dimension, polynomialSize, bits, p, + llvm::ArrayRef(crtDecomposition)); } } // namespace TFHE } // namespace concretelang diff --git a/compiler/lib/Runtime/CMakeLists.txt b/compiler/lib/Runtime/CMakeLists.txt index 4ad011298..b15790ebf 100644 --- a/compiler/lib/Runtime/CMakeLists.txt +++ b/compiler/lib/Runtime/CMakeLists.txt @@ -26,10 +26,11 @@ target_link_libraries( ConcretelangRuntime PUBLIC Concrete + ConcretelangClientLib + pthread m dl $ ) install(TARGETS ConcretelangRuntime omp EXPORT ConcretelangRuntime) install(EXPORT ConcretelangRuntime DESTINATION "./") - diff --git a/compiler/lib/Runtime/wrappers.cpp b/compiler/lib/Runtime/wrappers.cpp index f901ba25f..502e98abd 100644 --- a/compiler/lib/Runtime/wrappers.cpp +++ b/compiler/lib/Runtime/wrappers.cpp @@ -3,11 +3,13 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. -#include "concretelang/Runtime/wrappers.h" #include #include #include +#include "concretelang/ClientLib/CRT.h" +#include "concretelang/Runtime/wrappers.h" + void memref_expand_lut_in_trivial_glwe_ct_u64( uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned, uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride, @@ -98,6 +100,10 @@ void memref_bootstrap_lwe_u64( glwe_ct_aligned + glwe_ct_offset); } +uint64_t encode_crt(int64_t plaintext, uint64_t modulus, uint64_t product) { + return concretelang::clientlib::crt::encode(plaintext, modulus, product); +} + void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned, uint64_t src_offset, uint64_t src_size, uint64_t src_stride, uint64_t *dst_allocated, diff --git a/compiler/lib/ServerLib/DynamicRankCall.cpp b/compiler/lib/ServerLib/DynamicRankCall.cpp index 74947ceac..197820d6d 100644 --- a/compiler/lib/ServerLib/DynamicRankCall.cpp +++ b/compiler/lib/ServerLib/DynamicRankCall.cpp @@ -38,171 +38,166 @@ TensorData multi_arity_call_dynamic_rank(void *(*func)(void *...), using concretelang::clientlib::MemRefDescriptor; constexpr auto convert = concretelang::clientlib::tensorDataFromMemRef; switch (rank) { - case 0: { + case 1: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(1, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 1: { + case 2: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(2, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 2: { + case 3: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(3, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 3: { + case 4: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(4, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 4: { + case 5: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(5, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 5: { + case 6: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(6, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 6: { + case 7: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(7, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 7: { + case 8: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(8, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 8: { + case 9: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(9, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 9: { + case 10: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(10, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 10: { + case 11: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(11, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 11: { + case 12: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(12, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 12: { + case 13: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(13, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 13: { + case 14: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(14, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 14: { + case 15: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(15, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 15: { + case 16: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(16, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 16: { + case 17: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(17, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 17: { + case 18: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(18, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 18: { + case 19: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(19, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 19: { + case 20: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(20, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 20: { + case 21: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(21, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 21: { + case 22: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(22, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 22: { + case 23: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(23, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 23: { + case 24: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(24, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 24: { + case 25: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(25, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 25: { + case 26: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(26, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 26: { + case 27: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(27, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 27: { + case 28: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(28, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 28: { + case 29: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(29, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 29: { + case 30: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(30, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 30: { + case 31: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(31, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 31: { + case 32: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); return convert(32, m.allocated, m.aligned, m.offset, m.sizes, m.strides); } - case 32: { - auto m = multi_arity_call( - convert_fnptr (*)(void *...)>(func), args); - return convert(33, m.allocated, m.aligned, m.offset, m.sizes, m.strides); - } default: assert(false); diff --git a/compiler/lib/ServerLib/ServerLambda.cpp b/compiler/lib/ServerLib/ServerLambda.cpp index 9ae923bbd..7786b54eb 100644 --- a/compiler/lib/ServerLib/ServerLambda.cpp +++ b/compiler/lib/ServerLib/ServerLambda.cpp @@ -69,12 +69,6 @@ ServerLambda::load(std::string funcName, std::string outputPath) { return ServerLambda::loadFromModule(module, funcName); } -TensorData dynamicCall(void *(*func)(void *...), - std::vector &preparedArgs, CircuitGate &output) { - size_t rank = output.shape.dimensions.size(); - return multi_arity_call_dynamic_rank(func, preparedArgs, rank); -} - std::unique_ptr ServerLambda::call(PublicArguments &args, EvaluationKeys &evaluationKeys) { std::vector preparedArgs(args.preparedArgs.begin(), @@ -84,9 +78,12 @@ ServerLambda::call(PublicArguments &args, EvaluationKeys &evaluationKeys) { runtimeContext.evaluationKeys = evaluationKeys; preparedArgs.push_back((void *)&runtimeContext); - return clientlib::PublicResult::fromBuffers( - clientParameters, - {dynamicCall(this->func, preparedArgs, clientParameters.outputs[0])}); + assert(clientParameters.outputs.size() == 1 && + "ServerLambda::call is implemented for only one output"); + auto output = args.clientParameters.outputs[0]; + auto rank = args.clientParameters.bufferShape(output).size(); + auto result = multi_arity_call_dynamic_rank(func, preparedArgs, rank); + return clientlib::PublicResult::fromBuffers(clientParameters, {result}); ; } diff --git a/compiler/lib/ServerLib/genDynamicRankCall.py b/compiler/lib/ServerLib/genDynamicRankCall.py index 6ea898990..cc9959298 100644 --- a/compiler/lib/ServerLib/genDynamicRankCall.py +++ b/compiler/lib/ServerLib/genDynamicRankCall.py @@ -43,8 +43,8 @@ TensorData multi_arity_call_dynamic_rank(void *(*func)(void *...), constexpr auto convert = concretelang::clientlib::tensorDataFromMemRef; switch (rank) {""") -for tensor_rank in range(0, 33): - memref_rank = tensor_rank + 1 +for tensor_rank in range(1, 33): + memref_rank = tensor_rank print(f""" case {tensor_rank}: {{ auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index e91bdc0ee..0b73ed0d0 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -166,30 +166,39 @@ CompilerEngine::getConcreteOptimizerDescription(CompilationResult &res) { return std::move(descriptions->begin()->second); } +/// set the fheContext field if the v0Constraint can be computed /// set the fheContext field if the v0Constraint can be computed llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) { - auto descrOrErr = getConcreteOptimizerDescription(res); - if (auto err = descrOrErr.takeError()) { - return err; - } - // The function is non-crypto and without constraint override - if (!descrOrErr.get().hasValue()) { + if (compilerOptions.v0Parameter.hasValue()) { + // parameters come from the compiler options + V0Parameter v0Params = compilerOptions.v0Parameter.value(); + if (compilerOptions.largeIntegerParameter.hasValue()) { + v0Params.largeInteger = compilerOptions.largeIntegerParameter; + } + res.fheContext.emplace(mlir::concretelang::V0FHEContext{{0, 0}, v0Params}); return llvm::Error::success(); } - auto descr = std::move(descrOrErr.get().getValue()); - auto config = this->compilerOptions.optimizerConfig; - - auto fheParams = (compilerOptions.v0Parameter.hasValue()) - ? compilerOptions.v0Parameter - : getParameter(descr, config); - if (!fheParams) { - return StreamStringError() - << "Could not determine V0 parameters for 2-norm of " - << (*descrOrErr)->constraint.norm2 << " and p of " - << (*descrOrErr)->constraint.p; + // compute parameters + else { + auto descr = getConcreteOptimizerDescription(res); + if (auto err = descr.takeError()) { + return err; + } + if (!descr.get().hasValue()) { + return llvm::Error::success(); + } + auto optV0Params = + getParameter(descr.get().value(), compilerOptions.optimizerConfig); + if (!optV0Params) { + return StreamStringError() + << "Could not determine V0 parameters for 2-norm of " + << (*descr)->constraint.norm2 << " and p of " + << (*descr)->constraint.p; + } + res.fheContext.emplace(mlir::concretelang::V0FHEContext{ + descr.get().value().constraint, optV0Params.getValue()}); } - res.fheContext.emplace( - mlir::concretelang::V0FHEContext{descr.constraint, fheParams.getValue()}); + return llvm::Error::success(); } @@ -282,7 +291,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { // FHE -> TFHE if (mlir::concretelang::pipeline::lowerFHEToTFHE(mlirContext, module, - enablePass) + res.fheContext, enablePass) .failed()) { return errorDiag("Lowering from FHE to TFHE failed"); } diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 296b6e1d7..81ab0d7b6 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -110,21 +110,13 @@ JITLambda::call(clientlib::PublicArguments &args, // Prepare the outputs vector to store the output value of the lambda. auto numOutputs = 0; for (auto &output : args.clientParameters.outputs) { - if (output.shape.dimensions.empty()) { + auto shape = args.clientParameters.bufferShape(output); + if (shape.size() == 0) { // scalar gate - if (output.encryption.hasValue()) { - // encrypted scalar : memref - numOutputs += numArgOfRankedMemrefCallingConvention(1); - } else { - // clear scalar - numOutputs += 1; - } + numOutputs += 1; } else { - // memref gate : rank+1 if the output is encrypted for the lwe size - // dimension - auto rank = output.shape.dimensions.size() + - (output.encryption.hasValue() ? 1 : 0); - numOutputs += numArgOfRankedMemrefCallingConvention(rank); + // buffer gate + numOutputs += numArgOfRankedMemrefCallingConvention(shape.size()); } } std::vector outputs(numOutputs); @@ -148,7 +140,6 @@ JITLambda::call(clientlib::PublicArguments &args, for (auto &out : outputs) { rawArgs[i++] = &out; } - // Invoke if (auto err = invokeRaw(rawArgs)) { return std::move(err); @@ -159,14 +150,14 @@ JITLambda::call(clientlib::PublicArguments &args, { size_t outputOffset = 0; for (auto &output : args.clientParameters.outputs) { - if (output.shape.dimensions.empty() && !output.encryption.hasValue()) { - // clear scalar + auto shape = args.clientParameters.bufferShape(output); + if (shape.size() == 0) { + // scalar scalar buffers.push_back( clientlib::tensorDataFromScalar((uint64_t)outputs[outputOffset++])); } else { - // encrypted scalar, and tensor gate are memref - auto rank = output.shape.dimensions.size() + - (output.encryption.hasValue() ? 1 : 0); + // buffer gate + auto rank = shape.size(); auto allocated = (uint64_t *)outputs[outputOffset++]; auto aligned = (uint64_t *)outputs[outputOffset++]; auto offset = (size_t)outputs[outputOffset++]; diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 74a46068b..2fd3cbe78 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -182,6 +182,7 @@ markFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module, mlir::LogicalResult lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, + llvm::Optional &fheContext, std::function enablePass) { mlir::PassManager pm(&context); pipelinePrinting("FHEToTFHE", pm, context); @@ -192,8 +193,14 @@ lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, // to linalg.generic operations addPotentiallyNestedPass(pm, mlir::createLinalgGeneralizationPass(), enablePass); - addPotentiallyNestedPass(pm, mlir::concretelang::createConvertFHEToTFHEPass(), - enablePass); + mlir::concretelang::ApplyLookupTableLowering lowerStrategy = + mlir::concretelang::KeySwitchBoostrapLowering; + if (fheContext.hasValue() && fheContext->parameter.largeInteger.hasValue()) { + lowerStrategy = mlir::concretelang::WopPBSLowering; + } + addPotentiallyNestedPass( + pm, mlir::concretelang::createConvertFHEToTFHEPass(lowerStrategy), + enablePass); return pm.run(module.getOperation()); } @@ -260,6 +267,8 @@ lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass) { mlir::PassManager pm(&context); pipelinePrinting("BConcreteToStd", pm, context); + addPotentiallyNestedPass(pm, mlir::concretelang::createEliminateCRTOps(), + enablePass); addPotentiallyNestedPass(pm, mlir::concretelang::createAddRuntimeContext(), enablePass); return pm.run(module.getOperation()); diff --git a/compiler/lib/Support/V0ClientParameters.cpp b/compiler/lib/Support/V0ClientParameters.cpp index ae05933d9..d98072a5c 100644 --- a/compiler/lib/Support/V0ClientParameters.cpp +++ b/compiler/lib/Support/V0ClientParameters.cpp @@ -3,6 +3,7 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. +#include #include #include @@ -12,6 +13,7 @@ #include "concretelang/ClientLib/ClientParameters.h" #include "concretelang/Conversion/Utils/GlobalFHEContext.h" #include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" +#include "concretelang/Support/Error.h" #include "concretelang/Support/V0Curves.h" namespace mlir { @@ -20,6 +22,7 @@ namespace concretelang { using ::concretelang::clientlib::BIG_KEY; using ::concretelang::clientlib::CircuitGate; using ::concretelang::clientlib::ClientParameters; +using ::concretelang::clientlib::Encoding; using ::concretelang::clientlib::EncryptionGate; using ::concretelang::clientlib::LweSecretKeyID; using ::concretelang::clientlib::Precision; @@ -52,23 +55,21 @@ llvm::Expected gateFromMLIRType(LweSecretKeyID secretKeyID, }, }; } - if (auto lweType = type.dyn_cast_or_null< - mlir::concretelang::Concrete::LweCiphertextType>()) { - // TODO - Get the width from the LWECiphertextType instead of global - // precision (could be possible after merge concrete-ciphertext-parameter) - size_t precision = (size_t)lweType.getP(); + if (auto lweTy = type.dyn_cast_or_null< + mlir::concretelang::Concrete::LweCiphertextType>()) { return CircuitGate{ /* .encryption = */ llvm::Optional({ /* .secretKeyID = */ secretKeyID, /* .variance = */ variance, /* .encoding = */ { - /* .precision = */ precision, + /* .precision = */ lweTy.getP(), + /* .crt = */ lweTy.getCrtDecomposition().vec(), }, }), /*.shape = */ { - /*.width = */ precision, + /*.width = */ lweTy.getP(), /*.dimensions = */ std::vector(), /*.size = */ 0, }, @@ -99,6 +100,7 @@ createClientParametersForV0(V0FHEContext fheContext, auto v0Param = fheContext.parameter; Variance encryptionVariance = v0Curve->getVariance( v0Param.glweDimension, v0Param.getPolynomialSize(), 64); + // Variance encryptionVariance = 0.; Variance keyswitchVariance = v0Curve->getVariance(1, v0Param.nSmall, 64); // Static client parameters from global parameters for v0 ClientParameters c; @@ -138,10 +140,9 @@ createClientParametersForV0(V0FHEContext fheContext, return op.getName() == functionName; }); if (funcOp == rangeOps.end()) { - return llvm::make_error( - "cannot find the function for generate client parameters '" + - functionName + "'", - llvm::inconvertibleErrorCode()); + return StreamStringError( + "cannot find the function for generate client parameters: ") + << functionName; } // Create input and output circuit gate parameters diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 2d4a59cb9..52d7cbc4f 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -170,7 +170,8 @@ llvm::cl::opt llvm::cl::list jitArgs("jit-args", llvm::cl::desc("Value of arguments to pass to the main func"), - llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore); + llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore, + llvm::cl::MiscFlags::CommaSeparated); llvm::cl::opt jitKeySetCachePath( "jit-keyset-cache-path", @@ -210,6 +211,31 @@ llvm::cl::list v0Parameter( "logPolynomialSize, nSmall, brLevel, brLobBase, ksLevel, ksLogBase]"), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated); +llvm::cl::list largeIntegerCRTDecomposition( + "large-integer-crt-decomposition", + llvm::cl::desc( + "Use the large integer to lower FHE.eint with the given decomposition, " + "must be used with the other large-integers options (experimental)"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated); + +llvm::cl::list largeIntegerPackingKeyswitch( + "large-integer-packing-keyswitch", + llvm::cl::desc( + "Use the large integer to lower FHE.eint with the given parameters for " + "packing keyswitch, must be used with the other large-integers options " + "(experimental) [inputLweDimension, inputLweCount, " + "outputPolynomialSize, level, baseLog]"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated); + +llvm::cl::list largeIntegerCircuitBootstrap( + "large-integer-circuit-bootstrap", + llvm::cl::desc( + "Use the large integer to lower FHE.eint with the given parameters for " + "the cicuit boostrap, must be used with the other large-integers " + "options " + "(experimental) [level, baseLog]"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated); + } // namespace cmdline namespace llvm { @@ -257,6 +283,7 @@ cmdlineCompilationOptions() { if (!cmdline::fhelinalgTileSizes.empty()) options.fhelinalgTileSizes.emplace(cmdline::fhelinalgTileSizes); + // Setup the v0 parameter options if (!cmdline::v0Parameter.empty()) { if (cmdline::v0Parameter.size() != 7) { return llvm::make_error( @@ -270,10 +297,44 @@ cmdlineCompilationOptions() { cmdline::v0Parameter[6]); } - if (!cmdline::v0Constraint.empty() && !cmdline::optimizerV0) { - return llvm::make_error( - "You must use --v0-constraint with --optimizer-v0-strategy", - llvm::inconvertibleErrorCode()); + // Setup the large integer options + if (!cmdline::largeIntegerCRTDecomposition.empty() || + !cmdline::largeIntegerPackingKeyswitch.empty() || + !cmdline::largeIntegerPackingKeyswitch.empty()) { + if (cmdline::largeIntegerCRTDecomposition.empty() || + cmdline::largeIntegerPackingKeyswitch.empty() || + cmdline::largeIntegerPackingKeyswitch.empty()) { + return llvm::make_error( + "The large-integers options should all be set", + llvm::inconvertibleErrorCode()); + } + if (cmdline::largeIntegerPackingKeyswitch.size() != 5) { + return llvm::make_error( + "The large-integers-packing-keyswitch must be a list of 5 integer", + llvm::inconvertibleErrorCode()); + } + if (cmdline::largeIntegerCircuitBootstrap.size() != 2) { + return llvm::make_error( + "The large-integers-packing-keyswitch must be a list of 2 integer", + llvm::inconvertibleErrorCode()); + } + options.largeIntegerParameter = mlir::concretelang::LargeIntegerParameter(); + options.largeIntegerParameter->crtDecomposition = + cmdline::largeIntegerCRTDecomposition; + options.largeIntegerParameter->wopPBS.packingKeySwitch.inputLweDimension = + cmdline::largeIntegerPackingKeyswitch[0]; + options.largeIntegerParameter->wopPBS.packingKeySwitch.inputLweCount = + cmdline::largeIntegerPackingKeyswitch[1]; + options.largeIntegerParameter->wopPBS.packingKeySwitch + .outputPolynomialSize = cmdline::largeIntegerPackingKeyswitch[2]; + options.largeIntegerParameter->wopPBS.packingKeySwitch.level = + cmdline::largeIntegerPackingKeyswitch[3]; + options.largeIntegerParameter->wopPBS.packingKeySwitch.baseLog = + cmdline::largeIntegerPackingKeyswitch[4]; + options.largeIntegerParameter->wopPBS.circuitBootstrap.level = + cmdline::largeIntegerCircuitBootstrap[0]; + options.largeIntegerParameter->wopPBS.circuitBootstrap.baseLog = + cmdline::largeIntegerCircuitBootstrap[1]; } options.optimizerConfig.p_error = cmdline::pbsErrorProbability; @@ -502,9 +563,9 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { } if (cmdline::action == Action::COMPILE) { - auto err = - outputLib->emitArtifacts(/*sharedLib=*/true, /*staticLib=*/true, - /*clientParameters=*/true, /*cppHeader=*/true); + auto err = outputLib->emitArtifacts( + /*sharedLib=*/true, /*staticLib=*/true, + /*clientParameters=*/true, /*cppHeader=*/true); if (err) { return mlir::failure(); } diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe.mlir index dbca786ba..c2f5eec44 100644 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe.mlir +++ b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe.mlir @@ -1,10 +1,19 @@ // RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s -//CHECK: func.func @add_glwe(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>) -> tensor<2049xi64> { +//CHECK: func @add_lwe_ciphertexts(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>) -> tensor<2049xi64> { //CHECK: %[[V0:.*]] = "BConcrete.add_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, tensor<2049xi64>) -> tensor<2049xi64> //CHECK: return %[[V0]] : tensor<2049xi64> //CHECK: } -func.func @add_glwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> { +func.func @add_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> { %0 = "Concrete.add_lwe_ciphertexts"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> return %0 : !Concrete.lwe_ciphertext<2048,7> } + +//CHECK: func @add_crt_lwe_ciphertexts(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> { +//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64> +//CHECK: return %[[V0]] : tensor<5x2049xi64> +//CHECK: } +func.func @add_crt_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext, %arg1: !Concrete.lwe_ciphertext) -> !Concrete.lwe_ciphertext { + %0 = "Concrete.add_lwe_ciphertexts"(%arg0, %arg1) : (!Concrete.lwe_ciphertext, !Concrete.lwe_ciphertext) -> !Concrete.lwe_ciphertext + return %0 : !Concrete.lwe_ciphertext +} diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe_int.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe_int.mlir index a736b823f..4cfcb99a4 100644 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe_int.mlir +++ b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/add_lwe_int.mlir @@ -1,30 +1,40 @@ // RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s -//CHECK: func.func @add_glwe_const_int(%arg0: tensor<1025xi64>) -> tensor<1025xi64> { +//CHECK: func.func @add_glwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> { //CHECK: %c1_i8 = arith.constant 1 : i8 -//CHECK: %0 = arith.extui %c1_i8 : i8 to i64 +//CHECK: %[[V0:.*]] = arith.extui %c1_i8 : i8 to i64 //CHECK: %c56_i64 = arith.constant 56 : i64 -//CHECK: %1 = arith.shli %0, %c56_i64 : i64 -//CHECK: %2 = "BConcrete.add_plaintext_lwe_buffer"(%arg0, %1) : (tensor<1025xi64>, i64) -> tensor<1025xi64> -//CHECK: return %2 : tensor<1025xi64> +//CHECK: %[[V1:.*]] = arith.shli %[[V0]], %c56_i64 : i64 +//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_buffer"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64> +//CHECK: return %[[V2]] : tensor<1025xi64> //CHECK: } func.func @add_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> { %0 = arith.constant 1 : i8 - %1 = "Concrete.encode_int"(%0) : (i8) -> !Concrete.plaintext<8> - %2 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %1) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.plaintext<8>) -> !Concrete.lwe_ciphertext<1024,7> + %2 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7> return %2 : !Concrete.lwe_ciphertext<1024,7> } -//CHECK: func.func @add_glwe_int(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64> { -//CHECK: %0 = arith.extui %arg1 : i5 to i64 +//CHECK: func.func @add_glwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i5) -> tensor<1025xi64> { +//CHECK: %[[V0:.*]] = arith.extui %[[A1]] : i5 to i64 //CHECK: %c59_i64 = arith.constant 59 : i64 -//CHECK: %1 = arith.shli %0, %c59_i64 : i64 -//CHECK: %2 = "BConcrete.add_plaintext_lwe_buffer"(%arg0, %1) : (tensor<1025xi64>, i64) -> tensor<1025xi64> -//CHECK: return %2 : tensor<1025xi64> +//CHECK: %[[V1:.*]] = arith.shli %[[V0]], %c59_i64 : i64 +//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_buffer"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64> +//CHECK: return %[[V2]] : tensor<1025xi64> //CHECK: } func.func @add_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> { - %0 = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5> - %1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<1024,4> + %1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4> return %1 : !Concrete.lwe_ciphertext<1024,4> } + + +//CHECK: func.func @add_plaintext_lwe_ciphertext(%[[A0:.*]]: tensor<5x1025xi64>) -> tensor<5x1025xi64> { +//CHECK: %c1_i8 = arith.constant 1 : i8 +//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_buffer"(%[[A0]], %c1_i8) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i8) -> tensor<5x1025xi64> +//CHECK: return %[[V0]] : tensor<5x1025xi64> +//CHECK: } +func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext) -> !Concrete.lwe_ciphertext { + %0 = arith.constant 1 : i8 + %2 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext, i8) -> !Concrete.lwe_ciphertext + return %2 : !Concrete.lwe_ciphertext +} diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/identity.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/identity.mlir index 3b0224725..e96804928 100644 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/identity.mlir +++ b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/identity.mlir @@ -6,3 +6,10 @@ func.func @identity(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> { return %arg0 : !Concrete.lwe_ciphertext<1024,7> } + +// CHECK: func.func @identity_crt(%arg0: tensor<5x1025xi64>) -> tensor<5x1025xi64> { +// CHECK-NEXT: return %arg0 : tensor<5x1025xi64> +// CHECK-NEXT: } +func.func @identity_crt(%arg0: !Concrete.lwe_ciphertext) -> !Concrete.lwe_ciphertext { + return %arg0 : !Concrete.lwe_ciphertext +} diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/mul_lwe_int.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/mul_lwe_int.mlir index 48eeadc7a..eb70c0cae 100644 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/mul_lwe_int.mlir +++ b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/mul_lwe_int.mlir @@ -1,26 +1,33 @@ // RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s -//CHECK: func.func @mul_lwe_const_int(%arg0: tensor<1025xi64>) -> tensor<1025xi64> { +//CHECK: func.func @mul_lwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> { //CHECK: %c1_i8 = arith.constant 1 : i8 -//CHECK: %0 = arith.extui %c1_i8 : i8 to i64 -//CHECK: %1 = "BConcrete.mul_cleartext_lwe_buffer"(%arg0, %0) : (tensor<1025xi64>, i64) -> tensor<1025xi64> -//CHECK: return %1 : tensor<1025xi64> +//CHECK: %[[V0:.*]] = arith.extui %c1_i8 : i8 to i64 +//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_buffer"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64> +//CHECK: return %[[V1]] : tensor<1025xi64> //CHECK: } - func.func @mul_lwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> { %0 = arith.constant 1 : i8 - %1 = "Concrete.int_to_cleartext"(%0) : (i8) -> !Concrete.cleartext<8> - %2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.cleartext<8>) -> !Concrete.lwe_ciphertext<1024,7> + %2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7> return %2 : !Concrete.lwe_ciphertext<1024,7> } -//CHECK: func.func @mul_lwe_int(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64> { -//CHECK: %0 = arith.extui %arg1 : i5 to i64 -//CHECK: %1 = "BConcrete.mul_cleartext_lwe_buffer"(%arg0, %0) : (tensor<1025xi64>, i64) -> tensor<1025xi64> -//CHECK: return %1 : tensor<1025xi64> +//CHECK: func.func @mul_lwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i5) -> tensor<1025xi64> { +//CHECK: %[[V0:.*]] = arith.extui %[[A1]] : i5 to i64 +//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_buffer"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64> +//CHECK: return %[[V1]] : tensor<1025xi64> //CHECK: } func.func @mul_lwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> { - %0 = "Concrete.int_to_cleartext"(%arg1) : (i5) -> !Concrete.cleartext<5> - %1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.cleartext<5>) -> !Concrete.lwe_ciphertext<1024,4> + %1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4> return %1 : !Concrete.lwe_ciphertext<1024,4> } + + +//CHECK: func.func @mul_cleartext_lwe_ciphertext_crt(%[[A0:.*]]: tensor<5x1025xi64>, %[[A1:.*]]: i5) -> tensor<5x1025xi64> { +//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i5) -> tensor<5x1025xi64> +//CHECK: return %[[V0]] : tensor<5x1025xi64> +//CHECK: } +func.func @mul_cleartext_lwe_ciphertext_crt(%arg0: !Concrete.lwe_ciphertext, %arg1: i5) -> !Concrete.lwe_ciphertext { + %1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext, i5) -> !Concrete.lwe_ciphertext + return %1 : !Concrete.lwe_ciphertext +} diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/neg_lwe.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/neg_lwe.mlir index 0f6d9bdd3..6157557e7 100644 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/neg_lwe.mlir +++ b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/neg_lwe.mlir @@ -8,3 +8,12 @@ func.func @neg_lwe(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_cip %0 = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4> return %0 : !Concrete.lwe_ciphertext<1024,4> } + +//CHECK: func.func @negate_lwe_ciphertext(%[[A0:.*]]: tensor<5x1025xi64>) -> tensor<5x1025xi64> { +//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_buffer"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>) -> tensor<5x1025xi64> +//CHECK: return %[[V0]] : tensor<5x1025xi64> +//CHECK: } +func.func @negate_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext) -> !Concrete.lwe_ciphertext { + %0 = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext) -> !Concrete.lwe_ciphertext + return %0 : !Concrete.lwe_ciphertext +} diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_exapand_collapse_shape.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_exapand_collapse_shape.mlir index 6cf6c66dc..401b8a880 100644 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_exapand_collapse_shape.mlir +++ b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_exapand_collapse_shape.mlir @@ -1,24 +1,20 @@ // RUN: concretecompiler --split-input-file --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s -// CHECK: func -// DISABLED-CHECK: func.func @tensor_collapse_shape(%arg0: tensor<2x3x4x5x6x1025xi64>) -> tensor<720x1025xi64> { -// DISABLED-CHECK-NEXT: %0 = bufferization.to_memref %arg0 : memref<2x3x4x5x6x1025xi64> -// DISABLED-CHECK-NEXT: %1 = memref.collapse_shape %0 [[_:\[\[0, 1, 2, 3, 4\], \[5\]\]]] : memref<2x3x4x5x6x1025xi64> into memref<720x1025xi64> -// DISABLED-CHECK-NEXT: %2 = bufferization.to_tensor %1 : memref<720x1025xi64> -// DISABLED-CHECK-NEXT: return %2 : tensor<720x1025xi64> +//CHECK: func.func @tensor_collapse_shape(%[[A0:.*]]: tensor<2x3x4x5x6x1025xi64>) -> tensor<720x1025xi64> { +//CHECK: %[[V0:.*]] = tensor.collapse_shape %[[A0]] [[_:\[\[0, 1, 2, 3, 4\], \[5\]\]]] : tensor<2x3x4x5x6x1025xi64> into tensor<720x1025xi64> +//CHECK: return %[[V0]] : tensor<720x1025xi64> +//CHECK: } func.func @tensor_collapse_shape(%arg0: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<720x!Concrete.lwe_ciphertext<1024,4>> { %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4]] {MANP = 1 : ui1}: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext<1024,4>> into tensor<720x!Concrete.lwe_ciphertext<1024,4>> return %0 : tensor<720x!Concrete.lwe_ciphertext<1024,4>> } // ----- - -// DISABLED-CHECK: func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x5x1025xi64>) -> tensor<5x6x1025xi64> { -// DISABLED-CHECK-NEXT: %0 = bufferization.to_memref %arg0 : memref<2x3x5x1025xi64> -// DISABLED-CHECK-NEXT: %1 = memref.collapse_shape %0 [[_:\[\[0, 1, 2\], \[3\]\]]] : memref<2x3x5x1025xi64> into memref<30x1025xi64> -// DISABLED-CHECK-NEXT: %2 = memref.expand_shape %1 [[_:\[\[0, 1\], \[2\]\]]] : memref<30x1025xi64> into memref<5x6x1025xi64> -// DISABLED-CHECK-NEXT: %3 = bufferization.to_tensor %2 : memref<5x6x1025xi64> -// DISABLED-CHECK-NEXT: return %3 : tensor<5x6x1025xi64> +//CHECK: func.func @tensor_collatenspse_shape(%[[A0:.*]]: tensor<2x3x5x1025xi64>) -> tensor<5x6x1025xi64> { +//CHECK: %[[V0:.*]] = tensor.collapse_shape %[[A0]] [[_:\[\[0, 1, 2\], \[3\]\]]] : tensor<2x3x5x1025xi64> into tensor<30x1025xi64> +//CHECK: %[[V1:.*]] = tensor.expand_shape %[[V0]] [[_:\[\[0, 1\], \[2\]\]]] : tensor<30x1025xi64> into tensor<5x6x1025xi64> +//CHECK: return %[[V1]] : tensor<5x6x1025xi64> +//CHECK: } func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x5x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<5x6x!Concrete.lwe_ciphertext<1024,4>> { %0 = tensor.collapse_shape %arg0 [[0, 1, 2]] {MANP = 1 : ui1}: tensor<2x3x5x!Concrete.lwe_ciphertext<1024,4>> into tensor<30x!Concrete.lwe_ciphertext<1024,4>> %1 = tensor.expand_shape %0 [[0, 1]] {MANP = 1 : ui1}: tensor<30x!Concrete.lwe_ciphertext<1024,4>> into tensor<5x6x!Concrete.lwe_ciphertext<1024,4>> @@ -26,13 +22,31 @@ func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x5x!Concrete.lwe_ciphertex } // ----- - -// DISABLED-CHECK: func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x2x3x4x1025xi64>) -> tensor<6x2x12x1025xi64> { -// DISABLED-CHECK-NEXT: %0 = bufferization.to_memref %arg0 : memref<2x3x2x3x4x1025xi64> -// DISABLED-CHECK-NEXT: %1 = memref.collapse_shape %0 [[_:\[\[0, 1\], \[2\], \[3, 4\], \[5\]\]]] : memref<2x3x2x3x4x1025xi64> into memref<6x2x12x1025xi64> -// DISABLED-CHECK-NEXT: %2 = bufferization.to_tensor %1 : memref<6x2x12x1025xi64> -// DISABLED-CHECK-NEXT: return %2 : tensor<6x2x12x1025xi64> +//CHECK: func.func @tensor_collatenspse_shape(%[[A0:.*]]: tensor<2x3x2x3x4x1025xi64>) -> tensor<6x2x12x1025xi64> { +//CHECK: %[[V0:.*]] = tensor.collapse_shape %[[A0]] [[_:\[\[0, 1\], \[2\], \[3, 4\], \[5\]\]]] : tensor<2x3x2x3x4x1025xi64> into tensor<6x2x12x1025xi64> +//CHECK: return %[[V0]] : tensor<6x2x12x1025xi64> +//CHECK: } func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x2x3x4x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<6x2x12x!Concrete.lwe_ciphertext<1024,4>> { %0 = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] {MANP = 1 : ui1}: tensor<2x3x2x3x4x!Concrete.lwe_ciphertext<1024,4>> into tensor<6x2x12x!Concrete.lwe_ciphertext<1024,4>> return %0 : tensor<6x2x12x!Concrete.lwe_ciphertext<1024,4>> } + +// ----- +//CHECK: func.func @tensor_collapse_shape_crt(%[[A0:.*]]: tensor<2x3x4x5x6x5x1025xi64>) -> tensor<720x5x1025xi64> { +//CHECK: %[[V0:.*]] = tensor.collapse_shape %[[A0]] [[_:\[\[0, 1, 2, 3, 4\], \[5\], \[6\]\]]] : tensor<2x3x4x5x6x5x1025xi64> into tensor<720x5x1025xi64> +//CHECK: return %[[V0]] : tensor<720x5x1025xi64> +//CHECK: } +func.func @tensor_collapse_shape_crt(%arg0: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext>) -> tensor<720x!Concrete.lwe_ciphertext> { + %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4]] {MANP = 1 : ui1}: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext> into tensor<720x!Concrete.lwe_ciphertext> + return %0 : tensor<720x!Concrete.lwe_ciphertext> +} + +// ----- +//CHECK: func.func @tensor_expand_shape_crt(%[[A0:.*]]: tensor<30x1025xi64>) -> tensor<5x6x1025xi64> { +//CHECK: %[[V0:.*]] = tensor.expand_shape %[[A0]] [[_:\[\[0, 1\], \[2\]\]]] : tensor<30x1025xi64> into tensor<5x6x1025xi64> +//CHECK: return %[[V0]] : tensor<5x6x1025xi64> +//CHECK: } +func.func @tensor_expand_shape_crt(%arg0: tensor<30x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<5x6x!Concrete.lwe_ciphertext<1024,4>> { + %0 = tensor.expand_shape %arg0 [[0, 1]] {MANP = 1 : ui1}: tensor<30x!Concrete.lwe_ciphertext<1024,4>> into tensor<5x6x!Concrete.lwe_ciphertext<1024,4>> + return %0 : tensor<5x6x!Concrete.lwe_ciphertext<1024,4>> +} diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_identity.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_identity.mlir index ccb542826..a2bc4fad7 100644 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_identity.mlir +++ b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/tensor_identity.mlir @@ -1,7 +1,15 @@ // RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s + // CHECK: func.func @tensor_identity(%arg0: tensor<2x3x4x1025xi64>) -> tensor<2x3x4x1025xi64> { // CHECK-NEXT: return %arg0 : tensor<2x3x4x1025xi64> // CHECK-NEXT: } func.func @tensor_identity(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>> { return %arg0 : tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>> } + +// CHECK: func.func @tensor_identity_crt(%arg0: tensor<2x3x4x5x1025xi64>) -> tensor<2x3x4x5x1025xi64> { +// CHECK-NEXT: return %arg0 : tensor<2x3x4x5x1025xi64> +// CHECK-NEXT: } +func.func @tensor_identity_crt(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext>) -> tensor<2x3x4x!Concrete.lwe_ciphertext> { + return %arg0 : tensor<2x3x4x!Concrete.lwe_ciphertext> +} diff --git a/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe_int.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe_int.mlir index 0878bb32c..51c81b7c2 100644 --- a/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe_int.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEToConcrete/add_glwe_int.mlir @@ -1,22 +1,22 @@ // RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s -// CHECK-LABEL: func.func @add_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> +//CHECK: func.func @add_glwe_const_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> { +//CHECK: %c1_i8 = arith.constant 1 : i8 +//CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[A0]], %c1_i8) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7> +//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,7> +//CHECK: } func.func @add_glwe_const_int(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{7}> { - // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8 - // CHECK-NEXT: %[[V2:.*]] = "Concrete.encode_int"(%[[V1]]) : (i8) -> !Concrete.plaintext<8> - // CHECK-NEXT: %[[V3:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %[[V2]]) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.plaintext<8>) -> !Concrete.lwe_ciphertext<1024,7> - // CHECK-NEXT: return %[[V3]] : !Concrete.lwe_ciphertext<1024,7> %0 = arith.constant 1 : i8 %1 = "TFHE.add_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,1,64}{7}>, i8) -> (!TFHE.glwe<{1024,1,64}{7}>) return %1: !TFHE.glwe<{1024,1,64}{7}> } -// CHECK-LABEL: func.func @add_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> +//CHECK: func.func @add_glwe_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i5) -> !Concrete.lwe_ciphertext<1024,4> { +//CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[A0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4> +//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,4> +//CHECK: } func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !TFHE.glwe<{1024,1,64}{4}> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5> - // CHECK-NEXT: %[[V2:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %[[V1]]) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<1024,4> - // CHECK-NEXT: return %[[V2]] : !Concrete.lwe_ciphertext<1024,4> %1 = "TFHE.add_glwe_int"(%arg0, %arg1): (!TFHE.glwe<{1024,1,64}{4}>, i5) -> (!TFHE.glwe<{1024,1,64}{4}>) return %1: !TFHE.glwe<{1024,1,64}{4}> } diff --git a/compiler/tests/check_tests/Conversion/TFHEToConcrete/mul_glwe_int.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/mul_glwe_int.mlir index 720212115..3195bcd9c 100644 --- a/compiler/tests/check_tests/Conversion/TFHEToConcrete/mul_glwe_int.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEToConcrete/mul_glwe_int.mlir @@ -1,22 +1,22 @@ // RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s -// CHECK-LABEL: func.func @mul_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> +//CHECK: func.func @mul_glwe_const_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> { +//CHECK: %c1_i8 = arith.constant 1 : i8 +//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%[[A0]], %c1_i8) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7> +//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,7> +//CHECK: } func.func @mul_glwe_const_int(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{7}> { - // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8 - // CHECK-NEXT: %[[V2:.*]] = "Concrete.int_to_cleartext"(%[[V1]]) : (i8) -> !Concrete.cleartext<8> - // CHECK-NEXT: %[[V3:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %[[V2]]) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.cleartext<8>) -> !Concrete.lwe_ciphertext<1024,7> - // CHECK-NEXT: return %[[V3]] : !Concrete.lwe_ciphertext<1024,7> %0 = arith.constant 1 : i8 %1 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,1,64}{7}>, i8) -> (!TFHE.glwe<{1024,1,64}{7}>) return %1: !TFHE.glwe<{1024,1,64}{7}> } -// CHECK-LABEL: func.func @mul_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> +//CHECK: func.func @mul_glwe_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i5) -> !Concrete.lwe_ciphertext<1024,4> { +//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%[[A0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4> +//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,4> +//CHECK: } func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !TFHE.glwe<{1024,1,64}{4}> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.int_to_cleartext"(%arg1) : (i5) -> !Concrete.cleartext<5> - // CHECK-NEXT: %[[V2:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %[[V1]]) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.cleartext<5>) -> !Concrete.lwe_ciphertext<1024,4> - // CHECK-NEXT: return %[[V2]] : !Concrete.lwe_ciphertext<1024,4> %1 = "TFHE.mul_glwe_int"(%arg0, %arg1): (!TFHE.glwe<{1024,1,64}{4}>, i5) -> (!TFHE.glwe<{1024,1,64}{4}>) return %1: !TFHE.glwe<{1024,1,64}{4}> } diff --git a/compiler/tests/check_tests/Conversion/TFHEToConcrete/sub_int_glwe.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/sub_int_glwe.mlir index 593e889c3..8d3703554 100644 --- a/compiler/tests/check_tests/Conversion/TFHEToConcrete/sub_int_glwe.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEToConcrete/sub_int_glwe.mlir @@ -1,23 +1,23 @@ // RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s -// CHECK-LABEL: func.func @sub_const_int_glwe(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> +//CHECK: func.func @sub_const_int_glwe(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> { +//CHECK: %c1_i8 = arith.constant 1 : i8 +//CHECK: %[[V0:.*]] = "Concrete.negate_lwe_ciphertext"(%[[A0]]) : (!Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> +//CHECK: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[V0]], %c1_i8) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7> +//CHECK: return %[[V1]] : !Concrete.lwe_ciphertext<1024,7> +//CHECK: } func.func @sub_const_int_glwe(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{7}> { - // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8 - // CHECK-NEXT: %[[NEG:.*]] = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> - // CHECK-NEXT: %[[V2:.*]] = "Concrete.encode_int"(%[[V1]]) : (i8) -> !Concrete.plaintext<8> - // CHECK-NEXT: %[[V3:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[NEG]], %[[V2]]) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.plaintext<8>) -> !Concrete.lwe_ciphertext<1024,7> - // CHECK-NEXT: return %[[V3]] : !Concrete.lwe_ciphertext<1024,7> %0 = arith.constant 1 : i8 %1 = "TFHE.sub_int_glwe"(%0, %arg0): (i8, !TFHE.glwe<{1024,1,64}{7}>) -> (!TFHE.glwe<{1024,1,64}{7}>) return %1: !TFHE.glwe<{1024,1,64}{7}> } -// CHECK-LABEL: func.func @sub_int_glwe(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> +//CHECK: func.func @sub_int_glwe(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i5) -> !Concrete.lwe_ciphertext<1024,4> { +//CHECK: %[[V0:.*]] = "Concrete.negate_lwe_ciphertext"(%[[A0]]) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4> +//CHECK: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[V0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4> +//CHECK: return %[[V1]] : !Concrete.lwe_ciphertext<1024,4> +//CHECK: } func.func @sub_int_glwe(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !TFHE.glwe<{1024,1,64}{4}> { - // CHECK-NEXT: %[[NEG:.*]] = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4> - // CHECK-NEXT: %[[V1:.*]] = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5> - // CHECK-NEXT: %[[V2:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[NEG]], %[[V1]]) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<1024,4> - // CHECK-NEXT: return %[[V2]] : !Concrete.lwe_ciphertext<1024,4> %1 = "TFHE.sub_int_glwe"(%arg1, %arg0): (i5, !TFHE.glwe<{1024,1,64}{4}>) -> (!TFHE.glwe<{1024,1,64}{4}>) return %1: !TFHE.glwe<{1024,1,64}{4}> } diff --git a/compiler/tests/check_tests/Dialect/BConcrete/ops.mlir b/compiler/tests/check_tests/Dialect/BConcrete/ops.mlir index 7ce0e30b8..e57b7d656 100644 --- a/compiler/tests/check_tests/Dialect/BConcrete/ops.mlir +++ b/compiler/tests/check_tests/Dialect/BConcrete/ops.mlir @@ -9,6 +9,15 @@ func.func @add_lwe_ciphertexts(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>) return %0 : tensor<2049xi64> } +//CHECK: func.func @add_crt_lwe_ciphertexts(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> { +//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64> +//CHECK: return %[[V0]] : tensor<5x2049xi64> +//CHECK: } +func.func @add_crt_lwe_ciphertexts(%arg0: tensor<5x2049xi64>, %arg1: tensor<5x2049xi64>) -> tensor<5x2049xi64> { + %0 = "BConcrete.add_crt_lwe_buffer"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> ( tensor<5x2049xi64>) + return %0 : tensor<5x2049xi64> +} + //CHECK: func.func @add_plaintext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> { //CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64> //CHECK: return %[[V0]] : tensor<2049xi64> @@ -18,7 +27,16 @@ func.func @add_plaintext_lwe_ciphertext(%arg0: tensor<2049xi64>, %arg1: i64) -> return %0 : tensor<2049xi64> } -//CHECK: func.func @mul_cleartext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> { +//CHECK: func.func @add_plaintext_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: i64) -> tensor<5x2049xi64> { +//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64> +//CHECK: return %[[V0]] : tensor<5x2049xi64> +//CHECK: } +func.func @add_plaintext_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>, %arg1: i64) -> tensor<5x2049xi64> { + %0 = "BConcrete.add_plaintext_crt_lwe_buffer"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> ( tensor<5x2049xi64>) + return %0 : tensor<5x2049xi64> +} + +//CHECK: func @mul_cleartext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> { //CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64> //CHECK: return %[[V0]] : tensor<2049xi64> //CHECK: } @@ -27,6 +45,15 @@ func.func @mul_cleartext_lwe_ciphertext(%arg0: tensor<2049xi64>, %arg1: i64) -> return %0 : tensor<2049xi64> } +//CHECK: func.func @mul_cleartext_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: i64) -> tensor<5x2049xi64> { +//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64> +//CHECK: return %[[V0]] : tensor<5x2049xi64> +//CHECK: } +func.func @mul_cleartext_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>, %arg1: i64) -> tensor<5x2049xi64> { + %0 = "BConcrete.mul_cleartext_crt_lwe_buffer"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> (tensor<5x2049xi64>) + return %0 : tensor<5x2049xi64> +} + //CHECK: func.func @negate_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> { //CHECK: %[[V0:.*]] = "BConcrete.negate_lwe_buffer"(%[[A0]]) : (tensor<2049xi64>) -> tensor<2049xi64> //CHECK: return %[[V0]] : tensor<2049xi64> @@ -36,6 +63,15 @@ func.func @negate_lwe_ciphertext(%arg0: tensor<2049xi64>) -> tensor<2049xi64> { return %0 : tensor<2049xi64> } +//CHECK: func.func @negate_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> { +//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_buffer"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> tensor<5x2049xi64> +//CHECK: return %[[V0]] : tensor<5x2049xi64> +//CHECK: } +func.func @negate_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>) -> tensor<5x2049xi64> { + %0 = "BConcrete.negate_crt_lwe_buffer"(%arg0) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> (tensor<5x2049xi64>) + return %0 : tensor<5x2049xi64> +} + //CHECK: func.func @bootstrap_lwe(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<4096xi64>) -> tensor<2049xi64> { //CHECK: %[[V0:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[A0]], %[[A1]]) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (tensor<2049xi64>, tensor<4096xi64>) -> tensor<2049xi64> //CHECK: return %[[V0]] : tensor<2049xi64> diff --git a/compiler/tests/check_tests/Dialect/Concrete/no_optimization.mlir b/compiler/tests/check_tests/Dialect/Concrete/no_optimization.mlir index e90c87268..ec69bb7a7 100644 --- a/compiler/tests/check_tests/Dialect/Concrete/no_optimization.mlir +++ b/compiler/tests/check_tests/Dialect/Concrete/no_optimization.mlir @@ -1,14 +1,12 @@ // RUN: concretecompiler --optimize-concrete=false --action=dump-concrete %s 2>&1| FileCheck %s -// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> +//CHECK: func.func @mul_cleartext_lwe_ciphertext_0(%[[A0:.*]]: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> { +//CHECK: %c0_i7 = arith.constant 0 : i7 +//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%[[A0]], %c0_i7) : (!Concrete.lwe_ciphertext<2048,7>, i7) -> !Concrete.lwe_ciphertext<2048,7> +//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<2048,7> +//CHECK: } func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> { - // CHECK-NEXT: %[[V0:.*]] = arith.constant 0 : i7 - // CHECK-NEXT: %[[V1:.*]] = "Concrete.int_to_cleartext"(%[[V0]]) : (i7) -> !Concrete.cleartext<7> - // CHECK-NEXT: %[[V2:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %[[V1]]) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7> - // CHECK-NEXT: return %[[V2]] : !Concrete.lwe_ciphertext<2048,7> - %0 = arith.constant 0 : i7 - %1 = "Concrete.int_to_cleartext"(%0) : (i7) -> !Concrete.cleartext<7> - %2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>) + %2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0): (!Concrete.lwe_ciphertext<2048,7>, i7) -> (!Concrete.lwe_ciphertext<2048,7>) return %2: !Concrete.lwe_ciphertext<2048,7> } diff --git a/compiler/tests/check_tests/Dialect/Concrete/ops.mlir b/compiler/tests/check_tests/Dialect/Concrete/ops.mlir index 5aff8aa08..70135e8b3 100644 --- a/compiler/tests/check_tests/Dialect/Concrete/ops.mlir +++ b/compiler/tests/check_tests/Dialect/Concrete/ops.mlir @@ -9,21 +9,21 @@ func.func @add_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: ! return %1: !Concrete.lwe_ciphertext<2048,7> } -// CHECK-LABEL: func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<2048,7> -func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<2048,7> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<2048,7> +// CHECK-LABEL: func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i5) -> !Concrete.lwe_ciphertext<2048,7> +func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i5) -> !Concrete.lwe_ciphertext<2048,7> { + // CHECK-NEXT: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, i5) -> !Concrete.lwe_ciphertext<2048,7> // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> - %1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.plaintext<5>) -> (!Concrete.lwe_ciphertext<2048,7>) + %1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, i5) -> (!Concrete.lwe_ciphertext<2048,7>) return %1: !Concrete.lwe_ciphertext<2048,7> } -// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7> -func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7> +// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i7) -> !Concrete.lwe_ciphertext<2048,7> +func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i7) -> !Concrete.lwe_ciphertext<2048,7> { + // CHECK-NEXT: %[[V1:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, i7) -> !Concrete.lwe_ciphertext<2048,7> // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> - %1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>) + %1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, i7) -> (!Concrete.lwe_ciphertext<2048,7>) return %1: !Concrete.lwe_ciphertext<2048,7> } @@ -38,9 +38,9 @@ func.func @negate_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Co // CHECK-LABEL: func.func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7> func.func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7> + // CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, level = -1 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7> // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> - %1 = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7> + %1 = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, level = -1 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7> return %1: !Concrete.lwe_ciphertext<2048,7> } @@ -51,22 +51,3 @@ func.func @keyswitch_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.l %1 = "Concrete.keyswitch_lwe"(%arg0){baseLog = 2 : i32, level = 3 : i32}: (!Concrete.lwe_ciphertext<2048,7>) -> (!Concrete.lwe_ciphertext<2048,7>) return %1: !Concrete.lwe_ciphertext<2048,7> } - -// CHECK-LABEL: func.func @encode_int(%arg0: i6) -> !Concrete.plaintext<6> -func.func @encode_int(%arg0: i6) -> (!Concrete.plaintext<6>) { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.encode_int"(%arg0) : (i6) -> !Concrete.plaintext<6> - // CHECK-NEXT: return %[[V1]] : !Concrete.plaintext<6> - - %0 = "Concrete.encode_int"(%arg0): (i6) -> !Concrete.plaintext<6> - return %0: !Concrete.plaintext<6> -} - -// CHECK-LABEL: func.func @int_to_cleartext() -> !Concrete.cleartext<6> -func.func @int_to_cleartext() -> !Concrete.cleartext<6> { - // CHECK-NEXT: %[[V0:.*]] = arith.constant 5 : i6 - // CHECK-NEXT: %[[V1:.*]] = "Concrete.int_to_cleartext"(%[[V0]]) : (i6) -> !Concrete.cleartext<6> - // CHECK-NEXT: return %[[V1]] : !Concrete.cleartext<6> - %0 = arith.constant 5 : i6 - %1 = "Concrete.int_to_cleartext"(%0) : (i6) -> !Concrete.cleartext<6> - return %1 : !Concrete.cleartext<6> -} diff --git a/compiler/tests/check_tests/Dialect/Concrete/optimization.mlir b/compiler/tests/check_tests/Dialect/Concrete/optimization.mlir index 5aba799b8..5f10ec7a0 100644 --- a/compiler/tests/check_tests/Dialect/Concrete/optimization.mlir +++ b/compiler/tests/check_tests/Dialect/Concrete/optimization.mlir @@ -1,14 +1,14 @@ // RUN: concretecompiler --action=dump-concrete %s 2>&1| FileCheck %s -// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7> - func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7> { - // CHECK-NEXT: %[[V1:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7> - // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> +// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i7) -> !Concrete.lwe_ciphertext<2048,7> +func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i7) -> !Concrete.lwe_ciphertext<2048,7> { + // CHECK-NEXT: %[[V1:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, i7) -> !Concrete.lwe_ciphertext<2048,7> + // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> - %1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>) - return %1: !Concrete.lwe_ciphertext<2048,7> - } + %1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, i7) -> (!Concrete.lwe_ciphertext<2048,7>) + return %1: !Concrete.lwe_ciphertext<2048,7> +} // CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> { @@ -16,8 +16,7 @@ func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7 // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> %0 = arith.constant 0 : i7 - %1 = "Concrete.int_to_cleartext"(%0) : (i7) -> !Concrete.cleartext<7> - %2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>) + %2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0): (!Concrete.lwe_ciphertext<2048,7>, i7) -> (!Concrete.lwe_ciphertext<2048,7>) return %2: !Concrete.lwe_ciphertext<2048,7> } @@ -27,7 +26,6 @@ func.func @mul_cleartext_lwe_ciphertext_minus_1(%arg0: !Concrete.lwe_ciphertext< // CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7> %0 = arith.constant -1 : i7 - %1 = "Concrete.int_to_cleartext"(%0) : (i7) -> !Concrete.cleartext<7> - %2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>) + %2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0): (!Concrete.lwe_ciphertext<2048,7>, i7) -> (!Concrete.lwe_ciphertext<2048,7>) return %2: !Concrete.lwe_ciphertext<2048,7> } diff --git a/compiler/tests/check_tests/Dialect/Concrete/types.mlir b/compiler/tests/check_tests/Dialect/Concrete/types.mlir index 0ec07ef2e..224ba8120 100644 --- a/compiler/tests/check_tests/Dialect/Concrete/types.mlir +++ b/compiler/tests/check_tests/Dialect/Concrete/types.mlir @@ -13,7 +13,13 @@ func.func @type_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Conc return %arg0: !Concrete.lwe_ciphertext<2048,7> } -// CHECK-LABEL: func.func @type_cleartext(%arg0: !Concrete.cleartext<5>) -> !Concrete.cleartext<5> +// CHECK-LABEL: func @type_lwe_ciphertext_with_crt(%arg0: !Concrete.lwe_ciphertext) -> !Concrete.lwe_ciphertext +func.func @type_lwe_ciphertext_with_crt(%arg0: !Concrete.lwe_ciphertext) -> !Concrete.lwe_ciphertext { + // CHECK-NEXT: return %arg0 : !Concrete.lwe_ciphertext + return %arg0: !Concrete.lwe_ciphertext +} + +// CHECK-LABEL: func @type_cleartext(%arg0: !Concrete.cleartext<5>) -> !Concrete.cleartext<5> func.func @type_cleartext(%arg0: !Concrete.cleartext<5>) -> !Concrete.cleartext<5> { // CHECK-NEXT: return %arg0 : !Concrete.cleartext<5> return %arg0: !Concrete.cleartext<5> diff --git a/compiler/tests/check_tests/Dialect/TFHE/op_add_glwe.invalid.mlir b/compiler/tests/check_tests/Dialect/TFHE/op_add_glwe.invalid.mlir index 563076d19..799035b02 100644 --- a/compiler/tests/check_tests/Dialect/TFHE/op_add_glwe.invalid.mlir +++ b/compiler/tests/check_tests/Dialect/TFHE/op_add_glwe.invalid.mlir @@ -52,3 +52,20 @@ func.func @add_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>, %arg1: !TFHE.glwe<{1024, return %1: !TFHE.glwe<{1024,12,64}{7}> } +// ----- + +// GLWE polynomialSize parameter result +func.func @add_glwe(%arg0: !TFHE.glwe, %arg1: !TFHE.glwe) -> !TFHE.glwe { + // expected-error @+1 {{'TFHE.add_glwe' op should have the same GLWE 'crt' parameter}} + %1 = "TFHE.add_glwe"(%arg0, %arg1): (!TFHE.glwe, !TFHE.glwe) -> (!TFHE.glwe) + return %1: !TFHE.glwe +} + +// ----- + +// GLWE polynomialSize parameter inputs +func.func @add_glwe(%arg0: !TFHE.glwe, %arg1: !TFHE.glwe) -> !TFHE.glwe { + // expected-error @+1 {{'TFHE.add_glwe' op should have the same GLWE 'crt' parameter}} + %1 = "TFHE.add_glwe"(%arg0, %arg1): (!TFHE.glwe, !TFHE.glwe) -> (!TFHE.glwe) + return %1: !TFHE.glwe +} diff --git a/compiler/tests/check_tests/Dialect/TFHE/op_add_glwe_int.invalid.mlir b/compiler/tests/check_tests/Dialect/TFHE/op_add_glwe_int.invalid.mlir index 7dcf3b0be..b7336e673 100644 --- a/compiler/tests/check_tests/Dialect/TFHE/op_add_glwe_int.invalid.mlir +++ b/compiler/tests/check_tests/Dialect/TFHE/op_add_glwe_int.invalid.mlir @@ -30,6 +30,16 @@ func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024, // ----- +// GLWE crt parameter +func.func @add_glwe_int(%arg0: !TFHE.glwe) -> !TFHE.glwe { + %0 = arith.constant 1 : i8 + // expected-error @+1 {{'TFHE.add_glwe_int' op should have the same GLWE 'crt' parameter}} + %1 = "TFHE.add_glwe_int"(%arg0, %0): (!TFHE.glwe, i8) -> (!TFHE.glwe) + return %1: !TFHE.glwe +} + +// ----- + // integer width doesn't match GLWE parameter func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}> { %0 = arith.constant 1 : i9 diff --git a/compiler/tests/check_tests/Dialect/TFHE/op_mul_glwe_int.invalid.mlir b/compiler/tests/check_tests/Dialect/TFHE/op_mul_glwe_int.invalid.mlir index f1118aeba..ba7379cd6 100644 --- a/compiler/tests/check_tests/Dialect/TFHE/op_mul_glwe_int.invalid.mlir +++ b/compiler/tests/check_tests/Dialect/TFHE/op_mul_glwe_int.invalid.mlir @@ -30,6 +30,16 @@ func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024, // ----- +// GLWE crt parameter +func.func @mul_glwe_int(%arg0: !TFHE.glwe) -> !TFHE.glwe { + %0 = arith.constant 1 : i8 + // expected-error @+1 {{'TFHE.mul_glwe_int' op should have the same GLWE 'crt' parameter}} + %1 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe, i8) -> (!TFHE.glwe) + return %1: !TFHE.glwe +} + +// ----- + // integer width doesn't match GLWE parameter func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}> { %0 = arith.constant 1 : i9 diff --git a/compiler/tests/check_tests/Dialect/TFHE/op_neg_glwe.invalid.mlir b/compiler/tests/check_tests/Dialect/TFHE/op_neg_glwe.invalid.mlir index dc81bb31e..26747b61c 100644 --- a/compiler/tests/check_tests/Dialect/TFHE/op_neg_glwe.invalid.mlir +++ b/compiler/tests/check_tests/Dialect/TFHE/op_neg_glwe.invalid.mlir @@ -27,6 +27,15 @@ func.func @neg_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,11,6 // ----- +// GLWE crt parameter +func.func @neg_glwe(%arg0: !TFHE.glwe) -> !TFHE.glwe { + // expected-error @+1 {{'TFHE.neg_glwe' op should have the same GLWE 'crt' parameter}} + %1 = "TFHE.neg_glwe"(%arg0): (!TFHE.glwe) -> (!TFHE.glwe) + return %1: !TFHE.glwe +} + +// ----- + // integer width doesn't match GLWE parameter func.func @neg_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,11,64}{7}> { // expected-error @+1 {{'TFHE.neg_glwe' op should have the same GLWE 'polynomialSize' parameter}} diff --git a/compiler/tests/check_tests/Dialect/TFHE/op_sub_int_glwe.invalid.mlir b/compiler/tests/check_tests/Dialect/TFHE/op_sub_int_glwe.invalid.mlir index eaca7061a..6de784f1f 100644 --- a/compiler/tests/check_tests/Dialect/TFHE/op_sub_int_glwe.invalid.mlir +++ b/compiler/tests/check_tests/Dialect/TFHE/op_sub_int_glwe.invalid.mlir @@ -28,6 +28,17 @@ func.func @sub_int_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024, return %1: !TFHE.glwe<{1024,11,64}{7}> } +// ----- + +// GLWE polynomialSize parameter +func.func @sub_int_glwe(%arg0: !TFHE.glwe) -> !TFHE.glwe { + %0 = arith.constant 1 : i8 + // expected-error @+1 {{'TFHE.sub_int_glwe' op should have the same GLWE 'crt' parameter}} + %1 = "TFHE.sub_int_glwe"(%0, %arg0): (i8, !TFHE.glwe) -> (!TFHE.glwe) + return %1: !TFHE.glwe +} + + // ----- // integer width doesn't match GLWE parameter diff --git a/compiler/tests/check_tests/Dialect/TFHE/types_glwe.mlir b/compiler/tests/check_tests/Dialect/TFHE/types_glwe.mlir index 5c232d96b..c06fa82df 100644 --- a/compiler/tests/check_tests/Dialect/TFHE/types_glwe.mlir +++ b/compiler/tests/check_tests/Dialect/TFHE/types_glwe.mlir @@ -11,3 +11,15 @@ func.func @glwe_1(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> { // CHECK-LABEL: return %arg0 : !TFHE.glwe<{_,_,_}{7}> return %arg0: !TFHE.glwe<{_,_,_}{7}> } + +// CHECK-LABEL: func.func @glwe_crt(%arg0: !TFHE.glwe) -> !TFHE.glwe +func.func @glwe_crt(%arg0: !TFHE.glwe) -> !TFHE.glwe { + // CHECK-LABEL: return %arg0 : !TFHE.glwe + return %arg0: !TFHE.glwe +} + +// CHECK-LABEL: func.func @glwe_crt_undef(%arg0: !TFHE.glwe) -> !TFHE.glwe +func.func @glwe_crt_undef(%arg0: !TFHE.glwe) -> !TFHE.glwe { + // CHECK-LABEL: return %arg0 : !TFHE.glwe + return %arg0: !TFHE.glwe +} diff --git a/compiler/tests/end_to_end_fixture/EndToEndFixture.cpp b/compiler/tests/end_to_end_fixture/EndToEndFixture.cpp index 9fc6baba4..024ef7f71 100644 --- a/compiler/tests/end_to_end_fixture/EndToEndFixture.cpp +++ b/compiler/tests/end_to_end_fixture/EndToEndFixture.cpp @@ -67,7 +67,7 @@ llvm::Error checkResult(ScalarDesc &desc, } if (desc.value != res64->getValue()) { return StreamStringError("unexpected result value: got ") - << res64->getValue() << "expected " << desc.value; + << res64->getValue() << " expected " << desc.value; } return llvm::Error::success(); } @@ -204,6 +204,12 @@ template <> struct llvm::yaml::MappingTraits { desc.v0Constraint->p = v0constraint[0]; desc.v0Constraint->norm2 = v0constraint[1]; } + mlir::concretelang::LargeIntegerParameter largeInterger; + io.mapOptional("large-integer-crt-decomposition", + largeInterger.crtDecomposition); + if (!largeInterger.crtDecomposition.empty()) { + desc.largeIntegerParameter = largeInterger; + } } }; diff --git a/compiler/tests/end_to_end_fixture/EndToEndFixture.h b/compiler/tests/end_to_end_fixture/EndToEndFixture.h index bc115b5d0..b458710eb 100644 --- a/compiler/tests/end_to_end_fixture/EndToEndFixture.h +++ b/compiler/tests/end_to_end_fixture/EndToEndFixture.h @@ -14,7 +14,7 @@ struct TensorDescription { ValueWidth width; }; struct ScalarDesc { - uint64_t value; + int64_t value; ValueWidth width; }; @@ -43,6 +43,8 @@ struct EndToEndDesc { std::vector tests; llvm::Optional v0Parameter; llvm::Optional v0Constraint; + llvm::Optional + largeIntegerParameter; }; llvm::Expected diff --git a/compiler/tests/end_to_end_fixture/end_to_end_encrypted_tensor.yaml b/compiler/tests/end_to_end_fixture/end_to_end_encrypted_tensor.yaml index ae17d900f..5d888da56 100644 --- a/compiler/tests/end_to_end_fixture/end_to_end_encrypted_tensor.yaml +++ b/compiler/tests/end_to_end_fixture/end_to_end_encrypted_tensor.yaml @@ -1,3 +1,17 @@ +## Part of the Concrete Compiler Project, under the BSD3 License with Zama +## Exceptions. See +## https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +## for license information. +## +## This file contains programs and references to end to end test the compiler. +## Program in this file aims to test the `tensor` dialects on encrypted integers. +## +## Operators: +## - tensor.extract +## - tensor.insert +## - tensor.extract_slice +## - tensor.insert_slice + description: identity program: | func.func @main(%t: tensor<2x10x!FHE.eint<6>>) -> tensor<2x10x!FHE.eint<6>> { @@ -14,6 +28,24 @@ tests: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] shape: [2,10] --- +description: identity_16bits +program: | + func.func @main(%t: tensor<2x10x!FHE.eint<16>>) -> tensor<2x10x!FHE.eint<16>> { + return %t : tensor<2x10x!FHE.eint<16>> + } +tests: + - inputs: + - tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769, + 38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0] + shape: [2,10] + outputs: + - tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769, + 38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0] + shape: [2,10] +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- description: extract program: | func.func @main(%t: tensor<2x10x!FHE.eint<6>>, %i: index, %j: index) -> @@ -59,6 +91,157 @@ tests: outputs: - scalar: 9 --- +description: extract_16bits +program: | + func.func @main(%t: tensor<2x10x!FHE.eint<16>>, %i: index, %j: index) -> + !FHE.eint<16> { + %c = tensor.extract %t[%i, %j] : tensor<2x10x!FHE.eint<16>> + return %c : !FHE.eint<16> + } +tests: + - inputs: + - tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769, + 38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0] + shape: [2,10] + - scalar: 0 + - scalar: 0 + outputs: + - scalar: 65535 + - inputs: + - tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + shape: [2,10] + - scalar: 0 + - scalar: 9 + outputs: + - scalar: 0 + - inputs: + - tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + shape: [2,10] + - scalar: 1 + - scalar: 0 + outputs: + - scalar: 0 + - inputs: + - tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + shape: [2,10] + - scalar: 1 + - scalar: 9 + outputs: + - scalar: 9 +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- +description: insert +program: | + func.func @main(%t: tensor<2x10x!FHE.eint<6>>, %i: index, %j: index, %x: !FHE.eint<6>) -> tensor<2x10x!FHE.eint<6>> { + %r = tensor.insert %x into %t[%i, %j] : tensor<2x10x!FHE.eint<6>> + return %r : tensor<2x10x!FHE.eint<6>> + } +tests: + - inputs: + - tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + shape: [2,10] + - scalar: 0 + - scalar: 0 + - scalar: 42 + outputs: + - tensor: [42, 12, 7, 43, 52, 9, 26, 34, 22, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + shape: [2,10] + - inputs: + - tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + shape: [2,10] + - scalar: 0 + - scalar: 1 + - scalar: 42 + outputs: + - tensor: [63, 42, 7, 43, 52, 9, 26, 34, 22, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + shape: [2,10] + - inputs: + - tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0, + 42, 1, 2, 3, 4, 5, 6, 7, 8, 9] + shape: [2,10] + - scalar: 1 + - scalar: 0 + - scalar: 42 + outputs: + - tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0, + 42, 1, 2, 3, 4, 5, 6, 7, 8, 9] + shape: [2,10] + - inputs: + - tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0, + 42, 1, 2, 3, 4, 5, 6, 7, 8, 9] + shape: [2,10] + - scalar: 1 + - scalar: 9 + - scalar: 42 + outputs: + - tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0, + 42, 1, 2, 3, 4, 5, 6, 7, 8, 42] + shape: [2,10] +--- +description: insert_16bits +program: | + func.func @main(%t: tensor<2x10x!FHE.eint<16>>, %i: index, %j: index, %x: !FHE.eint<16>) -> tensor<2x10x!FHE.eint<16>> { + %r = tensor.insert %x into %t[%i, %j] : tensor<2x10x!FHE.eint<16>> + return %r : tensor<2x10x!FHE.eint<16>> + } +tests: + - inputs: + - tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769, + 38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0] + shape: [2,10] + - scalar: 0 + - scalar: 0 + - scalar: 42 + outputs: + - tensor: [42, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769, + 38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0] + shape: [2,10] + - inputs: + - tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769, + 38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0] + shape: [2,10] + - scalar: 0 + - scalar: 1 + - scalar: 42 + outputs: + - tensor: [65535, 42, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769, + 38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0] + shape: [2,10] + - inputs: + - tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769, + 38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0] + shape: [2,10] + - scalar: 1 + - scalar: 0 + - scalar: 42 + outputs: + - tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769, + 42, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0] + shape: [2,10] + - inputs: + - tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769, + 38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0] + shape: [2,10] + - scalar: 1 + - scalar: 9 + - scalar: 42 + outputs: + - tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769, + 38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 42] + shape: [2,10] +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- description: extract_slice program: | func.func @main(%t: tensor<2x10x!FHE.eint<6>>) -> tensor<1x5x!FHE.eint<6>> { @@ -91,6 +274,24 @@ tests: - tensor: [ 5, 6, 7, 8, 9] shape: [5] --- +description: extract_slice_16bits +program: | + func.func @main(%t: tensor<2x10x!FHE.eint<16>>) -> tensor<1x5x!FHE.eint<16>> { + %r = tensor.extract_slice %t[1, 5][1, 5][1, 1] : tensor<2x10x!FHE.eint<16>> to tensor<1x5x!FHE.eint<16>> + return %r : tensor<1x5x!FHE.eint<16>> + } +tests: + - inputs: + - tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769, + 38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0] + shape: [2,10] + outputs: + - tensor: [64814, 65491, 4271, 9294, 0] + shape: [1,5] +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- description: extract_slice_stride program: | func.func @main(%t: tensor<2x10x!FHE.eint<6>>) -> tensor<1x5x!FHE.eint<6>> { @@ -122,6 +323,24 @@ tests: - tensor: [3, 2, 1] shape: [3] --- +description: extract_slice_stride_16bits +program: | + func.func @main(%t: tensor<2x10x!FHE.eint<6>>) -> tensor<1x5x!FHE.eint<6>> { + %r = tensor.extract_slice %t[1, 0][1, 5][1, 2] : tensor<2x10x!FHE.eint<6>> to tensor<1x5x!FHE.eint<6>> + return %r : tensor<1x5x!FHE.eint<6>> + } +tests: + - inputs: + - tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769, + 38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0] + shape: [2,10] + outputs: + - tensor: [38786, 65112, 60515, 65491, 9294] + shape: [1,5] +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- description: insert_slice program: | func.func @main(%t0: tensor<2x10x!FHE.eint<6>>, %t1: tensor<2x2x!FHE.eint<6>>) -> tensor<2x10x!FHE.eint<6>> { @@ -179,3 +398,25 @@ tests: - tensor: [0, 1, 2, 3, 4, 5] shape: [2, 3] +--- +description: insert_slice_16bits +program: | + func.func @main(%t0: tensor<2x10x!FHE.eint<16>>, %t1: tensor<2x2x!FHE.eint<16>>) -> tensor<2x10x!FHE.eint<16>> { + %r = tensor.insert_slice %t1 into %t0[0, 5][2, 2][1, 1] : tensor<2x2x!FHE.eint<16>> into tensor<2x10x!FHE.eint<16>> + return %r : tensor<2x10x!FHE.eint<16>> + } +tests: + - inputs: + - tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769, + 38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0] + shape: [2,10] + - tensor: [1000, 1001, + 1002, 1003] + shape: [2,2] + outputs: + - tensor: [65535, 46706, 18752, 55384, 55709, 1000, 1001, 57650, 45551, 5769, + 38786, 36362, 65112, 5748, 60515, 1002, 1003, 4271, 9294, 0] + shape: [2,10] +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] diff --git a/compiler/tests/end_to_end_fixture/end_to_end_fhe.yaml b/compiler/tests/end_to_end_fixture/end_to_end_fhe.yaml index f1cc70cfe..a3590dfd6 100644 --- a/compiler/tests/end_to_end_fixture/end_to_end_fhe.yaml +++ b/compiler/tests/end_to_end_fixture/end_to_end_fhe.yaml @@ -9,6 +9,28 @@ tests: outputs: - scalar: 1 --- +description: identity_16bits +program: | + func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> { + return %arg0: !FHE.eint<16> + } +tests: + - inputs: + - scalar: 1 + outputs: + - scalar: 1 + - inputs: + - scalar: 0 + outputs: + - scalar: 0 + - inputs: + - scalar: 72071 + outputs: + - scalar: 72071 +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- description: zero_tensor program: | func.func @main() -> tensor<2x2x4x!FHE.eint<6>> { @@ -20,6 +42,20 @@ tests: - tensor: [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0] shape: [2,2,4] --- +description: zero_tensor_16bits +program: | + func.func @main() -> tensor<2x2x4x!FHE.eint<16>> { + %0 = "FHE.zero_tensor"() : () -> tensor<2x2x4x!FHE.eint<16>> + return %0 : tensor<2x2x4x!FHE.eint<16>> + } +tests: + - outputs: + - tensor: [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0] + shape: [2,2,4] +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- description: add_eint_int_cst program: | func.func @main(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { @@ -32,18 +68,30 @@ tests: - scalar: 0 outputs: - scalar: 1 +--- +description: add_eint_int_cst_16bits +program: | + func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> { + %0 = arith.constant 1 : i17 + %1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<16>, i17) -> (!FHE.eint<16>) + return %1: !FHE.eint<16> + } +tests: - inputs: + - scalar: 0 + outputs: - scalar: 1 - outputs: - - scalar: 2 - inputs: - - scalar: 2 + - scalar: 72070 outputs: - - scalar: 3 + - scalar: 72071 - inputs: - - scalar: 3 + - scalar: 72071 outputs: - - scalar: 4 + - scalar: 0 +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] --- description: add_eint_int_arg program: | @@ -63,30 +111,75 @@ tests: outputs: - scalar: 3 --- -description: sub_int_eint_cst +description: add_eint_int_arg_16bits program: | - func.func @main(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { - %0 = arith.constant 7 : i3 - %1 = "FHE.sub_int_eint"(%0, %arg0): (i3, !FHE.eint<2>) -> (!FHE.eint<2>) + func.func @main(%arg0: !FHE.eint<16>, %arg1: i17) -> !FHE.eint<16> { + %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<16>, i17) -> (!FHE.eint<16>) + return %1: !FHE.eint<16> + } +tests: + - inputs: + - scalar: 0 + - scalar: 0 + outputs: + - scalar: 0 + - inputs: + - scalar: 0 + - scalar: 5 + outputs: + - scalar: 5 + - inputs: + - scalar: 36036 + - scalar: 36035 + outputs: + - scalar: 72071 +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- +description: add_eint +program: | + func.func @main(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>) return %1: !FHE.eint<2> } tests: - inputs: + - scalar: 0 - scalar: 1 outputs: - - scalar: 6 + - scalar: 1 - inputs: + - scalar: 1 - scalar: 2 outputs: + - scalar: 3 +--- +description: add_eint_16bits +program: | + func.func @main(%arg0: !FHE.eint<16>, %arg1: !FHE.eint<16>) -> !FHE.eint<16> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<16>, !FHE.eint<16>) -> (!FHE.eint<16>) + return %1: !FHE.eint<16> + } +tests: + - inputs: + - scalar: 0 + - scalar: 0 + outputs: + - scalar: 0 + - inputs: + - scalar: 0 + - scalar: 5 + outputs: - scalar: 5 - inputs: - - scalar: 3 + - scalar: 36036 + - scalar: 36035 outputs: - - scalar: 4 - - inputs: - - scalar: 4 - outputs: - - scalar: 3 + - scalar: 72071 +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] --- description: sub_eint_int_cst program: | @@ -105,6 +198,22 @@ tests: outputs: - scalar: 0 --- +description: sub_eint_int_cst_16bits +program: | + func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> { + %0 = arith.constant 7 : i17 + %1 = "FHE.sub_eint_int"(%arg0, %0): (!FHE.eint<16>, i17) -> (!FHE.eint<16>) + return %1: !FHE.eint<16> + } +tests: + - inputs: + - scalar: 7 + outputs: + - scalar: 0 +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- description: sub_eint_int_arg program: | func.func @main(%arg0: !FHE.eint<4>, %arg1: i5) -> !FHE.eint<4> { @@ -128,6 +237,22 @@ tests: outputs: - scalar: 3 --- +description: sub_eint_int_arg_16bits_fixme +program: | + func.func @main(%arg0: !FHE.eint<16>, %arg1: i17) -> !FHE.eint<16> { + %1 = "FHE.sub_eint_int"(%arg0, %arg1): (!FHE.eint<16>, i17) -> (!FHE.eint<16>) + return %1: !FHE.eint<16> + } +tests: + - inputs: + - scalar: 72071 + - scalar: 2 + outputs: + - scalar: 72069 +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- description: sub_eint program: | func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> { @@ -151,6 +276,32 @@ tests: outputs: - scalar: 3 --- +description: sub_eint_16bits +program: | + func.func @main(%arg0: !FHE.eint<16>, %arg1: !FHE.eint<16>) -> !FHE.eint<16> { + %1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<16>, !FHE.eint<16>) -> (!FHE.eint<16>) + return %1: !FHE.eint<16> + } +tests: + - inputs: + - scalar: 0 + - scalar: 0 + outputs: + - scalar: 0 + - inputs: + - scalar: 72071 + - scalar: 72071 + outputs: + - scalar: 0 + - inputs: + - scalar: 7 + - scalar: 4 + outputs: + - scalar: 3 +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- description: sub_int_eint_arg program: | func.func @main(%arg0: i3, %arg1: !FHE.eint<2>) -> !FHE.eint<2> { @@ -174,6 +325,79 @@ tests: outputs: - scalar: 5 --- +description: sub_int_eint_arg_16bits +program: | + func.func @main(%arg0: i17, %arg1: !FHE.eint<16>) -> !FHE.eint<16> { + %1 = "FHE.sub_int_eint"(%arg0, %arg1): (i17, !FHE.eint<16>) -> (!FHE.eint<16>) + return %1: !FHE.eint<16> + } +tests: + - inputs: + - scalar: 0 + - scalar: 0 + outputs: + - scalar: 0 + - inputs: + - scalar: 0 + - scalar: -1 + outputs: + - scalar: 1 + #- inputs: + # - scalar: 72071 + # - scalar: 0 + # outputs: + # - scalar: 72071 + #- inputs: + # - scalar: 7 + # - scalar: 4 + # outputs: + # - scalar: 3 + +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- +description: sub_int_eint_cst +program: | + func.func @main(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { + %0 = arith.constant 7 : i3 + %1 = "FHE.sub_int_eint"(%0, %arg0): (i3, !FHE.eint<2>) -> (!FHE.eint<2>) + return %1 : !FHE.eint<2> + } +tests: + - inputs: + - scalar: 1 + outputs: + - scalar: 6 + - inputs: + - scalar: 2 + outputs: + - scalar: 5 +--- +description: sub_int_eint_cst_16bits +program: | + func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> { + %0 = arith.constant -1 : i17 + %1 = "FHE.sub_int_eint"(%0, %arg0): (i17, !FHE.eint<16>) -> (!FHE.eint<16>) + return %1: !FHE.eint<16> + } +tests: + - inputs: + - scalar: 72071 + outputs: + - scalar: 0 + - inputs: + - scalar: 0 + outputs: + - scalar: 72071 + - inputs: + - scalar: 32000 + outputs: + - scalar: 40071 +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- description: neg_eint program: | func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { @@ -198,6 +422,29 @@ tests: outputs: - scalar: 6 --- +description: neg_eint_16bits +program: | + func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> { + %1 = "FHE.neg_eint"(%arg0): (!FHE.eint<16>) -> (!FHE.eint<16>) + return %1: !FHE.eint<16> + } +tests: + - inputs: + - scalar: 0 + outputs: + - scalar: 0 + - inputs: + - scalar: 1 + outputs: + - scalar: 72071 + - inputs: + - scalar: 72071 + outputs: + - scalar: 1 +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- description: neg_eint_3bits program: | func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { @@ -247,6 +494,30 @@ tests: outputs: - scalar: 4 --- +description: mul_eint_int_cst_16bits +program: | + func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> { + %0 = arith.constant 2 : i17 + %1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<16>, i17) -> (!FHE.eint<16>) + return %1: !FHE.eint<16> + } +tests: + - inputs: + - scalar: 0 + outputs: + - scalar: 0 + - inputs: + - scalar: 1 + outputs: + - scalar: 2 + - inputs: + - scalar: 36035 + outputs: + - scalar: 72070 +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- description: mul_eint_int_arg program: | func.func @main(%arg0: !FHE.eint<2>, %arg1: i3) -> !FHE.eint<2> { @@ -270,28 +541,31 @@ tests: outputs: - scalar: 4 --- -description: add_eint +description: mul_eint_int_arg_16bits program: | - func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> { - %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>) - return %1: !FHE.eint<7> + func.func @main(%arg0: !FHE.eint<16>, %arg1: i17) -> !FHE.eint<16> { + %1 = "FHE.mul_eint_int"(%arg0, %arg1): (!FHE.eint<16>, i17) -> (!FHE.eint<16>) + return %1: !FHE.eint<16> } tests: - inputs: - - scalar: 1 - - scalar: 2 + - scalar: 0 + - scalar: 87 outputs: - - scalar: 3 - - inputs: - - scalar: 4 - - scalar: 5 - outputs: - - scalar: 9 + - scalar: 0 - inputs: - scalar: 1 - - scalar: 1 + - scalar: 72071 outputs: + - scalar: 72071 + - inputs: - scalar: 2 + - scalar: 3572 + outputs: + - scalar: 7144 +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] --- description: apply_lookup_table_1_bits program: | diff --git a/compiler/tests/end_to_end_fixture/end_to_end_fhelinalg.yaml b/compiler/tests/end_to_end_fixture/end_to_end_fhelinalg.yaml index 5cac37fcd..5d2753ddc 100644 --- a/compiler/tests/end_to_end_fixture/end_to_end_fhelinalg.yaml +++ b/compiler/tests/end_to_end_fixture/end_to_end_fhelinalg.yaml @@ -1,3 +1,229 @@ +description: add_eint_int_term_to_term +program: | + // Returns the term to term addition of `%a0` with `%a1` + func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!FHE.eint<6>> { + %res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4xi7>) -> tensor<4x!FHE.eint<6>> + return %res : tensor<4x!FHE.eint<6>> + } +tests: + - inputs: + - tensor: [31, 6, 12, 9] + shape: [4] + width: 8 + - tensor: [32, 9, 2, 3] + shape: [4] + width: 8 + outputs: + - tensor: [63, 15, 14, 12] + shape: [4] +--- +description: add_eint_int_term_to_term_16bits +program: | + // Returns the term to term addition of `%a0` with `%a1` + func.func @main(%a0: tensor<4x!FHE.eint<16>>, %a1: tensor<4xi17>) -> tensor<4x!FHE.eint<16>> { + %res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<16>>, tensor<4xi17>) -> tensor<4x!FHE.eint<16>> + return %res : tensor<4x!FHE.eint<16>> + } +tests: + - inputs: + - tensor: [32767, 1276, 10212, 0] + shape: [4] + - tensor: [32768, 20967, 3, 0] + shape: [4] + outputs: + - tensor: [65535, 22243, 10215, 0] + shape: [4] +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- +description: add_eint_term_to_term +program: | + func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> { + %res = "FHELinalg.add_eint"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> + return %res : tensor<4x!FHE.eint<6>> + } +tests: + - inputs: + - tensor: [31, 6, 12, 9] + shape: [4] + - tensor: [32, 9, 2, 3] + shape: [4] + outputs: + - tensor: [63, 15, 14, 12] + shape: [4] +--- +description: add_eint_term_to_term_16bits +program: | + func.func @main(%a0: tensor<4x!FHE.eint<16>>, %a1: tensor<4x!FHE.eint<16>>) -> tensor<4x!FHE.eint<16>> { + %res = "FHELinalg.add_eint"(%a0, %a1) : (tensor<4x!FHE.eint<16>>, tensor<4x!FHE.eint<16>>) -> tensor<4x!FHE.eint<16>> + return %res : tensor<4x!FHE.eint<16>> + } +tests: + - inputs: + - tensor: [32767, 1276, 10212, 0] + shape: [4] + - tensor: [32768, 20967, 3, 0] + shape: [4] + outputs: + - tensor: [65535, 22243, 10215, 0] + shape: [4] +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- +description: sub_int_eint_term_to_term +program: | + // Returns the term to term substraction of `%a0` with `%a1` + func.func @main(%a0: tensor<4xi5>, %a1: tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>> { + %res = "FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4xi5>, tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>> + return %res : tensor<4x!FHE.eint<4>> + } +tests: + - inputs: + - tensor: [15, 9, 12, 9] + shape: [4] + width: 8 + - tensor: [15, 6, 2, 3] + shape: [4] + width: 8 + outputs: + - tensor: [0, 3, 10, 6] + shape: [4] +--- +description: sub_int_eint_term_to_term_16bits +program: | + // Returns the term to term substraction of `%a0` with `%a1` + func.func @main(%a0: tensor<4xi17>, %a1: tensor<4x!FHE.eint<16>>) -> tensor<4x!FHE.eint<16>> { + %res = "FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4xi17>, tensor<4x!FHE.eint<16>>) -> tensor<4x!FHE.eint<16>> + return %res : tensor<4x!FHE.eint<16>> + } +tests: + - inputs: + - tensor: [65535, 22243, 10215, 0] + shape: [4] + - tensor: [65535, 1276, 10212, 0] + shape: [4] + outputs: + - tensor: [0, 20967, 3, 0] + shape: [4] +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- +description: sub_eint_int_term_to_term +program: | + func.func @main(%a0: tensor<4xi5>, %a1: tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>> { + %res = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<4x!FHE.eint<4>>, tensor<4xi5>) -> tensor<4x!FHE.eint<4>> + return %res : tensor<4x!FHE.eint<4>> + } +tests: + - inputs: + - tensor: [15, 6, 2, 3] + shape: [4] + width: 8 + - tensor: [15, 9, 12, 9] + shape: [4] + width: 8 + outputs: + - tensor: [0, 3, 10, 6] + shape: [4] +--- +description: sub_eint_int_term_to_term_16bits +program: | + func.func @main(%a0: tensor<4xi17>, %a1: tensor<4x!FHE.eint<16>>) -> tensor<4x!FHE.eint<16>> { + %res = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<4x!FHE.eint<16>>, tensor<4xi17>) -> tensor<4x!FHE.eint<16>> + return %res : tensor<4x!FHE.eint<16>> + } +tests: + - inputs: + - tensor: [65535, 1276, 10212, 0] + shape: [4] + - tensor: [65535, 22243, 10215, 0] + shape: [4] + outputs: + - tensor: [0, 20967, 3, 0] + shape: [4] +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- +description: sub_eint_term_to_term +program: | + func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> { + %res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> + return %res : tensor<4x!FHE.eint<6>> + } +tests: + - inputs: + - tensor: [31, 6, 12, 9] + shape: [4] + width: 8 + - tensor: [4, 2, 9, 3] + shape: [4] + width: 8 + outputs: + - tensor: [27, 4, 3, 6] + shape: [4] +--- +description: sub_eint_term_to_term_16bits +program: | + func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> { + %res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> + return %res : tensor<4x!FHE.eint<6>> + } +tests: + - inputs: + - tensor: [65535, 22243, 10215, 0] + shape: [4] + - tensor: [65535, 1276, 10212, 0] + shape: [4] + outputs: + - tensor: [0, 20967, 3, 0] + shape: [4] +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- +description: mul_eint_int_term_to_term +program: | + // Returns the term to term multiplication of `%a0` with `%a1` + func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!FHE.eint<6>> { + %res = "FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4xi7>) -> tensor<4x!FHE.eint<6>> + return %res : tensor<4x!FHE.eint<6>> + } +tests: + - inputs: + - tensor: [31, 6, 12, 9] + shape: [4] + width: 8 + - tensor: [2, 3, 2, 3] + shape: [4] + width: 8 + outputs: + - tensor: [62, 18, 24, 27] + shape: [4] +--- +description: mul_eint_int_term_to_term_16bits +program: | + // Returns the term to term multiplication of `%a0` with `%a1` + func.func @main(%a0: tensor<4x!FHE.eint<16>>, %a1: tensor<4xi17>) -> tensor<4x!FHE.eint<16>> { + %res = "FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<16>>, tensor<4xi17>) -> tensor<4x!FHE.eint<16>> + return %res : tensor<4x!FHE.eint<16>> + } +tests: + - inputs: + - tensor: [1, 65535, 12, 0] + shape: [4] + - tensor: [65535, 1, 1987, 0] + shape: [4] + outputs: + - tensor: [65535, 65535, 23844, 0] + shape: [4] +v0-constraint: [16, 0] +v0-parameter: [1,11,565,1,23,5,3] +large-integer-crt-decomposition: [7,8,9,11,13] +--- description: transpose1d program: | func.func @main(%input: tensor<3x!FHE.eint<6>>) -> tensor<3x!FHE.eint<6>> { diff --git a/compiler/tests/end_to_end_tests/end_to_end_jit_fhe.cc b/compiler/tests/end_to_end_tests/end_to_end_jit_fhe.cc index 533424cd1..9fb3495ea 100644 --- a/compiler/tests/end_to_end_tests/end_to_end_jit_fhe.cc +++ b/compiler/tests/end_to_end_tests/end_to_end_jit_fhe.cc @@ -20,6 +20,9 @@ void compile_and_run(EndToEndDesc desc, LambdaSupport support) { if (desc.v0Parameter.hasValue()) { options.v0Parameter = *desc.v0Parameter; } + if (desc.largeIntegerParameter.hasValue()) { + options.largeIntegerParameter = *desc.largeIntegerParameter; + } /* 0 - Enable parallel testing where required */ #ifdef CONCRETELANG_PARALLEL_TESTING_ENABLED diff --git a/compiler/tests/end_to_end_tests/end_to_end_jit_fhelinalg.cc b/compiler/tests/end_to_end_tests/end_to_end_jit_fhelinalg.cc index 35e3642b2..8fcf28f9f 100644 --- a/compiler/tests/end_to_end_tests/end_to_end_jit_fhelinalg.cc +++ b/compiler/tests/end_to_end_tests/end_to_end_jit_fhelinalg.cc @@ -70,37 +70,6 @@ func.func @main(%arg0: tensor<3x2x!FHE.eint<4>>) -> tensor<3x!FHE.eint<4>> { EXPECT_EQ((*res)[2], 6); } -TEST(End2EndJit_FHELinalg, add_eint_int_term_to_term) { - - checkedJit(lambda, R"XXX( - // Returns the term to term addition of `%a0` with `%a1` - func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!FHE.eint<6>> { - %res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4xi7>) -> tensor<4x!FHE.eint<6>> - return %res : tensor<4x!FHE.eint<6>> - } -)XXX"); - std::vector a0{31, 6, 12, 9}; - std::vector a1{32, 9, 2, 3}; - - mlir::concretelang::TensorLambdaArgument< - mlir::concretelang::IntLambdaArgument> - arg0(a0); - mlir::concretelang::TensorLambdaArgument< - mlir::concretelang::IntLambdaArgument> - arg1(a1); - - llvm::Expected> res = - lambda.operator()>({&arg0, &arg1}); - - ASSERT_EXPECTED_SUCCESS(res); - - ASSERT_EQ(res->size(), (size_t)4); - - for (size_t i = 0; i < 4; i++) { - EXPECT_EQ((*res)[i], (uint64_t)a0[i] + a1[i]); - } -} - // Same as add_eint_int_term_to_term test above, but returning a lambda argument TEST(End2EndJit_FHELinalg, add_eint_int_term_to_term_ret_lambda_argument) { @@ -371,40 +340,6 @@ TEST(End2EndJit_FHELinalg, add_eint_int_matrix_line_missing_dim) { // FHELinalg add_eint /////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////// -TEST(End2EndJit_FHELinalg, add_eint_term_to_term) { - - checkedJit(lambda, R"XXX( - // Returns the term to term addition of `%a0` with `%a1` - func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> { - %res = "FHELinalg.add_eint"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> - return %res : tensor<4x!FHE.eint<6>> - } -)XXX"); - - std::vector a0{31, 6, 12, 9}; - std::vector a1{32, 9, 2, 3}; - - mlir::concretelang::TensorLambdaArgument< - mlir::concretelang::IntLambdaArgument> - arg0(a0); - mlir::concretelang::TensorLambdaArgument< - mlir::concretelang::IntLambdaArgument> - arg1(a1); - - llvm::Expected> res = - lambda.operator()>({&arg0, &arg1}); - - ASSERT_EXPECTED_SUCCESS(res); - - ASSERT_EQ(res->size(), (uint64_t)4); - - for (size_t i = 0; i < 4; i++) { - EXPECT_EQ((*res)[i], (uint64_t)a0[i] + a1[i]) - << "result differ at pos " << i << ", expect " << a0[i] + a1[i] - << " got " << (*res)[i]; - } -} - TEST(End2EndJit_FHELinalg, add_eint_term_to_term_broadcast) { checkedJit(lambda, R"XXX( @@ -854,36 +789,6 @@ TEST(End2EndJit_FHELinalg, sub_int_eint_matrix_line_missing_dim) { // FHELinalg sub_eint_int /////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////// -TEST(End2EndJit_FHELinalg, sub_eint_int_term_to_term) { - - checkedJit(lambda, R"XXX( - func.func @main(%a0: tensor<4xi5>, %a1: tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>> { - %res = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<4x!FHE.eint<4>>, tensor<4xi5>) -> tensor<4x!FHE.eint<4>> - return %res : tensor<4x!FHE.eint<4>> - } -)XXX"); - std::vector a0{31, 6, 2, 3}; - std::vector a1{32, 9, 12, 9}; - - mlir::concretelang::TensorLambdaArgument< - mlir::concretelang::IntLambdaArgument> - arg0(a0); - mlir::concretelang::TensorLambdaArgument< - mlir::concretelang::IntLambdaArgument> - arg1(a1); - - llvm::Expected> res = - lambda.operator()>({&arg0, &arg1}); - - ASSERT_EXPECTED_SUCCESS(res); - - ASSERT_EQ(res->size(), (uint64_t)4); - - for (size_t i = 0; i < 4; i++) { - EXPECT_EQ((*res)[i], (uint64_t)a1[i] - a0[i]); - } -} - TEST(End2EndJit_FHELinalg, sub_eint_int_term_to_term_broadcast) { checkedJit(lambda, R"XXX( diff --git a/compiler/tests/tests_tools/assert.h b/compiler/tests/tests_tools/assert.h index 9629618d9..14045db0d 100644 --- a/compiler/tests/tests_tools/assert.h +++ b/compiler/tests/tests_tools/assert.h @@ -6,8 +6,13 @@ #include #define ASSERT_LLVM_ERROR(err) \ - if (err) { \ - ASSERT_TRUE(false) << llvm::toString(err); \ + { \ + llvm::Error e = err; \ + if (e) { \ + handleAllErrors(std::move(e), [](const llvm::ErrorInfoBase &ei) { \ + ASSERT_TRUE(false) << ei.message(); \ + }); \ + } \ } // Checks that the value `val` is not in an error state. Returns diff --git a/compiler/tests/unit_tests/concretelang/ClientLib/CMakeLists.txt b/compiler/tests/unit_tests/concretelang/ClientLib/CMakeLists.txt index bda851800..694547f8e 100644 --- a/compiler/tests/unit_tests/concretelang/ClientLib/CMakeLists.txt +++ b/compiler/tests/unit_tests/concretelang/ClientLib/CMakeLists.txt @@ -8,6 +8,7 @@ add_unittest( unit_tests_concretelang_clientlib ClientParameters.cpp + CRT.cpp KeySet.cpp ) diff --git a/compiler/tests/unit_tests/concretelang/ClientLib/CRT.cpp b/compiler/tests/unit_tests/concretelang/ClientLib/CRT.cpp new file mode 100644 index 000000000..b1bdbe257 --- /dev/null +++ b/compiler/tests/unit_tests/concretelang/ClientLib/CRT.cpp @@ -0,0 +1,50 @@ +#include + +#include "concretelang/ClientLib/CRT.h" +#include "tests_tools/assert.h" +namespace { +namespace crt = concretelang::clientlib::crt; +typedef std::vector CRTModuli; + +// Define a fixture for instantiate test with client parameters +class CRTTest : public ::testing::TestWithParam {}; + +TEST_P(CRTTest, crt_iCrt) { + auto moduli = GetParam(); + + // Max representable value from moduli + uint64_t maxValue = 1; + for (auto modulus : moduli) + maxValue *= modulus; + maxValue = maxValue - 1; + + std::vector valuesToTest{0, maxValue / 2, maxValue}; + for (auto a : valuesToTest) { + auto remainders = crt::crt(moduli, a); + auto b = crt::iCrt(moduli, remainders); + + ASSERT_EQ(a, b); + } +} + +std::vector generateAllParameters() { + return { + // This is our default moduli for the 16 bits + {7, 8, 9, 11, 13}, + }; +} + +INSTANTIATE_TEST_SUITE_P(CRTSuite, CRTTest, + ::testing::ValuesIn(generateAllParameters()), + [](const testing::TestParamInfo info) { + auto moduli = info.param; + std::string desc("mod"); + if (!moduli.empty()) { + for (auto b : moduli) { + desc = desc + "_" + std::to_string(b); + } + } + return desc; + }); + +} // namespace diff --git a/compiler/tests/unit_tests/concretelang/ClientLib/ClientParameters.cpp b/compiler/tests/unit_tests/concretelang/ClientLib/ClientParameters.cpp index 9e7e06faa..2ef688186 100644 --- a/compiler/tests/unit_tests/concretelang/ClientLib/ClientParameters.cpp +++ b/compiler/tests/unit_tests/concretelang/ClientLib/ClientParameters.cpp @@ -42,7 +42,7 @@ TEST(Support, client_parameters_json_serde) { }}}; params0.inputs = { { - /*.encryption = */ {{clientlib::SMALL_KEY, 0.00, {4}}}, + /*.encryption = */ {{clientlib::SMALL_KEY, 0.00, {4, {1, 2, 3, 4}}}}, /*.shape = */ {32, {1, 2, 3, 4}, 1 * 2 * 3 * 4}, }, { diff --git a/compiler/tests/unit_tests/concretelang/ClientLib/KeySet.cpp b/compiler/tests/unit_tests/concretelang/ClientLib/KeySet.cpp index d8c143209..6dc8ebdb5 100644 --- a/compiler/tests/unit_tests/concretelang/ClientLib/KeySet.cpp +++ b/compiler/tests/unit_tests/concretelang/ClientLib/KeySet.cpp @@ -8,14 +8,14 @@ namespace clientlib = concretelang::clientlib; // Define a fixture for instantiate test with client parameters -class ClientParametersTest +class KeySetTest : public ::testing::TestWithParam { protected: clientlib::ClientParameters clientParameters; }; // Test case encrypt and decrypt -TEST_P(ClientParametersTest, encrypt_decrypt) { +TEST_P(KeySetTest, encrypt_decrypt) { auto clientParameters = GetParam(); @@ -29,7 +29,7 @@ TEST_P(ClientParametersTest, encrypt_decrypt) { ASSERT_OUTCOME_HAS_VALUE(keySet->allocate_lwe(0, &ciphertext, size)); // Encrypt - uint64_t input = 3; + uint64_t input = 0; ASSERT_OUTCOME_HAS_VALUE(keySet->encrypt_lwe(0, ciphertext, input)); // Decrypt @@ -45,9 +45,9 @@ TEST_P(ClientParametersTest, encrypt_decrypt) { /// Create a client parameters with just one secret key of `dimension` and with /// one input scalar gate and one output scalar gate on the same key -clientlib::ClientParameters -generateClientParameterOneScalarOneScalar(clientlib::LweDimension dimension, - clientlib::Precision precision) { +clientlib::ClientParameters generateClientParameterOneScalarOneScalar( + clientlib::LweDimension dimension, clientlib::Precision precision, + clientlib::CRTDecomposition crtDecomposition) { // One secret key with the given dimension clientlib::ClientParameters params; params.secretKeys.insert({clientlib::SMALL_KEY, {/*.dimension =*/dimension}}); @@ -56,6 +56,7 @@ generateClientParameterOneScalarOneScalar(clientlib::LweDimension dimension, clientlib::EncryptionGate encryption; encryption.secretKeyID = clientlib::SMALL_KEY; encryption.encoding.precision = precision; + encryption.encoding.crt = crtDecomposition; clientlib::CircuitGate gate; gate.encryption = encryption; params.inputs.push_back(gate); @@ -74,13 +75,23 @@ std::vector generateAllParameters() { llvm::for_each(llvm::enumerate(precisions), [](auto p) { p.value() = p.index() + 1; }); + // All crt decomposition to test + std::vector crtDecompositions{ + // Empty crt decompositon means no decomposition + {}, + // The default decomposition for 16 bits + {7, 8, 9, 11, 13}, + }; + // All client parameters to test std::vector parameters; for (auto dimension : lweDimensions) { for (auto precision : precisions) { - parameters.push_back( - generateClientParameterOneScalarOneScalar(dimension, precision)); + for (auto crtDecomposition : crtDecompositions) { + parameters.push_back(generateClientParameterOneScalarOneScalar( + dimension, precision, crtDecomposition)); + } } } @@ -88,13 +99,21 @@ std::vector generateAllParameters() { } INSTANTIATE_TEST_SUITE_P( - OneScalarOnScalar, ClientParametersTest, - ::testing::ValuesIn(generateAllParameters()), + OneScalarOnScalar, KeySetTest, ::testing::ValuesIn(generateAllParameters()), [](const testing::TestParamInfo info) { auto cp = info.param; auto input_0 = cp.inputs[0]; - return std::string("lweDimension_") + - std::to_string(cp.lweSecretKeyParam(input_0).value().dimension) + - "_precision_" + - std::to_string(input_0.encryption.getValue().encoding.precision); + auto paramDescription = + std::string("lweDimension_") + + std::to_string(cp.lweSecretKeyParam(input_0).value().dimension) + + "_precision_" + + std::to_string(input_0.encryption.getValue().encoding.precision); + auto crt = input_0.encryption.getValue().encoding.crt; + if (!crt.empty()) { + paramDescription = paramDescription + "_crt_"; + for (auto b : crt) { + paramDescription = paramDescription + "_" + std::to_string(b); + } + } + return paramDescription; });