mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: support signed execution
Author: aPere3 <alexandre.pere@zama.ai> Co-authored-by: Umut <umutsahin@protonmail.com>
This commit is contained in:
committed by
Quentin Bourgerie
parent
f913c39e5b
commit
e95c53f2ff
@@ -133,9 +133,10 @@ static inline bool operator==(const PackingKeySwitchParam &lhs,
|
||||
struct Encoding {
|
||||
Precision precision;
|
||||
CRTDecomposition crt;
|
||||
bool isSigned;
|
||||
};
|
||||
static inline bool operator==(const Encoding &lhs, const Encoding &rhs) {
|
||||
return lhs.precision == rhs.precision;
|
||||
return lhs.precision == rhs.precision && lhs.isSigned == rhs.isSigned;
|
||||
}
|
||||
|
||||
struct EncryptionGate {
|
||||
|
||||
@@ -71,7 +71,8 @@ def BConcrete_EncodeExpandLutForBootstrapTensorOp : BConcrete_Op<"encode_expand_
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]> : $input_lookup_table,
|
||||
I32Attr: $polySize,
|
||||
I32Attr: $outputBits
|
||||
I32Attr: $outputBits,
|
||||
BoolAttr: $isSigned
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[I64]> : $result);
|
||||
@@ -86,7 +87,8 @@ def BConcrete_EncodeExpandLutForWopPBSTensorOp : BConcrete_Op<"encode_expand_lut
|
||||
I64ArrayAttr: $crtDecomposition,
|
||||
I64ArrayAttr: $crtBits,
|
||||
I32Attr : $polySize,
|
||||
I32Attr : $modulusProduct
|
||||
I32Attr : $modulusProduct,
|
||||
BoolAttr: $isSigned
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[I64]> : $result);
|
||||
@@ -226,7 +228,8 @@ def BConcrete_EncodeExpandLutForBootstrapBufferOp : BConcrete_Op<"encode_expand_
|
||||
BConcrete_LutBuffer: $result,
|
||||
BConcrete_LutBuffer: $input_lookup_table,
|
||||
I32Attr: $polySize,
|
||||
I32Attr: $outputBits
|
||||
I32Attr: $outputBits,
|
||||
BoolAttr : $isSigned
|
||||
);
|
||||
}
|
||||
|
||||
@@ -240,7 +243,8 @@ def BConcrete_EncodeExpandLutForWopPBSBufferOp : BConcrete_Op<"encode_expand_lut
|
||||
I64ArrayAttr: $crtDecomposition,
|
||||
I64ArrayAttr: $crtBits,
|
||||
I32Attr : $polySize,
|
||||
I32Attr : $modulusProduct
|
||||
I32Attr : $modulusProduct,
|
||||
BoolAttr: $isSigned
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -60,7 +60,8 @@ def Concrete_EncodeExpandLutForBootstrapOp : Concrete_Op<"encode_expand_lut_for_
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]> : $input_lookup_table,
|
||||
I32Attr: $polySize,
|
||||
I32Attr: $outputBits
|
||||
I32Attr: $outputBits,
|
||||
BoolAttr: $isSigned
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[I64]> : $result);
|
||||
@@ -75,7 +76,8 @@ let summary =
|
||||
I64ArrayAttr: $crtDecomposition,
|
||||
I64ArrayAttr: $crtBits,
|
||||
I32Attr : $polySize,
|
||||
I32Attr : $modulusProduct
|
||||
I32Attr : $modulusProduct,
|
||||
BoolAttr: $isSigned
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[I64]> : $result);
|
||||
|
||||
@@ -355,9 +355,9 @@ def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [NoSideEffect]> {
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FHE_EncryptedIntegerType:$a,
|
||||
let arguments = (ins FHE_AnyEncryptedInteger:$a,
|
||||
TensorOf<[AnyInteger]>:$lut);
|
||||
let results = (outs FHE_EncryptedIntegerType);
|
||||
let results = (outs FHE_AnyEncryptedInteger);
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
|
||||
@@ -26,7 +26,8 @@ def TFHE_EncodeExpandLutForBootstrapOp : TFHE_Op<"encode_expand_lut_for_bootstra
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]> : $input_lookup_table,
|
||||
I32Attr: $polySize,
|
||||
I32Attr: $outputBits
|
||||
I32Attr: $outputBits,
|
||||
BoolAttr: $isSigned
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[I64]> : $result);
|
||||
@@ -41,7 +42,8 @@ def TFHE_EncodeExpandLutForWopPBSOp : TFHE_Op<"encode_expand_lut_for_woppbs"> {
|
||||
I64ArrayAttr: $crtDecomposition,
|
||||
I64ArrayAttr: $crtBits,
|
||||
I32Attr : $polySize,
|
||||
I32Attr : $modulusProduct
|
||||
I32Attr : $modulusProduct,
|
||||
BoolAttr: $isSigned
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[I64]> : $result);
|
||||
|
||||
@@ -27,7 +27,7 @@ void memref_encode_expand_lut_for_bootstrap(
|
||||
uint64_t output_lut_stride, uint64_t *input_lut_allocated,
|
||||
uint64_t *input_lut_aligned, uint64_t input_lut_offset,
|
||||
uint64_t input_lut_size, uint64_t input_lut_stride, uint32_t poly_size,
|
||||
uint32_t out_MESSAGE_BITS);
|
||||
uint32_t out_MESSAGE_BITS, bool is_signed);
|
||||
|
||||
void memref_encode_expand_lut_for_woppbs(
|
||||
uint64_t *output_lut_allocated, uint64_t *output_lut_aligned,
|
||||
@@ -40,7 +40,7 @@ void memref_encode_expand_lut_for_woppbs(
|
||||
uint64_t crt_decomposition_stride, uint64_t *crt_bits_allocated,
|
||||
uint64_t *crt_bits_aligned, uint64_t crt_bits_offset,
|
||||
uint64_t crt_bits_size, uint64_t crt_bits_stride, uint32_t poly_size,
|
||||
uint32_t modulus_product);
|
||||
uint32_t modulus_product, bool is_signed);
|
||||
|
||||
void memref_encode_plaintext_with_crt(
|
||||
uint64_t *output_allocated, uint64_t *output_aligned,
|
||||
|
||||
@@ -290,7 +290,8 @@ public:
|
||||
// treatment, since it may alias none of the fixed size integer
|
||||
// types
|
||||
llvm::Expected<bool> successOrError =
|
||||
LambdaArgumentAdaptor::tryAddArg<uint64_t, uint32_t, uint16_t, uint8_t,
|
||||
LambdaArgumentAdaptor::tryAddArg<int64_t, int32_t, int16_t, int8_t,
|
||||
uint64_t, uint32_t, uint16_t, uint8_t,
|
||||
size_t>(encryptedArgs, arg, keySet);
|
||||
|
||||
if (!successOrError)
|
||||
|
||||
@@ -238,6 +238,19 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
[](mlir::concretelang::ClientParameters &clientParameters) {
|
||||
return pybind11::bytes(
|
||||
clientParametersSerialize(clientParameters));
|
||||
})
|
||||
.def("output_signs",
|
||||
[](mlir::concretelang::ClientParameters &clientParameters) {
|
||||
std::vector<bool> result;
|
||||
for (auto output : clientParameters.outputs) {
|
||||
if (output.encryption.hasValue()) {
|
||||
result.push_back(
|
||||
output.encryption.getValue().encoding.isSigned);
|
||||
} else {
|
||||
result.push_back(true);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
});
|
||||
|
||||
pybind11::class_<clientlib::KeySet>(m, "KeySet")
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
"""Client parameters."""
|
||||
|
||||
from typing import List
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
ClientParameters as _ClientParameters,
|
||||
@@ -35,6 +37,14 @@ class ClientParameters(WrapperCpp):
|
||||
)
|
||||
super().__init__(client_parameters)
|
||||
|
||||
def output_signs(self) -> List[bool]:
|
||||
"""Return the sign information of outputs.
|
||||
|
||||
Returns:
|
||||
List[bool]: list of booleans to indicate whether the outputs are signed or not
|
||||
"""
|
||||
return self.cpp().output_signs()
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""Serialize the ClientParameters.
|
||||
|
||||
|
||||
@@ -123,11 +123,14 @@ class ClientSupport(WrapperCpp):
|
||||
|
||||
@staticmethod
|
||||
def decrypt_result(
|
||||
keyset: KeySet, public_result: PublicResult
|
||||
client_parameters: ClientParameters,
|
||||
keyset: KeySet,
|
||||
public_result: PublicResult,
|
||||
) -> Union[int, np.ndarray]:
|
||||
"""Decrypt a public result using the keyset.
|
||||
|
||||
Args:
|
||||
client_parameters (ClientParameters): client parameters for decryption
|
||||
keyset (KeySet): keyset used for decryption
|
||||
public_result: public result to decrypt
|
||||
|
||||
@@ -148,12 +151,20 @@ class ClientSupport(WrapperCpp):
|
||||
lambda_arg = LambdaArgument.wrap(
|
||||
_ClientSupport.decrypt_result(keyset.cpp(), public_result.cpp())
|
||||
)
|
||||
|
||||
output_signs = client_parameters.output_signs()
|
||||
assert len(output_signs) == 1
|
||||
|
||||
is_signed = output_signs[0]
|
||||
if lambda_arg.is_scalar():
|
||||
return lambda_arg.get_scalar()
|
||||
result = lambda_arg.get_scalar()
|
||||
return (
|
||||
result if not is_signed else int(np.array([result]).astype(np.int64)[0])
|
||||
)
|
||||
if lambda_arg.is_tensor():
|
||||
shape = lambda_arg.get_tensor_shape()
|
||||
tensor = np.array(lambda_arg.get_tensor_data()).reshape(shape)
|
||||
return tensor
|
||||
return tensor if not is_signed else tensor.astype(np.int64)
|
||||
raise RuntimeError("unknown return type")
|
||||
|
||||
@staticmethod
|
||||
@@ -171,29 +182,42 @@ class ClientSupport(WrapperCpp):
|
||||
"""
|
||||
if not isinstance(value, ACCEPTED_TYPES):
|
||||
raise TypeError(
|
||||
"value of lambda argument must be either int, numpy.array or numpy.uint{8,16,32,64}"
|
||||
"value of lambda argument must be either int, numpy.array or numpy.(u)int{8,16,32,64}"
|
||||
)
|
||||
if isinstance(value, ACCEPTED_INTS):
|
||||
if isinstance(value, int) and not 0 <= value < np.iinfo(np.uint64).max:
|
||||
if (
|
||||
isinstance(value, int)
|
||||
and not np.iinfo(np.int64).min <= value < np.iinfo(np.uint64).max
|
||||
):
|
||||
raise TypeError(
|
||||
"single integer must be in the range [0, 2**64 - 1] (uint64)"
|
||||
"single integer must be in the range [-2**63, 2**64 - 1]"
|
||||
)
|
||||
if value < 0:
|
||||
value = int(np.int64(value).astype(np.uint64))
|
||||
return LambdaArgument.from_scalar(value)
|
||||
assert isinstance(value, np.ndarray)
|
||||
if value.dtype not in ACCEPTED_NUMPY_UINTS:
|
||||
raise TypeError("numpy.array must be of dtype uint{8,16,32,64}")
|
||||
raise TypeError("numpy.array must be of dtype (u)int{8,16,32,64}")
|
||||
if value.shape == ():
|
||||
if isinstance(value, np.ndarray):
|
||||
# extract the single element
|
||||
value = value.max()
|
||||
# should be a single uint here
|
||||
return LambdaArgument.from_scalar(value)
|
||||
if value.dtype == np.uint8:
|
||||
return LambdaArgument.from_tensor_8(value.flatten().tolist(), value.shape)
|
||||
if value.dtype == np.uint16:
|
||||
return LambdaArgument.from_tensor_16(value.flatten().tolist(), value.shape)
|
||||
if value.dtype == np.uint32:
|
||||
return LambdaArgument.from_tensor_32(value.flatten().tolist(), value.shape)
|
||||
if value.dtype == np.uint64:
|
||||
return LambdaArgument.from_tensor_64(value.flatten().tolist(), value.shape)
|
||||
raise TypeError("numpy.array must be of dtype uint{8,16,32,64}")
|
||||
if value.dtype in [np.uint8, np.int8]:
|
||||
return LambdaArgument.from_tensor_8(
|
||||
value.astype(np.uint8).flatten().tolist(), value.shape
|
||||
)
|
||||
if value.dtype in [np.uint16, np.int16]:
|
||||
return LambdaArgument.from_tensor_16(
|
||||
value.astype(np.uint16).flatten().tolist(), value.shape
|
||||
)
|
||||
if value.dtype in [np.uint32, np.int32]:
|
||||
return LambdaArgument.from_tensor_32(
|
||||
value.astype(np.uint32).flatten().tolist(), value.shape
|
||||
)
|
||||
if value.dtype in [np.uint64, np.int64]:
|
||||
return LambdaArgument.from_tensor_64(
|
||||
value.astype(np.uint64).flatten().tolist(), value.shape
|
||||
)
|
||||
raise TypeError("numpy.array must be of dtype (u)int{8,16,32,64}")
|
||||
|
||||
@@ -58,7 +58,7 @@ class LambdaArgument(WrapperCpp):
|
||||
"""
|
||||
if not isinstance(scalar, ACCEPTED_INTS):
|
||||
raise TypeError(
|
||||
f"scalar must be of type int or numpy.uint, not {type(scalar)}"
|
||||
f"scalar must be of type int or numpy.(u)int, not {type(scalar)}"
|
||||
)
|
||||
return LambdaArgument.wrap(_LambdaArgument.from_scalar(scalar))
|
||||
|
||||
|
||||
@@ -6,7 +6,16 @@ import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
ACCEPTED_NUMPY_UINTS = (np.uint8, np.uint16, np.uint32, np.uint64)
|
||||
ACCEPTED_NUMPY_UINTS = (
|
||||
np.int8,
|
||||
np.int16,
|
||||
np.int32,
|
||||
np.int64,
|
||||
np.uint8,
|
||||
np.uint16,
|
||||
np.uint32,
|
||||
np.uint64,
|
||||
)
|
||||
ACCEPTED_INTS = (int,) + ACCEPTED_NUMPY_UINTS
|
||||
ACCEPTED_TYPES = (np.ndarray,) + ACCEPTED_INTS
|
||||
|
||||
|
||||
@@ -306,6 +306,7 @@ bool fromJSON(const llvm::json::Value j, CircuitGateShape &v,
|
||||
llvm::json::Value toJSON(const Encoding &v) {
|
||||
llvm::json::Object object{
|
||||
{"precision", v.precision},
|
||||
{"isSigned", v.isSigned},
|
||||
};
|
||||
if (!v.crt.empty()) {
|
||||
object.insert({"crt", v.crt});
|
||||
@@ -324,6 +325,12 @@ bool fromJSON(const llvm::json::Value j, Encoding &v, llvm::json::Path p) {
|
||||
return false;
|
||||
}
|
||||
v.precision = precision.getValue();
|
||||
auto isSigned = obj->getBoolean("isSigned");
|
||||
if (!isSigned.hasValue()) {
|
||||
p.report("missing isSigned field");
|
||||
return false;
|
||||
}
|
||||
v.isSigned = isSigned.getValue();
|
||||
auto crt = obj->getArray("crt");
|
||||
if (crt != nullptr) {
|
||||
for (auto dim : *crt) {
|
||||
|
||||
@@ -355,37 +355,62 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
|
||||
if (!encryption.hasValue()) {
|
||||
return StringError("decrypt_lwe: the positional argument is not encrypted");
|
||||
}
|
||||
|
||||
auto crt = encryption->encoding.crt;
|
||||
// CRT encoding - N blocks with crt encoding
|
||||
if (!crt.empty()) {
|
||||
|
||||
if (!crt.empty()) { // The ciphertext used the crt strategy.
|
||||
|
||||
// Decrypt and decode remainders
|
||||
std::vector<int64_t> remainders;
|
||||
// decrypt and decode remainders
|
||||
for (auto modulus : crt) {
|
||||
uint64_t decrypted;
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_decrypt_lwe_ciphertext_u64_raw_ptr_buffers(
|
||||
engine, lweSecretKey, ciphertext, &decrypted));
|
||||
|
||||
auto plaintext = crt::decode(decrypted, modulus);
|
||||
remainders.push_back(plaintext);
|
||||
ciphertext = ciphertext + lweSecretKeyParam.lweSize();
|
||||
}
|
||||
// compute the inverse crt
|
||||
|
||||
// Compute the inverse crt
|
||||
output = crt::iCrt(crt, remainders);
|
||||
return outcome::success();
|
||||
|
||||
// Further decode signed integers
|
||||
if (encryption->encoding.isSigned) {
|
||||
uint64_t maxPos = 1;
|
||||
for (auto prime : encryption->encoding.crt) {
|
||||
maxPos *= prime;
|
||||
}
|
||||
maxPos /= 2;
|
||||
if (output >= maxPos) {
|
||||
output -= maxPos * 2;
|
||||
}
|
||||
}
|
||||
} else { // The ciphertext used the scalar strategy
|
||||
|
||||
// Decrypt
|
||||
uint64_t plaintext;
|
||||
CAPI_ASSERT_ERROR(default_engine_decrypt_lwe_ciphertext_u64_raw_ptr_buffers(
|
||||
engine, lweSecretKey, ciphertext, &plaintext));
|
||||
|
||||
// Decode unsigned integer
|
||||
uint64_t precision = encryption->encoding.precision;
|
||||
output = plaintext >> (64 - precision - 2);
|
||||
auto carry = output % 2;
|
||||
uint64_t mod = (((uint64_t)1) << (precision + 1));
|
||||
output = ((output >> 1) + carry) % mod;
|
||||
|
||||
// Further decode signed integers.
|
||||
if (encryption->encoding.isSigned) {
|
||||
uint64_t maxPos = (((uint64_t)1) << (precision - 1));
|
||||
if (output >= maxPos) { // The output is actually negative.
|
||||
// Set the preceding bits to zero
|
||||
output |= UINT64_MAX << precision;
|
||||
// This makes sure when the value is cast to int64, it has the correct
|
||||
// value
|
||||
};
|
||||
}
|
||||
}
|
||||
// Simple TFHE integers - 1 blocks with one padding bits
|
||||
uint64_t plaintext;
|
||||
|
||||
CAPI_ASSERT_ERROR(default_engine_decrypt_lwe_ciphertext_u64_raw_ptr_buffers(
|
||||
engine, lweSecretKey, ciphertext, &plaintext));
|
||||
|
||||
// Decode
|
||||
uint64_t precision = encryption->encoding.precision;
|
||||
output = plaintext >> (64 - precision - 2);
|
||||
auto carry = output % 2;
|
||||
uint64_t mod = (((uint64_t)1) << (precision + 1));
|
||||
output = ((output >> 1) + carry) % mod;
|
||||
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
@@ -174,16 +174,16 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
memref1DType, rewriter.getI64Type()},
|
||||
{});
|
||||
} else if (funcName == memref_encode_expand_lut_for_bootstrap) {
|
||||
funcType =
|
||||
mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType,
|
||||
rewriter.getI32Type(), rewriter.getI32Type()},
|
||||
{});
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{memref1DType, memref1DType, rewriter.getI32Type(),
|
||||
rewriter.getI32Type(), rewriter.getI1Type()},
|
||||
{});
|
||||
} else if (funcName == memref_encode_expand_lut_for_woppbs) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{memref1DType, memref1DType, memref1DType, memref1DType,
|
||||
rewriter.getI32Type(), rewriter.getI32Type()},
|
||||
rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI1Type()},
|
||||
{});
|
||||
} else {
|
||||
op->emitError("unknwon external function") << funcName;
|
||||
@@ -359,6 +359,9 @@ void encodeExpandLutForBootstrapAddOperands(
|
||||
// output bits
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.outputBitsAttr()));
|
||||
// is_signed
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.isSignedAttr()));
|
||||
}
|
||||
|
||||
void encodeExpandLutForWopPBSAddOperands(
|
||||
@@ -409,6 +412,9 @@ void encodeExpandLutForWopPBSAddOperands(
|
||||
// modulus_product
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.modulusProductAttr()));
|
||||
// is_signed
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.isSignedAttr()));
|
||||
}
|
||||
|
||||
struct BConcreteToCAPIPass : public BConcreteToCAPIBase<BConcreteToCAPIPass> {
|
||||
|
||||
@@ -38,41 +38,42 @@ namespace fhe_to_tfhe_crt_conversion {
|
||||
|
||||
namespace typing {
|
||||
|
||||
/// Converts `FHE::EncryptedInteger` into `Tensor<TFHE::GlweCiphetext>`.
|
||||
mlir::RankedTensorType convertEint(mlir::MLIRContext *context,
|
||||
FHE::EncryptedIntegerType eint,
|
||||
uint64_t crtLength) {
|
||||
/// Converts an encrypted integer into `TFHE::GlweCiphertext`.
|
||||
mlir::RankedTensorType convertEncrypted(mlir::MLIRContext *context,
|
||||
FHE::FheIntegerInterface enc,
|
||||
uint64_t crtLength) {
|
||||
return mlir::RankedTensorType::get(
|
||||
mlir::ArrayRef<int64_t>((int64_t)crtLength),
|
||||
TFHE::GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth()));
|
||||
TFHE::GLWECipherTextType::get(context, -1, -1, -1, enc.getWidth()));
|
||||
}
|
||||
|
||||
/// Converts `Tensor<FHE::EncryptedInteger>` into a
|
||||
/// `Tensor<TFHE::GlweCiphertext>` if the element type is appropriate. Otherwise
|
||||
/// return the input type.
|
||||
mlir::Type maybeConvertEintTensor(mlir::MLIRContext *context,
|
||||
mlir::RankedTensorType maybeEintTensor,
|
||||
uint64_t crtLength) {
|
||||
if (!maybeEintTensor.getElementType().isa<FHE::EncryptedIntegerType>()) {
|
||||
return (mlir::Type)(maybeEintTensor);
|
||||
/// Converts `Tensor<FHE::AnyEncryptedInteger>` into a
|
||||
/// `Tensor<TFHE::GlweCiphertext>` if the element type is appropriate.
|
||||
/// Otherwise return the input type.
|
||||
mlir::Type
|
||||
maybeConvertEncryptedTensor(mlir::MLIRContext *context,
|
||||
mlir::RankedTensorType maybeEncryptedTensor,
|
||||
uint64_t crtLength) {
|
||||
if (!maybeEncryptedTensor.getElementType().isa<FHE::FheIntegerInterface>()) {
|
||||
return (mlir::Type)(maybeEncryptedTensor);
|
||||
}
|
||||
auto eint =
|
||||
maybeEintTensor.getElementType().cast<FHE::EncryptedIntegerType>();
|
||||
auto currentShape = maybeEintTensor.getShape();
|
||||
auto encType =
|
||||
maybeEncryptedTensor.getElementType().cast<FHE::FheIntegerInterface>();
|
||||
auto currentShape = maybeEncryptedTensor.getShape();
|
||||
mlir::SmallVector<int64_t> newShape =
|
||||
mlir::SmallVector<int64_t>(currentShape.begin(), currentShape.end());
|
||||
newShape.push_back((int64_t)crtLength);
|
||||
return mlir::RankedTensorType::get(
|
||||
llvm::ArrayRef<int64_t>(newShape),
|
||||
TFHE::GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth()));
|
||||
TFHE::GLWECipherTextType::get(context, -1, -1, -1, encType.getWidth()));
|
||||
}
|
||||
|
||||
/// Converts the type `FHE::EncryptedInteger` to `Tensor<TFHE::GlweCiphetext>`
|
||||
/// if the input type is appropriate. Otherwise return the input type.
|
||||
mlir::Type maybeConvertEint(mlir::MLIRContext *context, mlir::Type t,
|
||||
uint64_t crtLength) {
|
||||
if (auto eint = t.dyn_cast<FHE::EncryptedIntegerType>())
|
||||
return convertEint(context, eint, crtLength);
|
||||
/// Converts any encrypted type to `TFHE::GlweCiphetext` if the
|
||||
/// input type is appropriate. Otherwise return the input type.
|
||||
mlir::Type maybeConvertEncrypted(mlir::MLIRContext *context, mlir::Type t,
|
||||
uint64_t crtLength) {
|
||||
if (auto eint = t.dyn_cast<FHE::FheIntegerInterface>())
|
||||
return convertEncrypted(context, eint, crtLength);
|
||||
|
||||
return t;
|
||||
}
|
||||
@@ -85,11 +86,11 @@ public:
|
||||
TypeConverter(concretelang::CrtLoweringParameters loweringParameters) {
|
||||
size_t nMods = loweringParameters.nMods;
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([=](FHE::EncryptedIntegerType type) {
|
||||
return convertEint(type.getContext(), type, nMods);
|
||||
addConversion([=](FHE::FheIntegerInterface type) {
|
||||
return convertEncrypted(type.getContext(), type, nMods);
|
||||
});
|
||||
addConversion([=](mlir::RankedTensorType type) {
|
||||
return maybeConvertEintTensor(type.getContext(), type, nMods);
|
||||
return maybeConvertEncryptedTensor(type.getContext(), type, nMods);
|
||||
});
|
||||
addConversion([&](concretelang::RT::FutureType type) {
|
||||
return concretelang::RT::FutureType::get(this->convertType(
|
||||
@@ -517,6 +518,8 @@ struct ApplyLookupTableEintOpPattern
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::TypeConverter *converter = this->getTypeConverter();
|
||||
|
||||
auto originalInputType = op.a().getType().cast<FHE::FheIntegerInterface>();
|
||||
|
||||
mlir::Value newLut =
|
||||
rewriter
|
||||
.create<TFHE::EncodeExpandLutForWopPBSOp>(
|
||||
@@ -530,7 +533,8 @@ struct ApplyLookupTableEintOpPattern
|
||||
rewriter.getI64ArrayAttr(
|
||||
mlir::ArrayRef<int64_t>(loweringParameters.bits)),
|
||||
rewriter.getI32IntegerAttr(loweringParameters.polynomialSize),
|
||||
rewriter.getI32IntegerAttr(loweringParameters.modsProd))
|
||||
rewriter.getI32IntegerAttr(loweringParameters.modsProd),
|
||||
rewriter.getBoolAttr(originalInputType.isSigned()))
|
||||
.getResult();
|
||||
|
||||
// Replace the lut with an encoded / expanded one.
|
||||
|
||||
@@ -37,34 +37,34 @@ namespace fhe_to_tfhe_scalar_conversion {
|
||||
|
||||
namespace typing {
|
||||
|
||||
/// Converts `FHE::EncryptedInteger` into `TFHE::GlweCiphetext`.
|
||||
TFHE::GLWECipherTextType convertEint(mlir::MLIRContext *context,
|
||||
FHE::EncryptedIntegerType eint) {
|
||||
return TFHE::GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth());
|
||||
/// Converts an encrypted integer into `TFHE::GlweCiphetext`.
|
||||
TFHE::GLWECipherTextType convertEncrypted(mlir::MLIRContext *context,
|
||||
FHE::FheIntegerInterface enc) {
|
||||
return TFHE::GLWECipherTextType::get(context, -1, -1, -1, enc.getWidth());
|
||||
}
|
||||
|
||||
/// Converts `Tensor<FHE::EncryptedInteger>` into a
|
||||
/// Converts `Tensor<FHE::AnyEncryptedInteger>` into a
|
||||
/// `Tensor<TFHE::GlweCiphertext>` if the element type is appropriate.
|
||||
/// Otherwise return the input type.
|
||||
mlir::Type maybeConvertEintTensor(mlir::MLIRContext *context,
|
||||
mlir::RankedTensorType maybeEintTensor) {
|
||||
if (!maybeEintTensor.getElementType().isa<FHE::EncryptedIntegerType>()) {
|
||||
return (mlir::Type)(maybeEintTensor);
|
||||
mlir::Type
|
||||
maybeConvertEncryptedTensor(mlir::MLIRContext *context,
|
||||
mlir::RankedTensorType maybeEncryptedTensor) {
|
||||
if (!maybeEncryptedTensor.getElementType().isa<FHE::FheIntegerInterface>()) {
|
||||
return (mlir::Type)(maybeEncryptedTensor);
|
||||
}
|
||||
auto eint =
|
||||
maybeEintTensor.getElementType().cast<FHE::EncryptedIntegerType>();
|
||||
auto currentShape = maybeEintTensor.getShape();
|
||||
auto enc =
|
||||
maybeEncryptedTensor.getElementType().cast<FHE::FheIntegerInterface>();
|
||||
auto currentShape = maybeEncryptedTensor.getShape();
|
||||
return mlir::RankedTensorType::get(
|
||||
currentShape,
|
||||
TFHE::GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth()));
|
||||
TFHE::GLWECipherTextType::get(context, -1, -1, -1, enc.getWidth()));
|
||||
}
|
||||
|
||||
/// Converts the type `FHE::EncryptedInteger` to `TFHE::GlweCiphetext` if the
|
||||
/// Converts any encrypted type to `TFHE::GlweCiphetext` if the
|
||||
/// input type is appropriate. Otherwise return the input type.
|
||||
mlir::Type maybeConvertEint(mlir::MLIRContext *context, mlir::Type t) {
|
||||
if (auto eint = t.dyn_cast<FHE::EncryptedIntegerType>())
|
||||
return convertEint(context, eint);
|
||||
|
||||
mlir::Type maybeConvertEncrypted(mlir::MLIRContext *context, mlir::Type t) {
|
||||
if (auto eint = t.dyn_cast<FHE::FheIntegerInterface>())
|
||||
return convertEncrypted(context, eint);
|
||||
return t;
|
||||
}
|
||||
|
||||
@@ -75,8 +75,8 @@ class TypeConverter : public mlir::TypeConverter {
|
||||
public:
|
||||
TypeConverter() {
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([](FHE::EncryptedIntegerType type) {
|
||||
return convertEint(type.getContext(), type);
|
||||
addConversion([](FHE::FheIntegerInterface type) {
|
||||
return convertEncrypted(type.getContext(), type);
|
||||
});
|
||||
addConversion([](FHE::EncryptedBooleanType type) {
|
||||
return TFHE::GLWECipherTextType::get(
|
||||
@@ -84,7 +84,7 @@ public:
|
||||
mlir::concretelang::FHE::EncryptedBooleanType::getWidth());
|
||||
});
|
||||
addConversion([](mlir::RankedTensorType type) {
|
||||
return maybeConvertEintTensor(type.getContext(), type);
|
||||
return maybeConvertEncryptedTensor(type.getContext(), type);
|
||||
});
|
||||
addConversion([&](concretelang::RT::FutureType type) {
|
||||
return concretelang::RT::FutureType::get(this->convertType(
|
||||
@@ -145,7 +145,7 @@ struct AddEintIntOpPattern : public ScalarOpPattern<FHE::AddEintIntOp> {
|
||||
// Write the plaintext encoding
|
||||
mlir::Value encodedInt = writePlaintextShiftEncoding(
|
||||
op.getLoc(), adaptor.b(),
|
||||
op.getType().cast<FHE::EncryptedIntegerType>().getWidth(), rewriter);
|
||||
op.getType().cast<FHE::FheIntegerInterface>().getWidth(), rewriter);
|
||||
|
||||
// Write the new op
|
||||
rewriter.replaceOpWithNewOp<TFHE::AddGLWEIntOp>(
|
||||
@@ -183,7 +183,7 @@ struct SubEintIntOpPattern : public ScalarOpPattern<FHE::SubEintIntOp> {
|
||||
// Write the plaintext encoding
|
||||
mlir::Value encodedInt = writePlaintextShiftEncoding(
|
||||
op.getLoc(), negative,
|
||||
eintOperand.getType().cast<FHE::EncryptedIntegerType>().getWidth(),
|
||||
eintOperand.getType().cast<FHE::FheIntegerInterface>().getWidth(),
|
||||
rewriter);
|
||||
|
||||
// Write the new op
|
||||
@@ -208,7 +208,7 @@ struct SubIntEintOpPattern : public ScalarOpPattern<FHE::SubIntEintOp> {
|
||||
// Write the plaintext encoding
|
||||
mlir::Value encodedInt = writePlaintextShiftEncoding(
|
||||
op.getLoc(), adaptor.a(),
|
||||
op.b().getType().cast<FHE::EncryptedIntegerType>().getWidth(),
|
||||
op.b().getType().cast<FHE::FheIntegerInterface>().getWidth(),
|
||||
rewriter);
|
||||
|
||||
// Write the new op
|
||||
@@ -290,8 +290,9 @@ struct ApplyLookupTableEintOpPattern
|
||||
FHE::ApplyLookupTableEintOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
auto inputType = op.a().getType().cast<FHE::FheIntegerInterface>();
|
||||
size_t outputBits =
|
||||
op.getResult().getType().cast<FHE::EncryptedIntegerType>().getWidth();
|
||||
op.getResult().getType().cast<FHE::FheIntegerInterface>().getWidth();
|
||||
mlir::Value newLut =
|
||||
rewriter
|
||||
.create<TFHE::EncodeExpandLutForBootstrapOp>(
|
||||
@@ -301,12 +302,36 @@ struct ApplyLookupTableEintOpPattern
|
||||
rewriter.getI64Type()),
|
||||
op.lut(),
|
||||
rewriter.getI32IntegerAttr(loweringParameters.polynomialSize),
|
||||
rewriter.getI32IntegerAttr(outputBits))
|
||||
rewriter.getI32IntegerAttr(outputBits),
|
||||
rewriter.getBoolAttr(inputType.isSigned()))
|
||||
.getResult();
|
||||
|
||||
typing::TypeConverter converter;
|
||||
mlir::Value input = adaptor.a();
|
||||
|
||||
if (inputType.isSigned()) {
|
||||
// If the input is a signed integer, it comes to the bootstrap with a
|
||||
// signed-leveled encoding (compatible with 2s complement semantics).
|
||||
// Unfortunately pbs is not compatible with this encoding, since the
|
||||
// (virtual) msb must be 0 to avoid a lookup in the phantom negative lut.
|
||||
uint64_t constantRaw = (uint64_t)1 << (inputType.getWidth() - 1);
|
||||
// Note that the constant must be encoded with one more bit to ensure the
|
||||
// signed extension used in the plaintext encoding works as expected.
|
||||
mlir::Value constant = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(),
|
||||
rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(inputType.getWidth() + 1), constantRaw));
|
||||
mlir::Value encodedConstant = writePlaintextShiftEncoding(
|
||||
op.getLoc(), constant, inputType.getWidth(), rewriter);
|
||||
auto inputOp = rewriter.create<TFHE::AddGLWEIntOp>(
|
||||
op.getLoc(), converter.convertType(input.getType()), input,
|
||||
encodedConstant);
|
||||
input = inputOp;
|
||||
}
|
||||
|
||||
// Insert keyswitch
|
||||
auto ksOp = rewriter.create<TFHE::KeySwitchGLWEOp>(
|
||||
op.getLoc(), adaptor.a().getType(), adaptor.a(), -1, -1);
|
||||
op.getLoc(), getTypeConverter()->convertType(adaptor.a().getType()), input, -1, -1);
|
||||
|
||||
// Insert bootstrap
|
||||
rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
|
||||
|
||||
@@ -1215,7 +1215,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<mlir::concretelang::FHE::EncryptedIntegerType>()) {
|
||||
.isa<mlir::concretelang::FHE::FheIntegerInterface>()) {
|
||||
norm2SqEquiv = getSqMANP(transposeOp, operands);
|
||||
} else {
|
||||
isDummy = true;
|
||||
@@ -1227,7 +1227,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
else if (auto extractOp = llvm::dyn_cast<mlir::tensor::ExtractOp>(op)) {
|
||||
if (extractOp.result()
|
||||
.getType()
|
||||
.isa<mlir::concretelang::FHE::EncryptedIntegerType>()) {
|
||||
.isa<mlir::concretelang::FHE::FheIntegerInterface>()) {
|
||||
norm2SqEquiv = getSqMANP(extractOp, operands);
|
||||
} else {
|
||||
isDummy = true;
|
||||
@@ -1240,7 +1240,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<mlir::concretelang::FHE::EncryptedIntegerType>()) {
|
||||
.isa<mlir::concretelang::FHE::FheIntegerInterface>()) {
|
||||
norm2SqEquiv = getSqMANP(extractSliceOp, operands);
|
||||
} else {
|
||||
isDummy = true;
|
||||
@@ -1252,7 +1252,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<mlir::concretelang::FHE::EncryptedIntegerType>()) {
|
||||
.isa<mlir::concretelang::FHE::FheIntegerInterface>()) {
|
||||
norm2SqEquiv = getSqMANP(insertOp, operands);
|
||||
} else {
|
||||
isDummy = true;
|
||||
@@ -1265,7 +1265,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<mlir::concretelang::FHE::EncryptedIntegerType>()) {
|
||||
.isa<mlir::concretelang::FHE::FheIntegerInterface>()) {
|
||||
norm2SqEquiv = getSqMANP(insertSliceOp, operands);
|
||||
} else {
|
||||
isDummy = true;
|
||||
@@ -1277,7 +1277,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<mlir::concretelang::FHE::EncryptedIntegerType>()) {
|
||||
.isa<mlir::concretelang::FHE::FheIntegerInterface>()) {
|
||||
norm2SqEquiv = getSqMANP(fromOp, operands);
|
||||
} else {
|
||||
isDummy = true;
|
||||
@@ -1290,7 +1290,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<mlir::concretelang::FHE::EncryptedIntegerType>()) {
|
||||
.isa<mlir::concretelang::FHE::FheIntegerInterface>()) {
|
||||
norm2SqEquiv = getSqMANP(reshapeOp, operands);
|
||||
} else {
|
||||
isDummy = true;
|
||||
@@ -1302,7 +1302,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<mlir::concretelang::FHE::EncryptedIntegerType>()) {
|
||||
.isa<mlir::concretelang::FHE::FheIntegerInterface>()) {
|
||||
norm2SqEquiv = getSqMANP(reshapeOp, operands);
|
||||
} else {
|
||||
isDummy = true;
|
||||
@@ -1410,16 +1410,15 @@ protected:
|
||||
|
||||
// Process all results using MANP attribute from MANP pas
|
||||
for (mlir::OpResult res : op->getResults()) {
|
||||
mlir::concretelang::FHE::EncryptedIntegerType eTy =
|
||||
mlir::concretelang::FHE::FheIntegerInterface eTy =
|
||||
res.getType()
|
||||
.dyn_cast_or_null<
|
||||
mlir::concretelang::FHE::EncryptedIntegerType>();
|
||||
.dyn_cast_or_null<mlir::concretelang::FHE::FheIntegerInterface>();
|
||||
if (eTy == nullptr) {
|
||||
auto tensorTy = res.getType().dyn_cast_or_null<mlir::TensorType>();
|
||||
if (tensorTy != nullptr) {
|
||||
eTy = tensorTy.getElementType()
|
||||
.dyn_cast_or_null<
|
||||
mlir::concretelang::FHE::EncryptedIntegerType>();
|
||||
mlir::concretelang::FHE::FheIntegerInterface>();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,14 +13,13 @@ namespace utils {
|
||||
/// Returns `true` if the given value is a scalar or tensor argument of
|
||||
/// a function, for which a MANP of 1 can be assumed.
|
||||
bool isEncryptedValue(mlir::Value value) {
|
||||
return (
|
||||
value.getType().isa<mlir::concretelang::FHE::EncryptedIntegerType>() ||
|
||||
value.getType().isa<mlir::concretelang::FHE::EncryptedBooleanType>() ||
|
||||
return (value.getType().isa<mlir::concretelang::FHE::FheIntegerInterface>() ||
|
||||
value.getType().isa<mlir::concretelang::FHE::EncryptedBooleanType>() ||
|
||||
(value.getType().isa<mlir::TensorType>() &&
|
||||
value.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<mlir::concretelang::FHE::EncryptedIntegerType>()));
|
||||
.isa<mlir::concretelang::FHE::FheIntegerInterface>()));
|
||||
}
|
||||
|
||||
/// Returns the bit width of `value` if `value` is an encrypted integer,
|
||||
@@ -30,7 +29,7 @@ bool isEncryptedValue(mlir::Value value) {
|
||||
unsigned int getEintPrecision(mlir::Value value) {
|
||||
if (auto ty = value.getType()
|
||||
.dyn_cast_or_null<
|
||||
mlir::concretelang::FHE::EncryptedIntegerType>()) {
|
||||
mlir::concretelang::FHE::FheIntegerInterface>()) {
|
||||
return ty.getWidth();
|
||||
}
|
||||
if (auto ty = value.getType()
|
||||
@@ -41,7 +40,7 @@ unsigned int getEintPrecision(mlir::Value value) {
|
||||
value.getType().dyn_cast_or_null<mlir::TensorType>()) {
|
||||
if (auto ty = tensorTy.getElementType()
|
||||
.dyn_cast_or_null<
|
||||
mlir::concretelang::FHE::EncryptedIntegerType>())
|
||||
mlir::concretelang::FHE::FheIntegerInterface>())
|
||||
return ty.getWidth();
|
||||
}
|
||||
|
||||
|
||||
@@ -228,7 +228,7 @@ mlir::LogicalResult GenGateOp::verify() {
|
||||
}
|
||||
|
||||
::mlir::LogicalResult ApplyLookupTableEintOp::verify() {
|
||||
auto ct = this->a().getType().cast<EncryptedIntegerType>();
|
||||
auto ct = this->a().getType().cast<FheIntegerInterface>();
|
||||
auto lut = this->lut().getType().cast<TensorType>();
|
||||
|
||||
// Check the shape of lut argument
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "concretelang/Runtime/seeder.h"
|
||||
#include <assert.h>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
@@ -253,7 +254,7 @@ void memref_encode_expand_lut_for_bootstrap(
|
||||
uint64_t output_lut_stride, uint64_t *input_lut_allocated,
|
||||
uint64_t *input_lut_aligned, uint64_t input_lut_offset,
|
||||
uint64_t input_lut_size, uint64_t input_lut_stride, uint32_t poly_size,
|
||||
uint32_t out_MESSAGE_BITS) {
|
||||
uint32_t out_MESSAGE_BITS, bool is_signed) {
|
||||
|
||||
assert(input_lut_stride == 1 && "Runtime: stride not equal to 1, check "
|
||||
"memref_encode_expand_lut_bootstrap");
|
||||
@@ -265,19 +266,41 @@ void memref_encode_expand_lut_for_bootstrap(
|
||||
|
||||
assert((mega_case_size % 2) == 0);
|
||||
|
||||
for (size_t idx = 0; idx < mega_case_size / 2; ++idx) {
|
||||
output_lut_aligned[output_lut_offset + idx] =
|
||||
input_lut_aligned[input_lut_offset] << (64 - out_MESSAGE_BITS - 1);
|
||||
// When the bootstrap is executed on encrypted signed integers, the lut must
|
||||
// be half-rotated. This map takes care about properly indexing into the input
|
||||
// lut depending on what bootstrap gets executed.
|
||||
std::function<size_t(size_t)> indexMap;
|
||||
if (is_signed) {
|
||||
size_t halfInputSize = input_lut_size / 2;
|
||||
indexMap = [=](size_t idx) {
|
||||
if (idx < halfInputSize) {
|
||||
return idx + halfInputSize;
|
||||
} else {
|
||||
return idx - halfInputSize;
|
||||
}
|
||||
};
|
||||
} else {
|
||||
indexMap = [=](size_t idx) { return idx; };
|
||||
}
|
||||
|
||||
// The first lut value should be centered over zero. This means that half of
|
||||
// it should appear at the beginning of the output lut, and half of it at the
|
||||
// end (but negated).
|
||||
for (size_t idx = 0; idx < mega_case_size / 2; ++idx) {
|
||||
output_lut_aligned[output_lut_offset + idx] =
|
||||
input_lut_aligned[input_lut_offset + indexMap(0)]
|
||||
<< (64 - out_MESSAGE_BITS - 1);
|
||||
}
|
||||
for (size_t idx = (input_lut_size - 1) * mega_case_size + mega_case_size / 2;
|
||||
idx < output_lut_size; ++idx) {
|
||||
output_lut_aligned[output_lut_offset + idx] =
|
||||
-(input_lut_aligned[input_lut_offset] << (64 - out_MESSAGE_BITS - 1));
|
||||
-(input_lut_aligned[input_lut_offset + indexMap(0)]
|
||||
<< (64 - out_MESSAGE_BITS - 1));
|
||||
}
|
||||
|
||||
// Treats the other ut values.
|
||||
for (size_t lut_idx = 1; lut_idx < input_lut_size; ++lut_idx) {
|
||||
uint64_t lut_value = input_lut_aligned[input_lut_offset + lut_idx]
|
||||
uint64_t lut_value = input_lut_aligned[input_lut_offset + indexMap(lut_idx)]
|
||||
<< (64 - out_MESSAGE_BITS - 1);
|
||||
size_t start = mega_case_size * (lut_idx - 1) + mega_case_size / 2;
|
||||
for (size_t output_idx = start; output_idx < start + mega_case_size;
|
||||
@@ -306,7 +329,7 @@ void memref_encode_expand_lut_for_woppbs(
|
||||
uint64_t *crt_bits_allocated, uint64_t *crt_bits_aligned,
|
||||
uint64_t crt_bits_offset, uint64_t crt_bits_size, uint64_t crt_bits_stride,
|
||||
// Crypto parameters
|
||||
uint32_t poly_size, uint32_t modulus_product) {
|
||||
uint32_t poly_size, uint32_t modulus_product, bool is_signed) {
|
||||
|
||||
assert(input_lut_stride == 1 && "Runtime: stride not equal to 1, check "
|
||||
"memref_encode_expand_lut_woppbs");
|
||||
@@ -314,22 +337,77 @@ void memref_encode_expand_lut_for_woppbs(
|
||||
"memref_encode_expand_lut_woppbs");
|
||||
assert(modulus_product > input_lut_size);
|
||||
|
||||
// When the woppbs is executed on encrypted signed integers, the index of the
|
||||
// lut elements must be adapted to fit the way signed are encrypted in CRT
|
||||
// (to ensure the lookup falls into the proper case).
|
||||
// This map takes care about properly indexing into the output lut depending
|
||||
// on what bootstrap gets executed.
|
||||
std::function<uint64_t(uint64_t)> indexMap;
|
||||
if (!is_signed) {
|
||||
// When not signed, the integer values are encoded in increasing order. That
|
||||
// is (example of 9 bits values, using crt decomposition [5,7,16]):
|
||||
//
|
||||
// |0 511|
|
||||
// |---------|
|
||||
// |0 511|
|
||||
//
|
||||
// is encoded as
|
||||
//
|
||||
// |0 511| INVALID |
|
||||
// |-------|-----------|
|
||||
// |0 511|512 559|
|
||||
//
|
||||
// Where on top are represented the semantic values, and below, the actual
|
||||
// encoding of values, either on uint64_t or as increasing crt values.
|
||||
//
|
||||
// As a consequence, there is nothing particular to do to map the index of
|
||||
// the input lut to an index of the output lut.
|
||||
indexMap = [=](uint64_t plaintext) { return plaintext; };
|
||||
} else {
|
||||
// When signed, the integer values are encoded in a way that resembles 2s
|
||||
// complement. That is (example of 9 bits values, using crt decomposition
|
||||
// [5,7,16]):
|
||||
//
|
||||
// |0 255|-256 -1|
|
||||
// |---------|----------|
|
||||
// |0 255|256 511|
|
||||
//
|
||||
// is encoded as
|
||||
//
|
||||
// |0 255| INVALID |-256 -1|
|
||||
// |---------|-------------|----------|
|
||||
// |0 255|256 303|304 559|
|
||||
//
|
||||
// Where on top are represented the semantic values, and below, the actual
|
||||
// encoding of values, either on uint64_t or as increasing crt values.
|
||||
//
|
||||
// As a consequence, to map the index of the input lut to an index of the
|
||||
// output lut we must take care of crossing the invalid range in between
|
||||
// positive values and negative values.
|
||||
indexMap = [=](uint64_t plaintext) {
|
||||
if (plaintext >= (input_lut_size / 2)) {
|
||||
plaintext += modulus_product - input_lut_size;
|
||||
}
|
||||
return plaintext;
|
||||
};
|
||||
}
|
||||
|
||||
uint64_t lut_crt_size = output_lut_size / crt_decomposition_size;
|
||||
|
||||
for (uint64_t value = 0; value < input_lut_size; value++) {
|
||||
for (uint64_t index = 0; index < input_lut_size; index++) {
|
||||
uint64_t index_lut = 0;
|
||||
uint64_t tmp = 1;
|
||||
|
||||
for (size_t block = 0; block < crt_decomposition_size; block++) {
|
||||
auto base = crt_decomposition_aligned[crt_decomposition_offset + block];
|
||||
auto bits = crt_bits_aligned[crt_bits_offset + block];
|
||||
index_lut += (((value % base) << bits) / base) * tmp;
|
||||
index_lut += (((indexMap(index) % base) << bits) / base) * tmp;
|
||||
tmp <<= bits;
|
||||
}
|
||||
|
||||
for (size_t block = 0; block < crt_decomposition_size; block++) {
|
||||
auto base = crt_decomposition_aligned[crt_decomposition_offset + block];
|
||||
auto v = encode_crt(input_lut_aligned[input_lut_offset + value], base,
|
||||
auto v = encode_crt(input_lut_aligned[input_lut_offset + index], base,
|
||||
modulus_product);
|
||||
output_lut_aligned[output_lut_offset + block * lut_crt_size + index_lut] =
|
||||
v;
|
||||
|
||||
@@ -58,8 +58,8 @@ llvm::Expected<CircuitGate> gateFromMLIRType(V0FHEContext fheContext,
|
||||
};
|
||||
}
|
||||
if (auto lweTy = type.dyn_cast_or_null<
|
||||
mlir::concretelang::FHE::EncryptedIntegerType>()) {
|
||||
bool sign = lweTy.isSignedInteger();
|
||||
mlir::concretelang::FHE::FheIntegerInterface>()) {
|
||||
bool sign = lweTy.isSigned();
|
||||
std::vector<int64_t> crt;
|
||||
if (fheContext.parameter.largeInteger.has_value()) {
|
||||
crt = fheContext.parameter.largeInteger.value().crtDecomposition;
|
||||
@@ -72,15 +72,14 @@ llvm::Expected<CircuitGate> gateFromMLIRType(V0FHEContext fheContext,
|
||||
{
|
||||
/* .precision = */ lweTy.getWidth(),
|
||||
/* .crt = */ crt,
|
||||
/*.sign = */ sign,
|
||||
},
|
||||
}),
|
||||
/*.shape = */
|
||||
{
|
||||
/*.width = */ (size_t)lweTy.getWidth(),
|
||||
/*.dimensions = */ std::vector<int64_t>(),
|
||||
/*.size = */ 0,
|
||||
/*.sign = */ sign,
|
||||
},
|
||||
{/*.width = */ (size_t)lweTy.getWidth(),
|
||||
/*.dimensions = */ std::vector<int64_t>(),
|
||||
/*.size = */ 0,
|
||||
/*.sign = */ sign},
|
||||
};
|
||||
}
|
||||
if (auto lweTy = type.dyn_cast_or_null<
|
||||
@@ -214,17 +213,12 @@ createClientParametersForV0(V0FHEContext fheContext,
|
||||
auto funcType = (*funcOp).getFunctionType();
|
||||
|
||||
auto inputs = funcType.getInputs();
|
||||
bool hasContext =
|
||||
inputs.empty()
|
||||
? false
|
||||
: inputs.back().isa<mlir::concretelang::Concrete::ContextType>();
|
||||
|
||||
auto gateFromType = [&](mlir::Type ty) {
|
||||
return gateFromMLIRType(fheContext, clientlib::BIG_KEY, inputVariance, ty);
|
||||
};
|
||||
for (auto inType = funcType.getInputs().begin();
|
||||
inType < funcType.getInputs().end() - hasContext; inType++) {
|
||||
auto gate = gateFromType(*inType);
|
||||
for (auto inType : inputs) {
|
||||
auto gate = gateFromType(inType);
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @apply_lookup_table(%arg0: tensor<4xi64>) -> tensor<1024xi64> {
|
||||
// CHECK-NEXT: %0 = "BConcrete.encode_expand_lut_for_bootstrap_tensor"(%arg0) {outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64>
|
||||
// CHECK-NEXT: %0 = "BConcrete.encode_expand_lut_for_bootstrap_tensor"(%arg0) {isSigned = true, outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64>
|
||||
// CHECK-NEXT: return %0 : tensor<1024xi64>
|
||||
// CHECK-NEXT: }
|
||||
func.func @apply_lookup_table(%arg0: tensor<4xi64>) -> tensor<1024xi64> {
|
||||
%0 = "Concrete.encode_expand_lut_for_bootstrap"(%arg0) {outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64>
|
||||
%0 = "Concrete.encode_expand_lut_for_bootstrap"(%arg0) {outputBits = 3 : i32, polySize = 1024 : i32, isSigned = true} : (tensor<4xi64>) -> tensor<1024xi64>
|
||||
return %0 : tensor<1024xi64>
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @main(%arg0: tensor<4xi64>) -> tensor<40960xi64> {
|
||||
// CHECK-NEXT: %0 = "BConcrete.encode_expand_lut_for_woppbs_tensor"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
|
||||
// CHECK-NEXT: %0 = "BConcrete.encode_expand_lut_for_woppbs_tensor"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
|
||||
// CHECK-NEXT: return %0 : tensor<40960xi64>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg0: tensor<4xi64>) -> tensor<40960xi64> {
|
||||
%0 = "Concrete.encode_expand_lut_for_woppbs"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
|
||||
%0 = "Concrete.encode_expand_lut_for_woppbs"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
|
||||
return %0 : tensor<40960xi64>
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @apply_lookup_table(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{3}>>
|
||||
// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
|
||||
// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
|
||||
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bootstrapBaseLog = -1 : i32, bootstrapLevel = -1 : i32, circuitBootstrapBaseLog = -1 : i32, circuitBootstrapLevel = -1 : i32, crtDecomposition = [], keyswitchBaseLog = -1 : i32, keyswitchLevel = -1 : i32, packingKeySwitchBaseLog = -1 : i32, packingKeySwitchInputLweDimension = -1 : i32, packingKeySwitchLevel = -1 : i32, packingKeySwitchoutputPolynomialSize = -1 : i32} : (tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<40960xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{3}>>
|
||||
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{3}>>
|
||||
func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<3> {
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
// CHECK: func.func @apply_lookup_table_cst(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
|
||||
// CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
|
||||
// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_woppbs"(%cst) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<128xi64>) -> tensor<40960xi64>
|
||||
// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_woppbs"(%cst) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<128xi64>) -> tensor<40960xi64>
|
||||
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bootstrapBaseLog = -1 : i32, bootstrapLevel = -1 : i32, circuitBootstrapBaseLog = -1 : i32, circuitBootstrapLevel = -1 : i32, crtDecomposition = [], keyswitchBaseLog = -1 : i32, keyswitchLevel = -1 : i32, packingKeySwitchBaseLog = -1 : i32, packingKeySwitchInputLweDimension = -1 : i32, packingKeySwitchLevel = -1 : i32, packingKeySwitchoutputPolynomialSize = -1 : i32} : (tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<40960xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
func.func @apply_lookup_table_cst(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @apply_lookup_table(%arg0: !TFHE.glwe<{_,_,_}{2}>, %arg1: tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}> {
|
||||
// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_bootstrap"(%arg1) {outputBits = 3 : i32, polySize = 256 : i32} : (tensor<4xi64>) -> tensor<256xi64>
|
||||
// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_bootstrap"(%arg1) {isSigned = false, outputBits = 3 : i32, polySize = 256 : i32} : (tensor<4xi64>) -> tensor<256xi64>
|
||||
// CHECK-NEXT: %1 = "TFHE.keyswitch_glwe"(%arg0) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
// CHECK-NEXT: %2 = "TFHE.bootstrap_glwe"(%1, %0) {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>, tensor<256xi64>) -> !TFHE.glwe<{_,_,_}{3}>
|
||||
// CHECK-NEXT: return %2 : !TFHE.glwe<{_,_,_}{3}>
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
//CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> {
|
||||
|
||||
//CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
|
||||
//CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_bootstrap"(%cst) {outputBits = 7 : i32, polySize = 8192 : i32} : (tensor<128xi64>) -> tensor<8192xi64>
|
||||
//CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_bootstrap"(%cst) {isSigned = false, outputBits = 7 : i32, polySize = 8192 : i32} : (tensor<128xi64>) -> tensor<8192xi64>
|
||||
//CHECK-NEXT: %1 = "TFHE.keyswitch_glwe"(%arg0) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
//CHECK-NEXT: %2 = "TFHE.bootstrap_glwe"(%1, %0) {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>, tensor<8192xi64>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
//CHECK-NEXT: return %2 : !TFHE.glwe<{_,_,_}{7}>
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @apply_lookup_table(%arg0: tensor<4xi64>) -> tensor<1024xi64> {
|
||||
// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_bootstrap"(%arg0) {outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64>
|
||||
// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_bootstrap"(%arg0) {isSigned = true, outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64>
|
||||
// CHECK-NEXT: return %0 : tensor<1024xi64>
|
||||
// CHECK-NEXT: }
|
||||
func.func @apply_lookup_table(%arg1: tensor<4xi64>) -> tensor<1024xi64> {
|
||||
%0 = "TFHE.encode_expand_lut_for_bootstrap"(%arg1) {outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64>
|
||||
%0 = "TFHE.encode_expand_lut_for_bootstrap"(%arg1) {outputBits = 3 : i32, polySize = 1024 : i32, isSigned = true} : (tensor<4xi64>) -> tensor<1024xi64>
|
||||
return %0: tensor<1024xi64>
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @main(%arg0: tensor<4xi64>) -> tensor<40960xi64> {
|
||||
// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_woppbs"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
|
||||
// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_woppbs"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
|
||||
// CHECK-NEXT: return %0 : tensor<40960xi64>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg1: tensor<4xi64>) -> tensor<40960xi64> {
|
||||
%0 = "TFHE.encode_expand_lut_for_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
|
||||
%0 = "TFHE.encode_expand_lut_for_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], isSigned = false, modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
|
||||
return %0: tensor<40960xi64>
|
||||
}
|
||||
|
||||
@@ -111,19 +111,14 @@ llvm::Error checkResult(const mlir::concretelang::TensorLambdaArgument<
|
||||
if (!expectedNumElts)
|
||||
return expectedNumElts.takeError();
|
||||
|
||||
auto hasError = false;
|
||||
StreamStringError err("result value differ");
|
||||
for (size_t i = 0; i < *expectedNumElts; i++) {
|
||||
|
||||
if (resValues[i] != expectedValues[i]) {
|
||||
hasError = true;
|
||||
err << " [pos(" << i << "), got " << resValues[i] << " expected "
|
||||
<< expectedValues[i] << "]";
|
||||
if ((uint64_t)resValues[i] != (uint64_t)expectedValues[i]) {
|
||||
return StreamStringError("result value differ at pos(")
|
||||
<< i << "), got " << resValues[i] << " expected "
|
||||
<< expectedValues[i];
|
||||
}
|
||||
}
|
||||
if (hasError) {
|
||||
return err;
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
|
||||
@@ -11,10 +11,11 @@ def generate(args):
|
||||
print("# /!\ DO NOT EDIT MANUALLY THIS FILE MANUALLY")
|
||||
print("# /!\ THIS FILE HAS BEEN GENERATED")
|
||||
np.random.seed(0)
|
||||
# unsigned_unsigned
|
||||
for p in args.bitwidth:
|
||||
max_value = (2 ** p) - 1
|
||||
random_lut = np.random.randint(max_value+1, size=2**p)
|
||||
print(f"description: apply_lookup_table_{p}bits")
|
||||
print(f"description: unsigned_apply_lookup_table_{p}bits")
|
||||
print("program: |")
|
||||
print(
|
||||
f" func.func @main(%arg0: !FHE.eint<{p}>) -> !FHE.eint<{p}> {{")
|
||||
@@ -41,6 +42,139 @@ def generate(args):
|
||||
print(" outputs:")
|
||||
print(f" - scalar: {random_lut[max_value]}")
|
||||
print("---")
|
||||
# unsigned_signed
|
||||
for p in args.bitwidth:
|
||||
lower_bound = -(2 ** (p-1))
|
||||
upper_bound = (2 ** (p-1)) - 1
|
||||
max_value = (2 ** p) - 1
|
||||
random_lut = np.random.randint(lower_bound, upper_bound, size=2**p)
|
||||
print(f"description: unsigned_signed_apply_lookup_table_{p}bits")
|
||||
print("program: |")
|
||||
print(
|
||||
f" func.func @main(%arg0: !FHE.eint<{p}>) -> !FHE.esint<{p}> {{")
|
||||
print(f" %tlu = arith.constant dense<[{','.join(map(str, random_lut))}]> : tensor<{2**p}xi64>")
|
||||
print(
|
||||
f" %1 = \"FHE.apply_lookup_table\"(%arg0, %tlu): (!FHE.eint<{p}>, tensor<{2**p}xi64>) -> (!FHE.esint<{p}>)")
|
||||
print(f" return %1: !FHE.esint<{p}>")
|
||||
print(" }")
|
||||
if p >= PRECISION_FORCE_CRT:
|
||||
print("encoding: crt")
|
||||
print(f"p-error: {P_ERROR}")
|
||||
print("tests:")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: 0")
|
||||
print(" outputs:")
|
||||
print(f" - scalar: {random_lut[0]}")
|
||||
print(f" signed: true")
|
||||
print(" - inputs:")
|
||||
random_i = np.random.randint(max_value)
|
||||
print(f" - scalar: {random_i}")
|
||||
print(" outputs:")
|
||||
print(f" - scalar: {random_lut[random_i]}")
|
||||
print(f" signed: true")
|
||||
print(" - inputs:")
|
||||
print(f" - scalar: {max_value}")
|
||||
print(" outputs:")
|
||||
print(f" - scalar: {random_lut[max_value]}")
|
||||
print(f" signed: true")
|
||||
print("---")
|
||||
# signed_signed
|
||||
for p in args.bitwidth:
|
||||
lower_bound = -(2 ** (p-1))
|
||||
upper_bound = (2 ** (p-1)) - 1
|
||||
random_lut = np.random.randint(lower_bound, upper_bound, size=2**p)
|
||||
print(f"description: signed_apply_lookup_table_{p}bits")
|
||||
print("program: |")
|
||||
print(
|
||||
f" func.func @main(%arg0: !FHE.esint<{p}>) -> !FHE.esint<{p}> {{")
|
||||
print(f" %tlu = arith.constant dense<[{','.join(map(str, random_lut))}]> : tensor<{2**p}xi64>")
|
||||
print(
|
||||
f" %1 = \"FHE.apply_lookup_table\"(%arg0, %tlu): (!FHE.esint<{p}>, tensor<{2**p}xi64>) -> (!FHE.esint<{p}>)")
|
||||
print(f" return %1: !FHE.esint<{p}>")
|
||||
print(" }")
|
||||
if p >= PRECISION_FORCE_CRT:
|
||||
print("encoding: crt")
|
||||
print(f"p-error: {P_ERROR}")
|
||||
print("tests:")
|
||||
print(" - inputs:")
|
||||
print(f" - scalar: 0")
|
||||
print(f" signed: true")
|
||||
print(" outputs:")
|
||||
print(f" - scalar: {random_lut[0]}")
|
||||
print(f" signed: true")
|
||||
print(" - inputs:")
|
||||
print(f" - scalar: {upper_bound}")
|
||||
print(f" signed: true")
|
||||
print(" outputs:")
|
||||
print(f" - scalar: {random_lut[upper_bound]}")
|
||||
print(f" signed: true")
|
||||
print(" - inputs:")
|
||||
print(f" - scalar: {lower_bound}")
|
||||
print(f" signed: true")
|
||||
print(" outputs:")
|
||||
print(f" - scalar: {random_lut[lower_bound]}")
|
||||
print(f" signed: true")
|
||||
print(" - inputs:")
|
||||
print(f" - scalar: -1")
|
||||
print(f" signed: true")
|
||||
print(" outputs:")
|
||||
print(f" - scalar: {random_lut[-1]}")
|
||||
print(f" signed: true")
|
||||
print(" - inputs:")
|
||||
random_i = np.random.randint(lower_bound, upper_bound)
|
||||
print(f" - scalar: {random_i}")
|
||||
print(f" signed: true")
|
||||
print(" outputs:")
|
||||
print(f" - scalar: {random_lut[random_i]}")
|
||||
print(f" signed: true")
|
||||
print("---")
|
||||
|
||||
# signed_unsigned
|
||||
for p in args.bitwidth:
|
||||
lower_bound = -(2 ** (p-1))
|
||||
upper_bound = (2 ** (p-1)) - 1
|
||||
max_value = (2 ** p) - 1
|
||||
random_lut = np.random.randint(max_value+1, size=2**p)
|
||||
print(f"description: signed_unsigned_apply_lookup_table_{p}bits")
|
||||
print("program: |")
|
||||
print(
|
||||
f" func.func @main(%arg0: !FHE.esint<{p}>) -> !FHE.eint<{p}> {{")
|
||||
print(f" %tlu = arith.constant dense<[{','.join(map(str, random_lut))}]> : tensor<{2**p}xi64>")
|
||||
print(
|
||||
f" %1 = \"FHE.apply_lookup_table\"(%arg0, %tlu): (!FHE.esint<{p}>, tensor<{2**p}xi64>) -> (!FHE.eint<{p}>)")
|
||||
print(f" return %1: !FHE.eint<{p}>")
|
||||
print(" }")
|
||||
if p >= PRECISION_FORCE_CRT:
|
||||
print("encoding: crt")
|
||||
print(f"p-error: {P_ERROR}")
|
||||
print("tests:")
|
||||
print(" - inputs:")
|
||||
print(f" - scalar: 0")
|
||||
print(f" signed: true")
|
||||
print(" outputs:")
|
||||
print(f" - scalar: {random_lut[0]}")
|
||||
print(" - inputs:")
|
||||
print(f" - scalar: {upper_bound}")
|
||||
print(f" signed: true")
|
||||
print(" outputs:")
|
||||
print(f" - scalar: {random_lut[upper_bound]}")
|
||||
print(" - inputs:")
|
||||
print(f" - scalar: {lower_bound}")
|
||||
print(f" signed: true")
|
||||
print(" outputs:")
|
||||
print(f" - scalar: {random_lut[lower_bound]}")
|
||||
print(" - inputs:")
|
||||
print(f" - scalar: -1")
|
||||
print(f" signed: true")
|
||||
print(" outputs:")
|
||||
print(f" - scalar: {random_lut[-1]}")
|
||||
print(" - inputs:")
|
||||
random_i = np.random.randint(lower_bound, upper_bound)
|
||||
print(f" - scalar: {random_i}")
|
||||
print(f" signed: true")
|
||||
print(" outputs:")
|
||||
print(f" - scalar: {random_lut[random_i]}")
|
||||
print("---")
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI = argparse.ArgumentParser()
|
||||
|
||||
@@ -17,6 +17,7 @@ def main():
|
||||
print("# /!\ DO NOT EDIT MANUALLY THIS FILE MANUALLY")
|
||||
print("# /!\ THIS FILE HAS BEEN GENERATED THANKS THE end_to_end_levelled_gen.py scripts")
|
||||
print("# This reference file aims to test all levelled ops with all bitwidth than we known that the compiler/optimizer support.\n\n")
|
||||
# unsigned
|
||||
for p in range(MIN_PRECISON, MAX_PRECISION+1):
|
||||
if p != 1:
|
||||
print("---")
|
||||
@@ -301,6 +302,579 @@ def main():
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
may_check_error_rate()
|
||||
print("---")
|
||||
# signed
|
||||
for p in range(MIN_PRECISON, MAX_PRECISION+1):
|
||||
print("---")
|
||||
def may_check_error_rate():
|
||||
if p in PRECISIONS_WITH_ERROR_RATES:
|
||||
print(TEST_ERROR_RATES)
|
||||
min_value = -(2 ** (p - 1))
|
||||
max_value = abs(min_value) - 1
|
||||
integer_bitwidth = p + 1
|
||||
max_constant = min((2 ** (57-p)) - 1, max_value)
|
||||
|
||||
# identity
|
||||
print("description: signed_identity_{0}bits".format(p))
|
||||
print("program: |")
|
||||
print(" func.func @main(%arg0: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p))
|
||||
print(" return %arg0: !FHE.esint<{0}>".format(p))
|
||||
print(" }")
|
||||
print("tests:")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
may_check_error_rate()
|
||||
|
||||
print("---")
|
||||
|
||||
# zero_tensor
|
||||
print("description: signed_zero_tensor_{0}bits".format(p))
|
||||
print("program: |")
|
||||
print(" func.func @main() -> tensor<2x2x4x!FHE.esint<{0}>> {{".format(p))
|
||||
print(" %0 = \"FHE.zero_tensor\"() : () -> tensor<2x2x4x!FHE.esint<{0}>>".format(p))
|
||||
print(" return %0: tensor<2x2x4x!FHE.esint<{0}>>".format(p))
|
||||
print(" }")
|
||||
print("tests:")
|
||||
print(" - outputs:")
|
||||
print(" - tensor: [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]")
|
||||
print(" shape: [2,2,4]")
|
||||
print(" signed: true")
|
||||
may_check_error_rate()
|
||||
|
||||
print("---")
|
||||
|
||||
# add_eint_int_cst
|
||||
print("description: signed_add_eint_int_cst_{0}bits".format(p))
|
||||
print("program: |")
|
||||
print(" func.func @main(%arg0: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p))
|
||||
print(" %0 = arith.constant 1 : i{0}".format(integer_bitwidth))
|
||||
print(" %1 = \"FHE.add_eint_int\"(%arg0, %0): (!FHE.esint<{0}>, i{1}) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth))
|
||||
print(" return %1: !FHE.esint<{0}>".format(p))
|
||||
print(" }")
|
||||
print("tests:")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(-1))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value-1))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(min_value + 1))
|
||||
print(" signed: true")
|
||||
may_check_error_rate()
|
||||
|
||||
print("---")
|
||||
|
||||
# add_eint_int_arg
|
||||
if p <= 28:
|
||||
# above 28 bits the *arg test doesn't have solution
|
||||
# TODO: Make a test that test that
|
||||
print("description: signed_add_eint_int_arg_{0}bits".format(p))
|
||||
print("program: |")
|
||||
print(" func.func @main(%arg0: !FHE.esint<{0}>, %arg1: i{1}) -> !FHE.esint<{0}> {{".format(p, integer_bitwidth))
|
||||
print(" %0 = \"FHE.add_eint_int\"(%arg0, %arg1): (!FHE.esint<{0}>, i{1}) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth))
|
||||
print(" return %0: !FHE.esint<{0}>".format(p))
|
||||
print(" }")
|
||||
print("tests:")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(1))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(min_value + 1))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value-1))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(1))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(-1))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(-1))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(-1))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(1))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
may_check_error_rate()
|
||||
|
||||
print("---")
|
||||
|
||||
# add_eint
|
||||
print("description: signed_add_eint_{0}bits".format(p))
|
||||
print("program: |")
|
||||
print(" func.func @main(%arg0: !FHE.esint<{0}>, %arg1: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p))
|
||||
print(" %res = \"FHE.add_eint\"(%arg0, %arg1): (!FHE.esint<{0}>, !FHE.esint<{0}>) -> !FHE.esint<{0}>".format(p))
|
||||
print(" return %res: !FHE.esint<{0}>".format(p))
|
||||
print(" }")
|
||||
print("tests:")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(((2 ** (p - 1)) >> 1) - 1))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(((2 ** (p - 1)) >> 1)))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(-1 if p == 1 else (2 ** (p - 1)) - 1))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(-(2 ** (p - 1))))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(((2 ** (p - 1)) - 1)))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(-1))
|
||||
print(" signed: true")
|
||||
may_check_error_rate()
|
||||
|
||||
print("---")
|
||||
|
||||
# sub_eint_int_cst
|
||||
print("description: signed_sub_eint_int_cst_{0}bits".format(p))
|
||||
print("program: |")
|
||||
print(" func.func @main(%arg0: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p))
|
||||
print(" %0 = arith.constant 1 : i{0}".format(integer_bitwidth))
|
||||
print(" %1 = \"FHE.sub_eint_int\"(%arg0, %0): (!FHE.esint<{0}>, i{1}) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth))
|
||||
print(" return %1: !FHE.esint<{0}>".format(p))
|
||||
print(" }")
|
||||
print("tests:")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value - 1))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(-1))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(min_value + 1))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
may_check_error_rate()
|
||||
|
||||
print("---")
|
||||
|
||||
# sub_eint_int_arg
|
||||
if p <= 28:
|
||||
# above 28 bits the *arg test doesn't have solution
|
||||
# TODO: Make a test that test that
|
||||
print("description: signed_sub_eint_int_arg_{0}bits".format(p))
|
||||
print("program: |")
|
||||
print(" func.func @main(%arg0: !FHE.esint<{0}>, %arg1: i{1}) -> !FHE.esint<{0}> {{".format(p, integer_bitwidth))
|
||||
print(" %1 = \"FHE.sub_eint_int\"(%arg0, %arg1): (!FHE.esint<{0}>, i{1}) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth))
|
||||
print(" return %1: !FHE.esint<{0}>".format(p))
|
||||
print(" }")
|
||||
print("tests:")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(1))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value - 1))
|
||||
print(" signed: true")
|
||||
if p != 28:
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(2 * max_value))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(-max_value))
|
||||
print(" signed: true")
|
||||
may_check_error_rate()
|
||||
|
||||
print("---")
|
||||
|
||||
# sub_int_eint_cst
|
||||
if p != 1:
|
||||
print("description: signed_sub_int_eint_cst_{0}bits".format(p))
|
||||
print("program: |")
|
||||
print(" func.func @main(%arg0: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p))
|
||||
print(" %0 = arith.constant 1 : i{0}".format(integer_bitwidth))
|
||||
print(" %1 = \"FHE.sub_int_eint\"(%0, %arg0): (i{1}, !FHE.esint<{0}>) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth))
|
||||
print(" return %1: !FHE.esint<{0}>".format(p))
|
||||
print(" }")
|
||||
print("tests:")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(min_value + 2))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(1))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(min_value + 2))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
may_check_error_rate()
|
||||
|
||||
print("---")
|
||||
|
||||
# sub_int_eint_arg
|
||||
if p <= 28:
|
||||
# above 28 bits the *arg test doesn't have solution
|
||||
# TODO: Make a test that test that
|
||||
print("description: signed_sub_int_eint_arg_{0}bits".format(p))
|
||||
print("program: |")
|
||||
print(" func.func @main(%arg0: i{1}, %arg1: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p, integer_bitwidth))
|
||||
print(" %1 = \"FHE.sub_int_eint\"(%arg0, %arg1): (i{1}, !FHE.esint<{0}>) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth))
|
||||
print(" return %1: !FHE.esint<{0}>".format(p))
|
||||
print(" }")
|
||||
print("tests:")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(1))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value - 1))
|
||||
print(" signed: true")
|
||||
if p != 28:
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(2 * max_value))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(-max_value))
|
||||
print(" signed: true")
|
||||
may_check_error_rate()
|
||||
|
||||
print("---")
|
||||
|
||||
# sub_eint
|
||||
print("description: signed_sub_eint_{0}bits".format(p))
|
||||
print("program: |")
|
||||
print(" func.func @main(%arg0: !FHE.esint<{0}>, %arg1: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p))
|
||||
print(" %res = \"FHE.sub_eint\"(%arg0, %arg1): (!FHE.esint<{0}>, !FHE.esint<{0}>) -> !FHE.esint<{0}>".format(p))
|
||||
print(" return %res: !FHE.esint<{0}>".format(p))
|
||||
print(" }")
|
||||
print("tests:")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(1))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(-1))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(min_value + 1))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(min_value + 1))
|
||||
print(" signed: true")
|
||||
may_check_error_rate()
|
||||
|
||||
print("---")
|
||||
|
||||
# mul_eint_int_cst
|
||||
print("description: signed_mul_eint_int_cst_{0}bits".format(p))
|
||||
print("program: |")
|
||||
print(" func.func @main(%arg0: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p))
|
||||
print(" %0 = arith.constant 2 : i{0}".format(integer_bitwidth))
|
||||
print(" %1 = \"FHE.mul_eint_int\"(%arg0, %0): (!FHE.esint<{0}>, i{1}) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth))
|
||||
print(" return %1: !FHE.esint<{0}>".format(p))
|
||||
print(" }")
|
||||
print("tests:")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
if p != 1:
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value // 2))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value - 1))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(min_value // 2))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
may_check_error_rate()
|
||||
|
||||
print("---")
|
||||
|
||||
# mul_eint_int_arg
|
||||
if p <= 28:
|
||||
# above 28 bits the *arg test doesn't have solution
|
||||
# TODO: Make a test that test that
|
||||
print("description: signed_mul_eint_int_arg_{0}bits".format(p))
|
||||
print("program: |")
|
||||
print(" func.func @main(%arg0: !FHE.esint<{0}>, %arg1: i{1}) -> !FHE.esint<{0}> {{".format(p, integer_bitwidth))
|
||||
print(" %0 = \"FHE.mul_eint_int\"(%arg0, %arg1): (!FHE.esint<{0}>, i{1}) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth))
|
||||
print(" return %0: !FHE.esint<{0}>".format(p))
|
||||
print(" }")
|
||||
print("tests:")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(0))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(1))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(1))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(min_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(1))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(1))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(-1))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(min_value + 1))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(-1))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(min_value + 1))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(min_value + 1))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(-1))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(-1))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(min_value + 1))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(max_value))
|
||||
print(" signed: true")
|
||||
if p > 2:
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(3))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(1))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(3))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(3))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(-1))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(-3))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(-3))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(1))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(-3))
|
||||
print(" signed: true")
|
||||
print(" - inputs:")
|
||||
print(" - scalar: {0}".format(-3))
|
||||
print(" signed: true")
|
||||
print(" - scalar: {0}".format(-1))
|
||||
print(" signed: true")
|
||||
print(" outputs:")
|
||||
print(" - scalar: {0}".format(3))
|
||||
print(" signed: true")
|
||||
may_check_error_rate()
|
||||
|
||||
print("---")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -11,11 +11,8 @@ from concrete.compiler import ClientSupport
|
||||
pytest.param([0, 1, 2], id="list"),
|
||||
pytest.param(0.5, id="float"),
|
||||
pytest.param(2**70, id="large int"),
|
||||
pytest.param(-8, id="negative int"),
|
||||
pytest.param("aze", id="str"),
|
||||
pytest.param(np.float64(0.8), id="np.float64"),
|
||||
pytest.param(np.int8(9), id="np.int8"),
|
||||
pytest.param(np.array([1, 2, 3], dtype=np.int64), id="np.array(np.int64)"),
|
||||
],
|
||||
)
|
||||
def test_invalid_arg_type(garbage):
|
||||
|
||||
@@ -108,5 +108,7 @@ def test_client_server_end_to_end(mlir, args, expected_result, keyset_cache):
|
||||
client_parameters, result_serialized
|
||||
)
|
||||
|
||||
output = ClientSupport.decrypt_result(keyset, result_unserialized)
|
||||
output = ClientSupport.decrypt_result(
|
||||
client_parameters, keyset, result_unserialized
|
||||
)
|
||||
assert np.array_equal(output, expected_result)
|
||||
|
||||
@@ -47,7 +47,7 @@ def run(engine, args, compilation_result, keyset_cache):
|
||||
evaluation_keys = key_set.get_evaluation_keys()
|
||||
public_result = engine.server_call(server_lambda, public_arguments, evaluation_keys)
|
||||
# Client
|
||||
result = ClientSupport.decrypt_result(key_set, public_result)
|
||||
result = ClientSupport.decrypt_result(client_parameters, key_set, public_result)
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -42,17 +42,20 @@ TEST(Support, client_parameters_json_serde) {
|
||||
}}};
|
||||
params0.inputs = {
|
||||
{
|
||||
/*.encryption = */ {{clientlib::SMALL_KEY, 0.00, {4, {1, 2, 3, 4}}}},
|
||||
/*.encryption = */ {
|
||||
{clientlib::SMALL_KEY, 0.00, {4, {1, 2, 3, 4}, false}}},
|
||||
/*.shape = */ {32, {1, 2, 3, 4}, 1 * 2 * 3 * 4, false},
|
||||
},
|
||||
{
|
||||
/*.encryption = */ {{clientlib::SMALL_KEY, 0.00, {5, {1, 2, 3, 4}}}},
|
||||
/*.encryption = */ {
|
||||
{clientlib::SMALL_KEY, 0.00, {5, {1, 2, 3, 4}, false}}},
|
||||
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4, false},
|
||||
},
|
||||
};
|
||||
params0.outputs = {
|
||||
{
|
||||
/*.encryption = */ {{clientlib::SMALL_KEY, 0.00, {5, {1, 2, 3, 4}}}},
|
||||
/*.encryption = */ {
|
||||
{clientlib::SMALL_KEY, 0.00, {5, {1, 2, 3, 4}, false}}},
|
||||
/*.shape = */ {8, {4, 4, 4, 4}, 4 * 4 * 4 * 4, false},
|
||||
},
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user