diff --git a/compiler/include/concretelang/ClientLib/ClientParameters.h b/compiler/include/concretelang/ClientLib/ClientParameters.h index 237e62842..4b4cc5057 100644 --- a/compiler/include/concretelang/ClientLib/ClientParameters.h +++ b/compiler/include/concretelang/ClientLib/ClientParameters.h @@ -133,9 +133,10 @@ static inline bool operator==(const PackingKeySwitchParam &lhs, struct Encoding { Precision precision; CRTDecomposition crt; + bool isSigned; }; static inline bool operator==(const Encoding &lhs, const Encoding &rhs) { - return lhs.precision == rhs.precision; + return lhs.precision == rhs.precision && lhs.isSigned == rhs.isSigned; } struct EncryptionGate { diff --git a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td b/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td index 2d026a4ae..5f93581a1 100644 --- a/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td +++ b/compiler/include/concretelang/Dialect/BConcrete/IR/BConcreteOps.td @@ -71,7 +71,8 @@ def BConcrete_EncodeExpandLutForBootstrapTensorOp : BConcrete_Op<"encode_expand_ let arguments = (ins 1DTensorOf<[I64]> : $input_lookup_table, I32Attr: $polySize, - I32Attr: $outputBits + I32Attr: $outputBits, + BoolAttr: $isSigned ); let results = (outs 1DTensorOf<[I64]> : $result); @@ -86,7 +87,8 @@ def BConcrete_EncodeExpandLutForWopPBSTensorOp : BConcrete_Op<"encode_expand_lut I64ArrayAttr: $crtDecomposition, I64ArrayAttr: $crtBits, I32Attr : $polySize, - I32Attr : $modulusProduct + I32Attr : $modulusProduct, + BoolAttr: $isSigned ); let results = (outs 1DTensorOf<[I64]> : $result); @@ -226,7 +228,8 @@ def BConcrete_EncodeExpandLutForBootstrapBufferOp : BConcrete_Op<"encode_expand_ BConcrete_LutBuffer: $result, BConcrete_LutBuffer: $input_lookup_table, I32Attr: $polySize, - I32Attr: $outputBits + I32Attr: $outputBits, + BoolAttr : $isSigned ); } @@ -240,7 +243,8 @@ def BConcrete_EncodeExpandLutForWopPBSBufferOp : BConcrete_Op<"encode_expand_lut I64ArrayAttr: $crtDecomposition, I64ArrayAttr: $crtBits, I32Attr : $polySize, - I32Attr : $modulusProduct + I32Attr : $modulusProduct, + BoolAttr: $isSigned ); } diff --git a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td index d9df7788e..130762fc2 100644 --- a/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td +++ b/compiler/include/concretelang/Dialect/Concrete/IR/ConcreteOps.td @@ -60,7 +60,8 @@ def Concrete_EncodeExpandLutForBootstrapOp : Concrete_Op<"encode_expand_lut_for_ let arguments = (ins 1DTensorOf<[I64]> : $input_lookup_table, I32Attr: $polySize, - I32Attr: $outputBits + I32Attr: $outputBits, + BoolAttr: $isSigned ); let results = (outs 1DTensorOf<[I64]> : $result); @@ -75,7 +76,8 @@ let summary = I64ArrayAttr: $crtDecomposition, I64ArrayAttr: $crtBits, I32Attr : $polySize, - I32Attr : $modulusProduct + I32Attr : $modulusProduct, + BoolAttr: $isSigned ); let results = (outs 1DTensorOf<[I64]> : $result); diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td index 5efd0d2c6..cd129b51a 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td @@ -355,9 +355,9 @@ def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [NoSideEffect]> { ``` }]; - let arguments = (ins FHE_EncryptedIntegerType:$a, + let arguments = (ins FHE_AnyEncryptedInteger:$a, TensorOf<[AnyInteger]>:$lut); - let results = (outs FHE_EncryptedIntegerType); + let results = (outs FHE_AnyEncryptedInteger); let hasVerifier = 1; } diff --git a/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td b/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td index ca36767dd..cd952e1aa 100644 --- a/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td +++ b/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td @@ -26,7 +26,8 @@ def TFHE_EncodeExpandLutForBootstrapOp : TFHE_Op<"encode_expand_lut_for_bootstra let arguments = (ins 1DTensorOf<[I64]> : $input_lookup_table, I32Attr: $polySize, - I32Attr: $outputBits + I32Attr: $outputBits, + BoolAttr: $isSigned ); let results = (outs 1DTensorOf<[I64]> : $result); @@ -41,7 +42,8 @@ def TFHE_EncodeExpandLutForWopPBSOp : TFHE_Op<"encode_expand_lut_for_woppbs"> { I64ArrayAttr: $crtDecomposition, I64ArrayAttr: $crtBits, I32Attr : $polySize, - I32Attr : $modulusProduct + I32Attr : $modulusProduct, + BoolAttr: $isSigned ); let results = (outs 1DTensorOf<[I64]> : $result); diff --git a/compiler/include/concretelang/Runtime/wrappers.h b/compiler/include/concretelang/Runtime/wrappers.h index 2cddfa5c9..200786e59 100644 --- a/compiler/include/concretelang/Runtime/wrappers.h +++ b/compiler/include/concretelang/Runtime/wrappers.h @@ -27,7 +27,7 @@ void memref_encode_expand_lut_for_bootstrap( uint64_t output_lut_stride, uint64_t *input_lut_allocated, uint64_t *input_lut_aligned, uint64_t input_lut_offset, uint64_t input_lut_size, uint64_t input_lut_stride, uint32_t poly_size, - uint32_t out_MESSAGE_BITS); + uint32_t out_MESSAGE_BITS, bool is_signed); void memref_encode_expand_lut_for_woppbs( uint64_t *output_lut_allocated, uint64_t *output_lut_aligned, @@ -40,7 +40,7 @@ void memref_encode_expand_lut_for_woppbs( uint64_t crt_decomposition_stride, uint64_t *crt_bits_allocated, uint64_t *crt_bits_aligned, uint64_t crt_bits_offset, uint64_t crt_bits_size, uint64_t crt_bits_stride, uint32_t poly_size, - uint32_t modulus_product); + uint32_t modulus_product, bool is_signed); void memref_encode_plaintext_with_crt( uint64_t *output_allocated, uint64_t *output_aligned, diff --git a/compiler/include/concretelang/Support/LambdaSupport.h b/compiler/include/concretelang/Support/LambdaSupport.h index 8e4ff894c..d3e68ced1 100644 --- a/compiler/include/concretelang/Support/LambdaSupport.h +++ b/compiler/include/concretelang/Support/LambdaSupport.h @@ -290,7 +290,8 @@ public: // treatment, since it may alias none of the fixed size integer // types llvm::Expected successOrError = - LambdaArgumentAdaptor::tryAddArg(encryptedArgs, arg, keySet); if (!successOrError) diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 3e75aa2bb..e6e4e2b10 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -238,6 +238,19 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( [](mlir::concretelang::ClientParameters &clientParameters) { return pybind11::bytes( clientParametersSerialize(clientParameters)); + }) + .def("output_signs", + [](mlir::concretelang::ClientParameters &clientParameters) { + std::vector result; + for (auto output : clientParameters.outputs) { + if (output.encryption.hasValue()) { + result.push_back( + output.encryption.getValue().encoding.isSigned); + } else { + result.push_back(true); + } + } + return result; }); pybind11::class_(m, "KeySet") diff --git a/compiler/lib/Bindings/Python/concrete/compiler/client_parameters.py b/compiler/lib/Bindings/Python/concrete/compiler/client_parameters.py index 4b388a030..e5d14cfa7 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/client_parameters.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/client_parameters.py @@ -3,6 +3,8 @@ """Client parameters.""" +from typing import List + # pylint: disable=no-name-in-module,import-error from mlir._mlir_libs._concretelang._compiler import ( ClientParameters as _ClientParameters, @@ -35,6 +37,14 @@ class ClientParameters(WrapperCpp): ) super().__init__(client_parameters) + def output_signs(self) -> List[bool]: + """Return the sign information of outputs. + + Returns: + List[bool]: list of booleans to indicate whether the outputs are signed or not + """ + return self.cpp().output_signs() + def serialize(self) -> bytes: """Serialize the ClientParameters. diff --git a/compiler/lib/Bindings/Python/concrete/compiler/client_support.py b/compiler/lib/Bindings/Python/concrete/compiler/client_support.py index b98f0cae8..2166bb16b 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/client_support.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/client_support.py @@ -123,11 +123,14 @@ class ClientSupport(WrapperCpp): @staticmethod def decrypt_result( - keyset: KeySet, public_result: PublicResult + client_parameters: ClientParameters, + keyset: KeySet, + public_result: PublicResult, ) -> Union[int, np.ndarray]: """Decrypt a public result using the keyset. Args: + client_parameters (ClientParameters): client parameters for decryption keyset (KeySet): keyset used for decryption public_result: public result to decrypt @@ -148,12 +151,20 @@ class ClientSupport(WrapperCpp): lambda_arg = LambdaArgument.wrap( _ClientSupport.decrypt_result(keyset.cpp(), public_result.cpp()) ) + + output_signs = client_parameters.output_signs() + assert len(output_signs) == 1 + + is_signed = output_signs[0] if lambda_arg.is_scalar(): - return lambda_arg.get_scalar() + result = lambda_arg.get_scalar() + return ( + result if not is_signed else int(np.array([result]).astype(np.int64)[0]) + ) if lambda_arg.is_tensor(): shape = lambda_arg.get_tensor_shape() tensor = np.array(lambda_arg.get_tensor_data()).reshape(shape) - return tensor + return tensor if not is_signed else tensor.astype(np.int64) raise RuntimeError("unknown return type") @staticmethod @@ -171,29 +182,42 @@ class ClientSupport(WrapperCpp): """ if not isinstance(value, ACCEPTED_TYPES): raise TypeError( - "value of lambda argument must be either int, numpy.array or numpy.uint{8,16,32,64}" + "value of lambda argument must be either int, numpy.array or numpy.(u)int{8,16,32,64}" ) if isinstance(value, ACCEPTED_INTS): - if isinstance(value, int) and not 0 <= value < np.iinfo(np.uint64).max: + if ( + isinstance(value, int) + and not np.iinfo(np.int64).min <= value < np.iinfo(np.uint64).max + ): raise TypeError( - "single integer must be in the range [0, 2**64 - 1] (uint64)" + "single integer must be in the range [-2**63, 2**64 - 1]" ) + if value < 0: + value = int(np.int64(value).astype(np.uint64)) return LambdaArgument.from_scalar(value) assert isinstance(value, np.ndarray) if value.dtype not in ACCEPTED_NUMPY_UINTS: - raise TypeError("numpy.array must be of dtype uint{8,16,32,64}") + raise TypeError("numpy.array must be of dtype (u)int{8,16,32,64}") if value.shape == (): if isinstance(value, np.ndarray): # extract the single element value = value.max() # should be a single uint here return LambdaArgument.from_scalar(value) - if value.dtype == np.uint8: - return LambdaArgument.from_tensor_8(value.flatten().tolist(), value.shape) - if value.dtype == np.uint16: - return LambdaArgument.from_tensor_16(value.flatten().tolist(), value.shape) - if value.dtype == np.uint32: - return LambdaArgument.from_tensor_32(value.flatten().tolist(), value.shape) - if value.dtype == np.uint64: - return LambdaArgument.from_tensor_64(value.flatten().tolist(), value.shape) - raise TypeError("numpy.array must be of dtype uint{8,16,32,64}") + if value.dtype in [np.uint8, np.int8]: + return LambdaArgument.from_tensor_8( + value.astype(np.uint8).flatten().tolist(), value.shape + ) + if value.dtype in [np.uint16, np.int16]: + return LambdaArgument.from_tensor_16( + value.astype(np.uint16).flatten().tolist(), value.shape + ) + if value.dtype in [np.uint32, np.int32]: + return LambdaArgument.from_tensor_32( + value.astype(np.uint32).flatten().tolist(), value.shape + ) + if value.dtype in [np.uint64, np.int64]: + return LambdaArgument.from_tensor_64( + value.astype(np.uint64).flatten().tolist(), value.shape + ) + raise TypeError("numpy.array must be of dtype (u)int{8,16,32,64}") diff --git a/compiler/lib/Bindings/Python/concrete/compiler/lambda_argument.py b/compiler/lib/Bindings/Python/concrete/compiler/lambda_argument.py index 870a269ae..034a7ada2 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/lambda_argument.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/lambda_argument.py @@ -58,7 +58,7 @@ class LambdaArgument(WrapperCpp): """ if not isinstance(scalar, ACCEPTED_INTS): raise TypeError( - f"scalar must be of type int or numpy.uint, not {type(scalar)}" + f"scalar must be of type int or numpy.(u)int, not {type(scalar)}" ) return LambdaArgument.wrap(_LambdaArgument.from_scalar(scalar)) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/utils.py b/compiler/lib/Bindings/Python/concrete/compiler/utils.py index a0be5c2f8..d0933eb40 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/utils.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/utils.py @@ -6,7 +6,16 @@ import os import numpy as np -ACCEPTED_NUMPY_UINTS = (np.uint8, np.uint16, np.uint32, np.uint64) +ACCEPTED_NUMPY_UINTS = ( + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, +) ACCEPTED_INTS = (int,) + ACCEPTED_NUMPY_UINTS ACCEPTED_TYPES = (np.ndarray,) + ACCEPTED_INTS diff --git a/compiler/lib/ClientLib/ClientParameters.cpp b/compiler/lib/ClientLib/ClientParameters.cpp index 17ba2216b..6df51747e 100644 --- a/compiler/lib/ClientLib/ClientParameters.cpp +++ b/compiler/lib/ClientLib/ClientParameters.cpp @@ -306,6 +306,7 @@ bool fromJSON(const llvm::json::Value j, CircuitGateShape &v, llvm::json::Value toJSON(const Encoding &v) { llvm::json::Object object{ {"precision", v.precision}, + {"isSigned", v.isSigned}, }; if (!v.crt.empty()) { object.insert({"crt", v.crt}); @@ -324,6 +325,12 @@ bool fromJSON(const llvm::json::Value j, Encoding &v, llvm::json::Path p) { return false; } v.precision = precision.getValue(); + auto isSigned = obj->getBoolean("isSigned"); + if (!isSigned.hasValue()) { + p.report("missing isSigned field"); + return false; + } + v.isSigned = isSigned.getValue(); auto crt = obj->getArray("crt"); if (crt != nullptr) { for (auto dim : *crt) { diff --git a/compiler/lib/ClientLib/KeySet.cpp b/compiler/lib/ClientLib/KeySet.cpp index dd7ebe7fc..f4bb8fcab 100644 --- a/compiler/lib/ClientLib/KeySet.cpp +++ b/compiler/lib/ClientLib/KeySet.cpp @@ -355,37 +355,62 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) { if (!encryption.hasValue()) { return StringError("decrypt_lwe: the positional argument is not encrypted"); } + auto crt = encryption->encoding.crt; - // CRT encoding - N blocks with crt encoding - if (!crt.empty()) { + + if (!crt.empty()) { // The ciphertext used the crt strategy. + + // Decrypt and decode remainders std::vector remainders; - // decrypt and decode remainders for (auto modulus : crt) { uint64_t decrypted; CAPI_ASSERT_ERROR( default_engine_decrypt_lwe_ciphertext_u64_raw_ptr_buffers( engine, lweSecretKey, ciphertext, &decrypted)); - auto plaintext = crt::decode(decrypted, modulus); remainders.push_back(plaintext); ciphertext = ciphertext + lweSecretKeyParam.lweSize(); } - // compute the inverse crt + + // Compute the inverse crt output = crt::iCrt(crt, remainders); - return outcome::success(); + + // Further decode signed integers + if (encryption->encoding.isSigned) { + uint64_t maxPos = 1; + for (auto prime : encryption->encoding.crt) { + maxPos *= prime; + } + maxPos /= 2; + if (output >= maxPos) { + output -= maxPos * 2; + } + } + } else { // The ciphertext used the scalar strategy + + // Decrypt + uint64_t plaintext; + CAPI_ASSERT_ERROR(default_engine_decrypt_lwe_ciphertext_u64_raw_ptr_buffers( + engine, lweSecretKey, ciphertext, &plaintext)); + + // Decode unsigned integer + uint64_t precision = encryption->encoding.precision; + output = plaintext >> (64 - precision - 2); + auto carry = output % 2; + uint64_t mod = (((uint64_t)1) << (precision + 1)); + output = ((output >> 1) + carry) % mod; + + // Further decode signed integers. + if (encryption->encoding.isSigned) { + uint64_t maxPos = (((uint64_t)1) << (precision - 1)); + if (output >= maxPos) { // The output is actually negative. + // Set the preceding bits to zero + output |= UINT64_MAX << precision; + // This makes sure when the value is cast to int64, it has the correct + // value + }; + } } - // Simple TFHE integers - 1 blocks with one padding bits - uint64_t plaintext; - - CAPI_ASSERT_ERROR(default_engine_decrypt_lwe_ciphertext_u64_raw_ptr_buffers( - engine, lweSecretKey, ciphertext, &plaintext)); - - // Decode - uint64_t precision = encryption->encoding.precision; - output = plaintext >> (64 - precision - 2); - auto carry = output % 2; - uint64_t mod = (((uint64_t)1) << (precision + 1)); - output = ((output >> 1) + carry) % mod; return outcome::success(); } diff --git a/compiler/lib/Conversion/BConcreteToCAPI/BConcreteToCAPI.cpp b/compiler/lib/Conversion/BConcreteToCAPI/BConcreteToCAPI.cpp index 55b559f81..b5f6ebe0a 100644 --- a/compiler/lib/Conversion/BConcreteToCAPI/BConcreteToCAPI.cpp +++ b/compiler/lib/Conversion/BConcreteToCAPI/BConcreteToCAPI.cpp @@ -174,16 +174,16 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI( memref1DType, rewriter.getI64Type()}, {}); } else if (funcName == memref_encode_expand_lut_for_bootstrap) { - funcType = - mlir::FunctionType::get(rewriter.getContext(), - {memref1DType, memref1DType, - rewriter.getI32Type(), rewriter.getI32Type()}, - {}); + funcType = mlir::FunctionType::get( + rewriter.getContext(), + {memref1DType, memref1DType, rewriter.getI32Type(), + rewriter.getI32Type(), rewriter.getI1Type()}, + {}); } else if (funcName == memref_encode_expand_lut_for_woppbs) { funcType = mlir::FunctionType::get( rewriter.getContext(), {memref1DType, memref1DType, memref1DType, memref1DType, - rewriter.getI32Type(), rewriter.getI32Type()}, + rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI1Type()}, {}); } else { op->emitError("unknwon external function") << funcName; @@ -359,6 +359,9 @@ void encodeExpandLutForBootstrapAddOperands( // output bits operands.push_back(rewriter.create( op.getLoc(), op.outputBitsAttr())); + // is_signed + operands.push_back( + rewriter.create(op.getLoc(), op.isSignedAttr())); } void encodeExpandLutForWopPBSAddOperands( @@ -409,6 +412,9 @@ void encodeExpandLutForWopPBSAddOperands( // modulus_product operands.push_back(rewriter.create( op.getLoc(), op.modulusProductAttr())); + // is_signed + operands.push_back( + rewriter.create(op.getLoc(), op.isSignedAttr())); } struct BConcreteToCAPIPass : public BConcreteToCAPIBase { diff --git a/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp b/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp index f2acfb64f..5a89a171a 100644 --- a/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp +++ b/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp @@ -38,41 +38,42 @@ namespace fhe_to_tfhe_crt_conversion { namespace typing { -/// Converts `FHE::EncryptedInteger` into `Tensor`. -mlir::RankedTensorType convertEint(mlir::MLIRContext *context, - FHE::EncryptedIntegerType eint, - uint64_t crtLength) { +/// Converts an encrypted integer into `TFHE::GlweCiphertext`. +mlir::RankedTensorType convertEncrypted(mlir::MLIRContext *context, + FHE::FheIntegerInterface enc, + uint64_t crtLength) { return mlir::RankedTensorType::get( mlir::ArrayRef((int64_t)crtLength), - TFHE::GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth())); + TFHE::GLWECipherTextType::get(context, -1, -1, -1, enc.getWidth())); } -/// Converts `Tensor` into a -/// `Tensor` if the element type is appropriate. Otherwise -/// return the input type. -mlir::Type maybeConvertEintTensor(mlir::MLIRContext *context, - mlir::RankedTensorType maybeEintTensor, - uint64_t crtLength) { - if (!maybeEintTensor.getElementType().isa()) { - return (mlir::Type)(maybeEintTensor); +/// Converts `Tensor` into a +/// `Tensor` if the element type is appropriate. +/// Otherwise return the input type. +mlir::Type +maybeConvertEncryptedTensor(mlir::MLIRContext *context, + mlir::RankedTensorType maybeEncryptedTensor, + uint64_t crtLength) { + if (!maybeEncryptedTensor.getElementType().isa()) { + return (mlir::Type)(maybeEncryptedTensor); } - auto eint = - maybeEintTensor.getElementType().cast(); - auto currentShape = maybeEintTensor.getShape(); + auto encType = + maybeEncryptedTensor.getElementType().cast(); + auto currentShape = maybeEncryptedTensor.getShape(); mlir::SmallVector newShape = mlir::SmallVector(currentShape.begin(), currentShape.end()); newShape.push_back((int64_t)crtLength); return mlir::RankedTensorType::get( llvm::ArrayRef(newShape), - TFHE::GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth())); + TFHE::GLWECipherTextType::get(context, -1, -1, -1, encType.getWidth())); } -/// Converts the type `FHE::EncryptedInteger` to `Tensor` -/// if the input type is appropriate. Otherwise return the input type. -mlir::Type maybeConvertEint(mlir::MLIRContext *context, mlir::Type t, - uint64_t crtLength) { - if (auto eint = t.dyn_cast()) - return convertEint(context, eint, crtLength); +/// Converts any encrypted type to `TFHE::GlweCiphetext` if the +/// input type is appropriate. Otherwise return the input type. +mlir::Type maybeConvertEncrypted(mlir::MLIRContext *context, mlir::Type t, + uint64_t crtLength) { + if (auto eint = t.dyn_cast()) + return convertEncrypted(context, eint, crtLength); return t; } @@ -85,11 +86,11 @@ public: TypeConverter(concretelang::CrtLoweringParameters loweringParameters) { size_t nMods = loweringParameters.nMods; addConversion([](mlir::Type type) { return type; }); - addConversion([=](FHE::EncryptedIntegerType type) { - return convertEint(type.getContext(), type, nMods); + addConversion([=](FHE::FheIntegerInterface type) { + return convertEncrypted(type.getContext(), type, nMods); }); addConversion([=](mlir::RankedTensorType type) { - return maybeConvertEintTensor(type.getContext(), type, nMods); + return maybeConvertEncryptedTensor(type.getContext(), type, nMods); }); addConversion([&](concretelang::RT::FutureType type) { return concretelang::RT::FutureType::get(this->convertType( @@ -517,6 +518,8 @@ struct ApplyLookupTableEintOpPattern mlir::ConversionPatternRewriter &rewriter) const override { mlir::TypeConverter *converter = this->getTypeConverter(); + auto originalInputType = op.a().getType().cast(); + mlir::Value newLut = rewriter .create( @@ -530,7 +533,8 @@ struct ApplyLookupTableEintOpPattern rewriter.getI64ArrayAttr( mlir::ArrayRef(loweringParameters.bits)), rewriter.getI32IntegerAttr(loweringParameters.polynomialSize), - rewriter.getI32IntegerAttr(loweringParameters.modsProd)) + rewriter.getI32IntegerAttr(loweringParameters.modsProd), + rewriter.getBoolAttr(originalInputType.isSigned())) .getResult(); // Replace the lut with an encoded / expanded one. diff --git a/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp b/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp index d2254beb8..e374e68b6 100644 --- a/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp +++ b/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp @@ -37,34 +37,34 @@ namespace fhe_to_tfhe_scalar_conversion { namespace typing { -/// Converts `FHE::EncryptedInteger` into `TFHE::GlweCiphetext`. -TFHE::GLWECipherTextType convertEint(mlir::MLIRContext *context, - FHE::EncryptedIntegerType eint) { - return TFHE::GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth()); +/// Converts an encrypted integer into `TFHE::GlweCiphetext`. +TFHE::GLWECipherTextType convertEncrypted(mlir::MLIRContext *context, + FHE::FheIntegerInterface enc) { + return TFHE::GLWECipherTextType::get(context, -1, -1, -1, enc.getWidth()); } -/// Converts `Tensor` into a +/// Converts `Tensor` into a /// `Tensor` if the element type is appropriate. /// Otherwise return the input type. -mlir::Type maybeConvertEintTensor(mlir::MLIRContext *context, - mlir::RankedTensorType maybeEintTensor) { - if (!maybeEintTensor.getElementType().isa()) { - return (mlir::Type)(maybeEintTensor); +mlir::Type +maybeConvertEncryptedTensor(mlir::MLIRContext *context, + mlir::RankedTensorType maybeEncryptedTensor) { + if (!maybeEncryptedTensor.getElementType().isa()) { + return (mlir::Type)(maybeEncryptedTensor); } - auto eint = - maybeEintTensor.getElementType().cast(); - auto currentShape = maybeEintTensor.getShape(); + auto enc = + maybeEncryptedTensor.getElementType().cast(); + auto currentShape = maybeEncryptedTensor.getShape(); return mlir::RankedTensorType::get( currentShape, - TFHE::GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth())); + TFHE::GLWECipherTextType::get(context, -1, -1, -1, enc.getWidth())); } -/// Converts the type `FHE::EncryptedInteger` to `TFHE::GlweCiphetext` if the +/// Converts any encrypted type to `TFHE::GlweCiphetext` if the /// input type is appropriate. Otherwise return the input type. -mlir::Type maybeConvertEint(mlir::MLIRContext *context, mlir::Type t) { - if (auto eint = t.dyn_cast()) - return convertEint(context, eint); - +mlir::Type maybeConvertEncrypted(mlir::MLIRContext *context, mlir::Type t) { + if (auto eint = t.dyn_cast()) + return convertEncrypted(context, eint); return t; } @@ -75,8 +75,8 @@ class TypeConverter : public mlir::TypeConverter { public: TypeConverter() { addConversion([](mlir::Type type) { return type; }); - addConversion([](FHE::EncryptedIntegerType type) { - return convertEint(type.getContext(), type); + addConversion([](FHE::FheIntegerInterface type) { + return convertEncrypted(type.getContext(), type); }); addConversion([](FHE::EncryptedBooleanType type) { return TFHE::GLWECipherTextType::get( @@ -84,7 +84,7 @@ public: mlir::concretelang::FHE::EncryptedBooleanType::getWidth()); }); addConversion([](mlir::RankedTensorType type) { - return maybeConvertEintTensor(type.getContext(), type); + return maybeConvertEncryptedTensor(type.getContext(), type); }); addConversion([&](concretelang::RT::FutureType type) { return concretelang::RT::FutureType::get(this->convertType( @@ -145,7 +145,7 @@ struct AddEintIntOpPattern : public ScalarOpPattern { // Write the plaintext encoding mlir::Value encodedInt = writePlaintextShiftEncoding( op.getLoc(), adaptor.b(), - op.getType().cast().getWidth(), rewriter); + op.getType().cast().getWidth(), rewriter); // Write the new op rewriter.replaceOpWithNewOp( @@ -183,7 +183,7 @@ struct SubEintIntOpPattern : public ScalarOpPattern { // Write the plaintext encoding mlir::Value encodedInt = writePlaintextShiftEncoding( op.getLoc(), negative, - eintOperand.getType().cast().getWidth(), + eintOperand.getType().cast().getWidth(), rewriter); // Write the new op @@ -208,7 +208,7 @@ struct SubIntEintOpPattern : public ScalarOpPattern { // Write the plaintext encoding mlir::Value encodedInt = writePlaintextShiftEncoding( op.getLoc(), adaptor.a(), - op.b().getType().cast().getWidth(), + op.b().getType().cast().getWidth(), rewriter); // Write the new op @@ -290,8 +290,9 @@ struct ApplyLookupTableEintOpPattern FHE::ApplyLookupTableEintOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { + auto inputType = op.a().getType().cast(); size_t outputBits = - op.getResult().getType().cast().getWidth(); + op.getResult().getType().cast().getWidth(); mlir::Value newLut = rewriter .create( @@ -301,12 +302,36 @@ struct ApplyLookupTableEintOpPattern rewriter.getI64Type()), op.lut(), rewriter.getI32IntegerAttr(loweringParameters.polynomialSize), - rewriter.getI32IntegerAttr(outputBits)) + rewriter.getI32IntegerAttr(outputBits), + rewriter.getBoolAttr(inputType.isSigned())) .getResult(); + typing::TypeConverter converter; + mlir::Value input = adaptor.a(); + + if (inputType.isSigned()) { + // If the input is a signed integer, it comes to the bootstrap with a + // signed-leveled encoding (compatible with 2s complement semantics). + // Unfortunately pbs is not compatible with this encoding, since the + // (virtual) msb must be 0 to avoid a lookup in the phantom negative lut. + uint64_t constantRaw = (uint64_t)1 << (inputType.getWidth() - 1); + // Note that the constant must be encoded with one more bit to ensure the + // signed extension used in the plaintext encoding works as expected. + mlir::Value constant = rewriter.create( + op.getLoc(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(inputType.getWidth() + 1), constantRaw)); + mlir::Value encodedConstant = writePlaintextShiftEncoding( + op.getLoc(), constant, inputType.getWidth(), rewriter); + auto inputOp = rewriter.create( + op.getLoc(), converter.convertType(input.getType()), input, + encodedConstant); + input = inputOp; + } + // Insert keyswitch auto ksOp = rewriter.create( - op.getLoc(), adaptor.a().getType(), adaptor.a(), -1, -1); + op.getLoc(), getTypeConverter()->convertType(adaptor.a().getType()), input, -1, -1); // Insert bootstrap rewriter.replaceOpWithNewOp( diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 0a5209c77..15ab78667 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -1215,7 +1215,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { .getType() .cast() .getElementType() - .isa()) { + .isa()) { norm2SqEquiv = getSqMANP(transposeOp, operands); } else { isDummy = true; @@ -1227,7 +1227,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { else if (auto extractOp = llvm::dyn_cast(op)) { if (extractOp.result() .getType() - .isa()) { + .isa()) { norm2SqEquiv = getSqMANP(extractOp, operands); } else { isDummy = true; @@ -1240,7 +1240,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { .getType() .cast() .getElementType() - .isa()) { + .isa()) { norm2SqEquiv = getSqMANP(extractSliceOp, operands); } else { isDummy = true; @@ -1252,7 +1252,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { .getType() .cast() .getElementType() - .isa()) { + .isa()) { norm2SqEquiv = getSqMANP(insertOp, operands); } else { isDummy = true; @@ -1265,7 +1265,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { .getType() .cast() .getElementType() - .isa()) { + .isa()) { norm2SqEquiv = getSqMANP(insertSliceOp, operands); } else { isDummy = true; @@ -1277,7 +1277,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { .getType() .cast() .getElementType() - .isa()) { + .isa()) { norm2SqEquiv = getSqMANP(fromOp, operands); } else { isDummy = true; @@ -1290,7 +1290,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { .getType() .cast() .getElementType() - .isa()) { + .isa()) { norm2SqEquiv = getSqMANP(reshapeOp, operands); } else { isDummy = true; @@ -1302,7 +1302,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { .getType() .cast() .getElementType() - .isa()) { + .isa()) { norm2SqEquiv = getSqMANP(reshapeOp, operands); } else { isDummy = true; @@ -1410,16 +1410,15 @@ protected: // Process all results using MANP attribute from MANP pas for (mlir::OpResult res : op->getResults()) { - mlir::concretelang::FHE::EncryptedIntegerType eTy = + mlir::concretelang::FHE::FheIntegerInterface eTy = res.getType() - .dyn_cast_or_null< - mlir::concretelang::FHE::EncryptedIntegerType>(); + .dyn_cast_or_null(); if (eTy == nullptr) { auto tensorTy = res.getType().dyn_cast_or_null(); if (tensorTy != nullptr) { eTy = tensorTy.getElementType() .dyn_cast_or_null< - mlir::concretelang::FHE::EncryptedIntegerType>(); + mlir::concretelang::FHE::FheIntegerInterface>(); } } diff --git a/compiler/lib/Dialect/FHE/Analysis/utils.cpp b/compiler/lib/Dialect/FHE/Analysis/utils.cpp index 5f23cdb1c..483536952 100644 --- a/compiler/lib/Dialect/FHE/Analysis/utils.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/utils.cpp @@ -13,14 +13,13 @@ namespace utils { /// Returns `true` if the given value is a scalar or tensor argument of /// a function, for which a MANP of 1 can be assumed. bool isEncryptedValue(mlir::Value value) { - return ( - value.getType().isa() || - value.getType().isa() || + return (value.getType().isa() || + value.getType().isa() || (value.getType().isa() && value.getType() .cast() .getElementType() - .isa())); + .isa())); } /// Returns the bit width of `value` if `value` is an encrypted integer, @@ -30,7 +29,7 @@ bool isEncryptedValue(mlir::Value value) { unsigned int getEintPrecision(mlir::Value value) { if (auto ty = value.getType() .dyn_cast_or_null< - mlir::concretelang::FHE::EncryptedIntegerType>()) { + mlir::concretelang::FHE::FheIntegerInterface>()) { return ty.getWidth(); } if (auto ty = value.getType() @@ -41,7 +40,7 @@ unsigned int getEintPrecision(mlir::Value value) { value.getType().dyn_cast_or_null()) { if (auto ty = tensorTy.getElementType() .dyn_cast_or_null< - mlir::concretelang::FHE::EncryptedIntegerType>()) + mlir::concretelang::FHE::FheIntegerInterface>()) return ty.getWidth(); } diff --git a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp index 6a4de8497..5896fb09c 100644 --- a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp +++ b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp @@ -228,7 +228,7 @@ mlir::LogicalResult GenGateOp::verify() { } ::mlir::LogicalResult ApplyLookupTableEintOp::verify() { - auto ct = this->a().getType().cast(); + auto ct = this->a().getType().cast(); auto lut = this->lut().getType().cast(); // Check the shape of lut argument diff --git a/compiler/lib/Runtime/wrappers.cpp b/compiler/lib/Runtime/wrappers.cpp index 46a3d2e6e..3dc732f2e 100644 --- a/compiler/lib/Runtime/wrappers.cpp +++ b/compiler/lib/Runtime/wrappers.cpp @@ -8,6 +8,7 @@ #include "concretelang/Runtime/seeder.h" #include #include +#include #include #include #include @@ -253,7 +254,7 @@ void memref_encode_expand_lut_for_bootstrap( uint64_t output_lut_stride, uint64_t *input_lut_allocated, uint64_t *input_lut_aligned, uint64_t input_lut_offset, uint64_t input_lut_size, uint64_t input_lut_stride, uint32_t poly_size, - uint32_t out_MESSAGE_BITS) { + uint32_t out_MESSAGE_BITS, bool is_signed) { assert(input_lut_stride == 1 && "Runtime: stride not equal to 1, check " "memref_encode_expand_lut_bootstrap"); @@ -265,19 +266,41 @@ void memref_encode_expand_lut_for_bootstrap( assert((mega_case_size % 2) == 0); - for (size_t idx = 0; idx < mega_case_size / 2; ++idx) { - output_lut_aligned[output_lut_offset + idx] = - input_lut_aligned[input_lut_offset] << (64 - out_MESSAGE_BITS - 1); + // When the bootstrap is executed on encrypted signed integers, the lut must + // be half-rotated. This map takes care about properly indexing into the input + // lut depending on what bootstrap gets executed. + std::function indexMap; + if (is_signed) { + size_t halfInputSize = input_lut_size / 2; + indexMap = [=](size_t idx) { + if (idx < halfInputSize) { + return idx + halfInputSize; + } else { + return idx - halfInputSize; + } + }; + } else { + indexMap = [=](size_t idx) { return idx; }; } + // The first lut value should be centered over zero. This means that half of + // it should appear at the beginning of the output lut, and half of it at the + // end (but negated). + for (size_t idx = 0; idx < mega_case_size / 2; ++idx) { + output_lut_aligned[output_lut_offset + idx] = + input_lut_aligned[input_lut_offset + indexMap(0)] + << (64 - out_MESSAGE_BITS - 1); + } for (size_t idx = (input_lut_size - 1) * mega_case_size + mega_case_size / 2; idx < output_lut_size; ++idx) { output_lut_aligned[output_lut_offset + idx] = - -(input_lut_aligned[input_lut_offset] << (64 - out_MESSAGE_BITS - 1)); + -(input_lut_aligned[input_lut_offset + indexMap(0)] + << (64 - out_MESSAGE_BITS - 1)); } + // Treats the other ut values. for (size_t lut_idx = 1; lut_idx < input_lut_size; ++lut_idx) { - uint64_t lut_value = input_lut_aligned[input_lut_offset + lut_idx] + uint64_t lut_value = input_lut_aligned[input_lut_offset + indexMap(lut_idx)] << (64 - out_MESSAGE_BITS - 1); size_t start = mega_case_size * (lut_idx - 1) + mega_case_size / 2; for (size_t output_idx = start; output_idx < start + mega_case_size; @@ -306,7 +329,7 @@ void memref_encode_expand_lut_for_woppbs( uint64_t *crt_bits_allocated, uint64_t *crt_bits_aligned, uint64_t crt_bits_offset, uint64_t crt_bits_size, uint64_t crt_bits_stride, // Crypto parameters - uint32_t poly_size, uint32_t modulus_product) { + uint32_t poly_size, uint32_t modulus_product, bool is_signed) { assert(input_lut_stride == 1 && "Runtime: stride not equal to 1, check " "memref_encode_expand_lut_woppbs"); @@ -314,22 +337,77 @@ void memref_encode_expand_lut_for_woppbs( "memref_encode_expand_lut_woppbs"); assert(modulus_product > input_lut_size); + // When the woppbs is executed on encrypted signed integers, the index of the + // lut elements must be adapted to fit the way signed are encrypted in CRT + // (to ensure the lookup falls into the proper case). + // This map takes care about properly indexing into the output lut depending + // on what bootstrap gets executed. + std::function indexMap; + if (!is_signed) { + // When not signed, the integer values are encoded in increasing order. That + // is (example of 9 bits values, using crt decomposition [5,7,16]): + // + // |0 511| + // |---------| + // |0 511| + // + // is encoded as + // + // |0 511| INVALID | + // |-------|-----------| + // |0 511|512 559| + // + // Where on top are represented the semantic values, and below, the actual + // encoding of values, either on uint64_t or as increasing crt values. + // + // As a consequence, there is nothing particular to do to map the index of + // the input lut to an index of the output lut. + indexMap = [=](uint64_t plaintext) { return plaintext; }; + } else { + // When signed, the integer values are encoded in a way that resembles 2s + // complement. That is (example of 9 bits values, using crt decomposition + // [5,7,16]): + // + // |0 255|-256 -1| + // |---------|----------| + // |0 255|256 511| + // + // is encoded as + // + // |0 255| INVALID |-256 -1| + // |---------|-------------|----------| + // |0 255|256 303|304 559| + // + // Where on top are represented the semantic values, and below, the actual + // encoding of values, either on uint64_t or as increasing crt values. + // + // As a consequence, to map the index of the input lut to an index of the + // output lut we must take care of crossing the invalid range in between + // positive values and negative values. + indexMap = [=](uint64_t plaintext) { + if (plaintext >= (input_lut_size / 2)) { + plaintext += modulus_product - input_lut_size; + } + return plaintext; + }; + } + uint64_t lut_crt_size = output_lut_size / crt_decomposition_size; - for (uint64_t value = 0; value < input_lut_size; value++) { + for (uint64_t index = 0; index < input_lut_size; index++) { uint64_t index_lut = 0; uint64_t tmp = 1; for (size_t block = 0; block < crt_decomposition_size; block++) { auto base = crt_decomposition_aligned[crt_decomposition_offset + block]; auto bits = crt_bits_aligned[crt_bits_offset + block]; - index_lut += (((value % base) << bits) / base) * tmp; + index_lut += (((indexMap(index) % base) << bits) / base) * tmp; tmp <<= bits; } for (size_t block = 0; block < crt_decomposition_size; block++) { auto base = crt_decomposition_aligned[crt_decomposition_offset + block]; - auto v = encode_crt(input_lut_aligned[input_lut_offset + value], base, + auto v = encode_crt(input_lut_aligned[input_lut_offset + index], base, modulus_product); output_lut_aligned[output_lut_offset + block * lut_crt_size + index_lut] = v; diff --git a/compiler/lib/Support/V0ClientParameters.cpp b/compiler/lib/Support/V0ClientParameters.cpp index 8224cd20d..1a1782e91 100644 --- a/compiler/lib/Support/V0ClientParameters.cpp +++ b/compiler/lib/Support/V0ClientParameters.cpp @@ -58,8 +58,8 @@ llvm::Expected gateFromMLIRType(V0FHEContext fheContext, }; } if (auto lweTy = type.dyn_cast_or_null< - mlir::concretelang::FHE::EncryptedIntegerType>()) { - bool sign = lweTy.isSignedInteger(); + mlir::concretelang::FHE::FheIntegerInterface>()) { + bool sign = lweTy.isSigned(); std::vector crt; if (fheContext.parameter.largeInteger.has_value()) { crt = fheContext.parameter.largeInteger.value().crtDecomposition; @@ -72,15 +72,14 @@ llvm::Expected gateFromMLIRType(V0FHEContext fheContext, { /* .precision = */ lweTy.getWidth(), /* .crt = */ crt, + /*.sign = */ sign, }, }), /*.shape = */ - { - /*.width = */ (size_t)lweTy.getWidth(), - /*.dimensions = */ std::vector(), - /*.size = */ 0, - /*.sign = */ sign, - }, + {/*.width = */ (size_t)lweTy.getWidth(), + /*.dimensions = */ std::vector(), + /*.size = */ 0, + /*.sign = */ sign}, }; } if (auto lweTy = type.dyn_cast_or_null< @@ -214,17 +213,12 @@ createClientParametersForV0(V0FHEContext fheContext, auto funcType = (*funcOp).getFunctionType(); auto inputs = funcType.getInputs(); - bool hasContext = - inputs.empty() - ? false - : inputs.back().isa(); auto gateFromType = [&](mlir::Type ty) { return gateFromMLIRType(fheContext, clientlib::BIG_KEY, inputVariance, ty); }; - for (auto inType = funcType.getInputs().begin(); - inType < funcType.getInputs().end() - hasContext; inType++) { - auto gate = gateFromType(*inType); + for (auto inType : inputs) { + auto gate = gateFromType(inType); if (auto err = gate.takeError()) { return std::move(err); } diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/encode_expand_lut_for_bootstrap.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/encode_expand_lut_for_bootstrap.mlir index 768599ff8..57f99eada 100644 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/encode_expand_lut_for_bootstrap.mlir +++ b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/encode_expand_lut_for_bootstrap.mlir @@ -1,10 +1,10 @@ // RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s // CHECK: func.func @apply_lookup_table(%arg0: tensor<4xi64>) -> tensor<1024xi64> { -// CHECK-NEXT: %0 = "BConcrete.encode_expand_lut_for_bootstrap_tensor"(%arg0) {outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64> +// CHECK-NEXT: %0 = "BConcrete.encode_expand_lut_for_bootstrap_tensor"(%arg0) {isSigned = true, outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64> // CHECK-NEXT: return %0 : tensor<1024xi64> // CHECK-NEXT: } func.func @apply_lookup_table(%arg0: tensor<4xi64>) -> tensor<1024xi64> { - %0 = "Concrete.encode_expand_lut_for_bootstrap"(%arg0) {outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64> + %0 = "Concrete.encode_expand_lut_for_bootstrap"(%arg0) {outputBits = 3 : i32, polySize = 1024 : i32, isSigned = true} : (tensor<4xi64>) -> tensor<1024xi64> return %0 : tensor<1024xi64> } diff --git a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/encode_expand_lut_for_woppbs.mlir b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/encode_expand_lut_for_woppbs.mlir index 1f3484ae9..2d56b4e6d 100644 --- a/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/encode_expand_lut_for_woppbs.mlir +++ b/compiler/tests/check_tests/Conversion/ConcreteToBConcrete/encode_expand_lut_for_woppbs.mlir @@ -1,10 +1,10 @@ // RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s // CHECK: func.func @main(%arg0: tensor<4xi64>) -> tensor<40960xi64> { -// CHECK-NEXT: %0 = "BConcrete.encode_expand_lut_for_woppbs_tensor"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64> +// CHECK-NEXT: %0 = "BConcrete.encode_expand_lut_for_woppbs_tensor"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64> // CHECK-NEXT: return %0 : tensor<40960xi64> // CHECK-NEXT: } func.func @main(%arg0: tensor<4xi64>) -> tensor<40960xi64> { - %0 = "Concrete.encode_expand_lut_for_woppbs"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64> + %0 = "Concrete.encode_expand_lut_for_woppbs"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64> return %0 : tensor<40960xi64> } diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate.mlir index 117443b24..a958311d8 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate.mlir @@ -1,7 +1,7 @@ // RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s // CHECK: func.func @apply_lookup_table(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{3}>> -// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64> +// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64> // CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bootstrapBaseLog = -1 : i32, bootstrapLevel = -1 : i32, circuitBootstrapBaseLog = -1 : i32, circuitBootstrapLevel = -1 : i32, crtDecomposition = [], keyswitchBaseLog = -1 : i32, keyswitchLevel = -1 : i32, packingKeySwitchBaseLog = -1 : i32, packingKeySwitchInputLweDimension = -1 : i32, packingKeySwitchLevel = -1 : i32, packingKeySwitchoutputPolynomialSize = -1 : i32} : (tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<40960xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{3}>> // CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{3}>> func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<3> { diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate_cst.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate_cst.mlir index ec4a022ee..ce3b8359d 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate_cst.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/apply_univariate_cst.mlir @@ -2,7 +2,7 @@ // CHECK: func.func @apply_lookup_table_cst(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> { // CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64> -// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_woppbs"(%cst) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<128xi64>) -> tensor<40960xi64> +// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_woppbs"(%cst) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<128xi64>) -> tensor<40960xi64> // CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bootstrapBaseLog = -1 : i32, bootstrapLevel = -1 : i32, circuitBootstrapBaseLog = -1 : i32, circuitBootstrapLevel = -1 : i32, crtDecomposition = [], keyswitchBaseLog = -1 : i32, keyswitchLevel = -1 : i32, packingKeySwitchBaseLog = -1 : i32, packingKeySwitchInputLweDimension = -1 : i32, packingKeySwitchLevel = -1 : i32, packingKeySwitchoutputPolynomialSize = -1 : i32} : (tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<40960xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> // CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{7}>> func.func @apply_lookup_table_cst(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/apply_univariate.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/apply_univariate.mlir index bcf0fb342..a17912704 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/apply_univariate.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/apply_univariate.mlir @@ -1,7 +1,7 @@ // RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s // CHECK: func.func @apply_lookup_table(%arg0: !TFHE.glwe<{_,_,_}{2}>, %arg1: tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}> { -// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_bootstrap"(%arg1) {outputBits = 3 : i32, polySize = 256 : i32} : (tensor<4xi64>) -> tensor<256xi64> +// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_bootstrap"(%arg1) {isSigned = false, outputBits = 3 : i32, polySize = 256 : i32} : (tensor<4xi64>) -> tensor<256xi64> // CHECK-NEXT: %1 = "TFHE.keyswitch_glwe"(%arg0) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}> // CHECK-NEXT: %2 = "TFHE.bootstrap_glwe"(%1, %0) {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>, tensor<256xi64>) -> !TFHE.glwe<{_,_,_}{3}> // CHECK-NEXT: return %2 : !TFHE.glwe<{_,_,_}{3}> diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/apply_univariate_cst.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/apply_univariate_cst.mlir index be8bd4687..4a3df25b0 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/apply_univariate_cst.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/apply_univariate_cst.mlir @@ -3,7 +3,7 @@ //CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> { //CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64> -//CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_bootstrap"(%cst) {outputBits = 7 : i32, polySize = 8192 : i32} : (tensor<128xi64>) -> tensor<8192xi64> +//CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_bootstrap"(%cst) {isSigned = false, outputBits = 7 : i32, polySize = 8192 : i32} : (tensor<128xi64>) -> tensor<8192xi64> //CHECK-NEXT: %1 = "TFHE.keyswitch_glwe"(%arg0) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> //CHECK-NEXT: %2 = "TFHE.bootstrap_glwe"(%1, %0) {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>, tensor<8192xi64>) -> !TFHE.glwe<{_,_,_}{7}> //CHECK-NEXT: return %2 : !TFHE.glwe<{_,_,_}{7}> diff --git a/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_bootstrap.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_bootstrap.mlir index 21b001408..b93fe73a6 100644 --- a/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_bootstrap.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_bootstrap.mlir @@ -1,10 +1,10 @@ // RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s // CHECK: func.func @apply_lookup_table(%arg0: tensor<4xi64>) -> tensor<1024xi64> { -// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_bootstrap"(%arg0) {outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64> +// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_bootstrap"(%arg0) {isSigned = true, outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64> // CHECK-NEXT: return %0 : tensor<1024xi64> // CHECK-NEXT: } func.func @apply_lookup_table(%arg1: tensor<4xi64>) -> tensor<1024xi64> { - %0 = "TFHE.encode_expand_lut_for_bootstrap"(%arg1) {outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64> + %0 = "TFHE.encode_expand_lut_for_bootstrap"(%arg1) {outputBits = 3 : i32, polySize = 1024 : i32, isSigned = true} : (tensor<4xi64>) -> tensor<1024xi64> return %0: tensor<1024xi64> } diff --git a/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_woppbs.mlir b/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_woppbs.mlir index 3c054f33d..271398f66 100644 --- a/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_woppbs.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEToConcrete/encode_expand_lut_for_woppbs.mlir @@ -1,10 +1,10 @@ // RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s // CHECK: func.func @main(%arg0: tensor<4xi64>) -> tensor<40960xi64> { -// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_woppbs"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64> +// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_woppbs"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64> // CHECK-NEXT: return %0 : tensor<40960xi64> // CHECK-NEXT: } func.func @main(%arg1: tensor<4xi64>) -> tensor<40960xi64> { - %0 = "TFHE.encode_expand_lut_for_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64> + %0 = "TFHE.encode_expand_lut_for_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64> return %0: tensor<40960xi64> } diff --git a/compiler/tests/end_to_end_fixture/EndToEndFixture.cpp b/compiler/tests/end_to_end_fixture/EndToEndFixture.cpp index 30c49c095..a4aa9b13b 100644 --- a/compiler/tests/end_to_end_fixture/EndToEndFixture.cpp +++ b/compiler/tests/end_to_end_fixture/EndToEndFixture.cpp @@ -111,19 +111,14 @@ llvm::Error checkResult(const mlir::concretelang::TensorLambdaArgument< if (!expectedNumElts) return expectedNumElts.takeError(); - auto hasError = false; StreamStringError err("result value differ"); for (size_t i = 0; i < *expectedNumElts; i++) { - - if (resValues[i] != expectedValues[i]) { - hasError = true; - err << " [pos(" << i << "), got " << resValues[i] << " expected " - << expectedValues[i] << "]"; + if ((uint64_t)resValues[i] != (uint64_t)expectedValues[i]) { + return StreamStringError("result value differ at pos(") + << i << "), got " << resValues[i] << " expected " + << expectedValues[i]; } } - if (hasError) { - return err; - } return llvm::Error::success(); } diff --git a/compiler/tests/end_to_end_fixture/end_to_end_apply_lookup_table_gen.py b/compiler/tests/end_to_end_fixture/end_to_end_apply_lookup_table_gen.py index 73d7935c0..54fb87821 100644 --- a/compiler/tests/end_to_end_fixture/end_to_end_apply_lookup_table_gen.py +++ b/compiler/tests/end_to_end_fixture/end_to_end_apply_lookup_table_gen.py @@ -11,10 +11,11 @@ def generate(args): print("# /!\ DO NOT EDIT MANUALLY THIS FILE MANUALLY") print("# /!\ THIS FILE HAS BEEN GENERATED") np.random.seed(0) + # unsigned_unsigned for p in args.bitwidth: max_value = (2 ** p) - 1 random_lut = np.random.randint(max_value+1, size=2**p) - print(f"description: apply_lookup_table_{p}bits") + print(f"description: unsigned_apply_lookup_table_{p}bits") print("program: |") print( f" func.func @main(%arg0: !FHE.eint<{p}>) -> !FHE.eint<{p}> {{") @@ -41,6 +42,139 @@ def generate(args): print(" outputs:") print(f" - scalar: {random_lut[max_value]}") print("---") + # unsigned_signed + for p in args.bitwidth: + lower_bound = -(2 ** (p-1)) + upper_bound = (2 ** (p-1)) - 1 + max_value = (2 ** p) - 1 + random_lut = np.random.randint(lower_bound, upper_bound, size=2**p) + print(f"description: unsigned_signed_apply_lookup_table_{p}bits") + print("program: |") + print( + f" func.func @main(%arg0: !FHE.eint<{p}>) -> !FHE.esint<{p}> {{") + print(f" %tlu = arith.constant dense<[{','.join(map(str, random_lut))}]> : tensor<{2**p}xi64>") + print( + f" %1 = \"FHE.apply_lookup_table\"(%arg0, %tlu): (!FHE.eint<{p}>, tensor<{2**p}xi64>) -> (!FHE.esint<{p}>)") + print(f" return %1: !FHE.esint<{p}>") + print(" }") + if p >= PRECISION_FORCE_CRT: + print("encoding: crt") + print(f"p-error: {P_ERROR}") + print("tests:") + print(" - inputs:") + print(" - scalar: 0") + print(" outputs:") + print(f" - scalar: {random_lut[0]}") + print(f" signed: true") + print(" - inputs:") + random_i = np.random.randint(max_value) + print(f" - scalar: {random_i}") + print(" outputs:") + print(f" - scalar: {random_lut[random_i]}") + print(f" signed: true") + print(" - inputs:") + print(f" - scalar: {max_value}") + print(" outputs:") + print(f" - scalar: {random_lut[max_value]}") + print(f" signed: true") + print("---") + # signed_signed + for p in args.bitwidth: + lower_bound = -(2 ** (p-1)) + upper_bound = (2 ** (p-1)) - 1 + random_lut = np.random.randint(lower_bound, upper_bound, size=2**p) + print(f"description: signed_apply_lookup_table_{p}bits") + print("program: |") + print( + f" func.func @main(%arg0: !FHE.esint<{p}>) -> !FHE.esint<{p}> {{") + print(f" %tlu = arith.constant dense<[{','.join(map(str, random_lut))}]> : tensor<{2**p}xi64>") + print( + f" %1 = \"FHE.apply_lookup_table\"(%arg0, %tlu): (!FHE.esint<{p}>, tensor<{2**p}xi64>) -> (!FHE.esint<{p}>)") + print(f" return %1: !FHE.esint<{p}>") + print(" }") + if p >= PRECISION_FORCE_CRT: + print("encoding: crt") + print(f"p-error: {P_ERROR}") + print("tests:") + print(" - inputs:") + print(f" - scalar: 0") + print(f" signed: true") + print(" outputs:") + print(f" - scalar: {random_lut[0]}") + print(f" signed: true") + print(" - inputs:") + print(f" - scalar: {upper_bound}") + print(f" signed: true") + print(" outputs:") + print(f" - scalar: {random_lut[upper_bound]}") + print(f" signed: true") + print(" - inputs:") + print(f" - scalar: {lower_bound}") + print(f" signed: true") + print(" outputs:") + print(f" - scalar: {random_lut[lower_bound]}") + print(f" signed: true") + print(" - inputs:") + print(f" - scalar: -1") + print(f" signed: true") + print(" outputs:") + print(f" - scalar: {random_lut[-1]}") + print(f" signed: true") + print(" - inputs:") + random_i = np.random.randint(lower_bound, upper_bound) + print(f" - scalar: {random_i}") + print(f" signed: true") + print(" outputs:") + print(f" - scalar: {random_lut[random_i]}") + print(f" signed: true") + print("---") + + # signed_unsigned + for p in args.bitwidth: + lower_bound = -(2 ** (p-1)) + upper_bound = (2 ** (p-1)) - 1 + max_value = (2 ** p) - 1 + random_lut = np.random.randint(max_value+1, size=2**p) + print(f"description: signed_unsigned_apply_lookup_table_{p}bits") + print("program: |") + print( + f" func.func @main(%arg0: !FHE.esint<{p}>) -> !FHE.eint<{p}> {{") + print(f" %tlu = arith.constant dense<[{','.join(map(str, random_lut))}]> : tensor<{2**p}xi64>") + print( + f" %1 = \"FHE.apply_lookup_table\"(%arg0, %tlu): (!FHE.esint<{p}>, tensor<{2**p}xi64>) -> (!FHE.eint<{p}>)") + print(f" return %1: !FHE.eint<{p}>") + print(" }") + if p >= PRECISION_FORCE_CRT: + print("encoding: crt") + print(f"p-error: {P_ERROR}") + print("tests:") + print(" - inputs:") + print(f" - scalar: 0") + print(f" signed: true") + print(" outputs:") + print(f" - scalar: {random_lut[0]}") + print(" - inputs:") + print(f" - scalar: {upper_bound}") + print(f" signed: true") + print(" outputs:") + print(f" - scalar: {random_lut[upper_bound]}") + print(" - inputs:") + print(f" - scalar: {lower_bound}") + print(f" signed: true") + print(" outputs:") + print(f" - scalar: {random_lut[lower_bound]}") + print(" - inputs:") + print(f" - scalar: -1") + print(f" signed: true") + print(" outputs:") + print(f" - scalar: {random_lut[-1]}") + print(" - inputs:") + random_i = np.random.randint(lower_bound, upper_bound) + print(f" - scalar: {random_i}") + print(f" signed: true") + print(" outputs:") + print(f" - scalar: {random_lut[random_i]}") + print("---") if __name__ == "__main__": CLI = argparse.ArgumentParser() diff --git a/compiler/tests/end_to_end_fixture/end_to_end_leveled_gen.py b/compiler/tests/end_to_end_fixture/end_to_end_leveled_gen.py index b45126d7e..c64a80426 100644 --- a/compiler/tests/end_to_end_fixture/end_to_end_leveled_gen.py +++ b/compiler/tests/end_to_end_fixture/end_to_end_leveled_gen.py @@ -17,6 +17,7 @@ def main(): print("# /!\ DO NOT EDIT MANUALLY THIS FILE MANUALLY") print("# /!\ THIS FILE HAS BEEN GENERATED THANKS THE end_to_end_levelled_gen.py scripts") print("# This reference file aims to test all levelled ops with all bitwidth than we known that the compiler/optimizer support.\n\n") + # unsigned for p in range(MIN_PRECISON, MAX_PRECISION+1): if p != 1: print("---") @@ -301,6 +302,579 @@ def main(): print(" - scalar: {0}".format(max_value)) may_check_error_rate() print("---") + # signed + for p in range(MIN_PRECISON, MAX_PRECISION+1): + print("---") + def may_check_error_rate(): + if p in PRECISIONS_WITH_ERROR_RATES: + print(TEST_ERROR_RATES) + min_value = -(2 ** (p - 1)) + max_value = abs(min_value) - 1 + integer_bitwidth = p + 1 + max_constant = min((2 ** (57-p)) - 1, max_value) + + # identity + print("description: signed_identity_{0}bits".format(p)) + print("program: |") + print(" func.func @main(%arg0: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p)) + print(" return %arg0: !FHE.esint<{0}>".format(p)) + print(" }") + print("tests:") + print(" - inputs:") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + may_check_error_rate() + + print("---") + + # zero_tensor + print("description: signed_zero_tensor_{0}bits".format(p)) + print("program: |") + print(" func.func @main() -> tensor<2x2x4x!FHE.esint<{0}>> {{".format(p)) + print(" %0 = \"FHE.zero_tensor\"() : () -> tensor<2x2x4x!FHE.esint<{0}>>".format(p)) + print(" return %0: tensor<2x2x4x!FHE.esint<{0}>>".format(p)) + print(" }") + print("tests:") + print(" - outputs:") + print(" - tensor: [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]") + print(" shape: [2,2,4]") + print(" signed: true") + may_check_error_rate() + + print("---") + + # add_eint_int_cst + print("description: signed_add_eint_int_cst_{0}bits".format(p)) + print("program: |") + print(" func.func @main(%arg0: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p)) + print(" %0 = arith.constant 1 : i{0}".format(integer_bitwidth)) + print(" %1 = \"FHE.add_eint_int\"(%arg0, %0): (!FHE.esint<{0}>, i{1}) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth)) + print(" return %1: !FHE.esint<{0}>".format(p)) + print(" }") + print("tests:") + print(" - inputs:") + print(" - scalar: {0}".format(-1)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(max_value-1)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(min_value + 1)) + print(" signed: true") + may_check_error_rate() + + print("---") + + # add_eint_int_arg + if p <= 28: + # above 28 bits the *arg test doesn't have solution + # TODO: Make a test that test that + print("description: signed_add_eint_int_arg_{0}bits".format(p)) + print("program: |") + print(" func.func @main(%arg0: !FHE.esint<{0}>, %arg1: i{1}) -> !FHE.esint<{0}> {{".format(p, integer_bitwidth)) + print(" %0 = \"FHE.add_eint_int\"(%arg0, %arg1): (!FHE.esint<{0}>, i{1}) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth)) + print(" return %0: !FHE.esint<{0}>".format(p)) + print(" }") + print("tests:") + print(" - inputs:") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + print(" - scalar: {0}".format(1)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(min_value + 1)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(max_value-1)) + print(" signed: true") + print(" - scalar: {0}".format(1)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(-1)) + print(" signed: true") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(-1)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(-1)) + print(" signed: true") + print(" - scalar: {0}".format(1)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + may_check_error_rate() + + print("---") + + # add_eint + print("description: signed_add_eint_{0}bits".format(p)) + print("program: |") + print(" func.func @main(%arg0: !FHE.esint<{0}>, %arg1: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p)) + print(" %res = \"FHE.add_eint\"(%arg0, %arg1): (!FHE.esint<{0}>, !FHE.esint<{0}>) -> !FHE.esint<{0}>".format(p)) + print(" return %res: !FHE.esint<{0}>".format(p)) + print(" }") + print("tests:") + print(" - inputs:") + print(" - scalar: {0}".format(((2 ** (p - 1)) >> 1) - 1)) + print(" signed: true") + print(" - scalar: {0}".format(((2 ** (p - 1)) >> 1))) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(-1 if p == 1 else (2 ** (p - 1)) - 1)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(-(2 ** (p - 1)))) + print(" signed: true") + print(" - scalar: {0}".format(((2 ** (p - 1)) - 1))) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(-1)) + print(" signed: true") + may_check_error_rate() + + print("---") + + # sub_eint_int_cst + print("description: signed_sub_eint_int_cst_{0}bits".format(p)) + print("program: |") + print(" func.func @main(%arg0: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p)) + print(" %0 = arith.constant 1 : i{0}".format(integer_bitwidth)) + print(" %1 = \"FHE.sub_eint_int\"(%arg0, %0): (!FHE.esint<{0}>, i{1}) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth)) + print(" return %1: !FHE.esint<{0}>".format(p)) + print(" }") + print("tests:") + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(max_value - 1)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(-1)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(min_value + 1)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + may_check_error_rate() + + print("---") + + # sub_eint_int_arg + if p <= 28: + # above 28 bits the *arg test doesn't have solution + # TODO: Make a test that test that + print("description: signed_sub_eint_int_arg_{0}bits".format(p)) + print("program: |") + print(" func.func @main(%arg0: !FHE.esint<{0}>, %arg1: i{1}) -> !FHE.esint<{0}> {{".format(p, integer_bitwidth)) + print(" %1 = \"FHE.sub_eint_int\"(%arg0, %arg1): (!FHE.esint<{0}>, i{1}) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth)) + print(" return %1: !FHE.esint<{0}>".format(p)) + print(" }") + print("tests:") + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - scalar: {0}".format(1)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(max_value - 1)) + print(" signed: true") + if p != 28: + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - scalar: {0}".format(2 * max_value)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(-max_value)) + print(" signed: true") + may_check_error_rate() + + print("---") + + # sub_int_eint_cst + if p != 1: + print("description: signed_sub_int_eint_cst_{0}bits".format(p)) + print("program: |") + print(" func.func @main(%arg0: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p)) + print(" %0 = arith.constant 1 : i{0}".format(integer_bitwidth)) + print(" %1 = \"FHE.sub_int_eint\"(%0, %arg0): (i{1}, !FHE.esint<{0}>) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth)) + print(" return %1: !FHE.esint<{0}>".format(p)) + print(" }") + print("tests:") + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(min_value + 2)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(1)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(min_value + 2)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + may_check_error_rate() + + print("---") + + # sub_int_eint_arg + if p <= 28: + # above 28 bits the *arg test doesn't have solution + # TODO: Make a test that test that + print("description: signed_sub_int_eint_arg_{0}bits".format(p)) + print("program: |") + print(" func.func @main(%arg0: i{1}, %arg1: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p, integer_bitwidth)) + print(" %1 = \"FHE.sub_int_eint\"(%arg0, %arg1): (i{1}, !FHE.esint<{0}>) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth)) + print(" return %1: !FHE.esint<{0}>".format(p)) + print(" }") + print("tests:") + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - scalar: {0}".format(0)) + print(" outputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - scalar: {0}".format(1)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(max_value - 1)) + print(" signed: true") + if p != 28: + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - scalar: {0}".format(2 * max_value)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(-max_value)) + print(" signed: true") + may_check_error_rate() + + print("---") + + # sub_eint + print("description: signed_sub_eint_{0}bits".format(p)) + print("program: |") + print(" func.func @main(%arg0: !FHE.esint<{0}>, %arg1: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p)) + print(" %res = \"FHE.sub_eint\"(%arg0, %arg1): (!FHE.esint<{0}>, !FHE.esint<{0}>) -> !FHE.esint<{0}>".format(p)) + print(" return %res: !FHE.esint<{0}>".format(p)) + print(" }") + print("tests:") + print(" - inputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" - scalar: {0}".format(1)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(-1)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" - scalar: {0}".format(min_value + 1)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(min_value + 1)) + print(" signed: true") + may_check_error_rate() + + print("---") + + # mul_eint_int_cst + print("description: signed_mul_eint_int_cst_{0}bits".format(p)) + print("program: |") + print(" func.func @main(%arg0: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p)) + print(" %0 = arith.constant 2 : i{0}".format(integer_bitwidth)) + print(" %1 = \"FHE.mul_eint_int\"(%arg0, %0): (!FHE.esint<{0}>, i{1}) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth)) + print(" return %1: !FHE.esint<{0}>".format(p)) + print(" }") + print("tests:") + print(" - inputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + if p != 1: + print(" - inputs:") + print(" - scalar: {0}".format(max_value // 2)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(max_value - 1)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(min_value // 2)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + may_check_error_rate() + + print("---") + + # mul_eint_int_arg + if p <= 28: + # above 28 bits the *arg test doesn't have solution + # TODO: Make a test that test that + print("description: signed_mul_eint_int_arg_{0}bits".format(p)) + print("program: |") + print(" func.func @main(%arg0: !FHE.esint<{0}>, %arg1: i{1}) -> !FHE.esint<{0}> {{".format(p, integer_bitwidth)) + print(" %0 = \"FHE.mul_eint_int\"(%arg0, %arg1): (!FHE.esint<{0}>, i{1}) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth)) + print(" return %0: !FHE.esint<{0}>".format(p)) + print(" }") + print("tests:") + print(" - inputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(0)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + print(" - scalar: {0}".format(1)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(1)) + print(" signed: true") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(min_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - scalar: {0}".format(1)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(1)) + print(" signed: true") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - scalar: {0}".format(-1)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(min_value + 1)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(-1)) + print(" signed: true") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(min_value + 1)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(min_value + 1)) + print(" signed: true") + print(" - scalar: {0}".format(-1)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(-1)) + print(" signed: true") + print(" - scalar: {0}".format(min_value + 1)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(max_value)) + print(" signed: true") + if p > 2: + print(" - inputs:") + print(" - scalar: {0}".format(3)) + print(" signed: true") + print(" - scalar: {0}".format(1)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(3)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(3)) + print(" signed: true") + print(" - scalar: {0}".format(-1)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(-3)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(-3)) + print(" signed: true") + print(" - scalar: {0}".format(1)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(-3)) + print(" signed: true") + print(" - inputs:") + print(" - scalar: {0}".format(-3)) + print(" signed: true") + print(" - scalar: {0}".format(-1)) + print(" signed: true") + print(" outputs:") + print(" - scalar: {0}".format(3)) + print(" signed: true") + may_check_error_rate() + + print("---") if __name__ == "__main__": main() diff --git a/compiler/tests/python/test_argument_support.py b/compiler/tests/python/test_argument_support.py index 08eb77f80..5ed872ed0 100644 --- a/compiler/tests/python/test_argument_support.py +++ b/compiler/tests/python/test_argument_support.py @@ -11,11 +11,8 @@ from concrete.compiler import ClientSupport pytest.param([0, 1, 2], id="list"), pytest.param(0.5, id="float"), pytest.param(2**70, id="large int"), - pytest.param(-8, id="negative int"), pytest.param("aze", id="str"), pytest.param(np.float64(0.8), id="np.float64"), - pytest.param(np.int8(9), id="np.int8"), - pytest.param(np.array([1, 2, 3], dtype=np.int64), id="np.array(np.int64)"), ], ) def test_invalid_arg_type(garbage): diff --git a/compiler/tests/python/test_client_server.py b/compiler/tests/python/test_client_server.py index 11c671455..377edf4d5 100644 --- a/compiler/tests/python/test_client_server.py +++ b/compiler/tests/python/test_client_server.py @@ -108,5 +108,7 @@ def test_client_server_end_to_end(mlir, args, expected_result, keyset_cache): client_parameters, result_serialized ) - output = ClientSupport.decrypt_result(keyset, result_unserialized) + output = ClientSupport.decrypt_result( + client_parameters, keyset, result_unserialized + ) assert np.array_equal(output, expected_result) diff --git a/compiler/tests/python/test_compilation.py b/compiler/tests/python/test_compilation.py index 0b7f7477a..dba612cf3 100644 --- a/compiler/tests/python/test_compilation.py +++ b/compiler/tests/python/test_compilation.py @@ -47,7 +47,7 @@ def run(engine, args, compilation_result, keyset_cache): evaluation_keys = key_set.get_evaluation_keys() public_result = engine.server_call(server_lambda, public_arguments, evaluation_keys) # Client - result = ClientSupport.decrypt_result(key_set, public_result) + result = ClientSupport.decrypt_result(client_parameters, key_set, public_result) return result diff --git a/compiler/tests/unit_tests/concretelang/ClientLib/ClientParameters.cpp b/compiler/tests/unit_tests/concretelang/ClientLib/ClientParameters.cpp index 6128831ed..ceae4f269 100644 --- a/compiler/tests/unit_tests/concretelang/ClientLib/ClientParameters.cpp +++ b/compiler/tests/unit_tests/concretelang/ClientLib/ClientParameters.cpp @@ -42,17 +42,20 @@ TEST(Support, client_parameters_json_serde) { }}}; params0.inputs = { { - /*.encryption = */ {{clientlib::SMALL_KEY, 0.00, {4, {1, 2, 3, 4}}}}, + /*.encryption = */ { + {clientlib::SMALL_KEY, 0.00, {4, {1, 2, 3, 4}, false}}}, /*.shape = */ {32, {1, 2, 3, 4}, 1 * 2 * 3 * 4, false}, }, { - /*.encryption = */ {{clientlib::SMALL_KEY, 0.00, {5, {1, 2, 3, 4}}}}, + /*.encryption = */ { + {clientlib::SMALL_KEY, 0.00, {5, {1, 2, 3, 4}, false}}}, /*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4, false}, }, }; params0.outputs = { { - /*.encryption = */ {{clientlib::SMALL_KEY, 0.00, {5, {1, 2, 3, 4}}}}, + /*.encryption = */ { + {clientlib::SMALL_KEY, 0.00, {5, {1, 2, 3, 4}, false}}}, /*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4, false}, }, };