feat(compiler): First draft to support FHE.eint up to 16bits

For now what it works are only levelled ops with user parameters. (take a look to the tests)

Done:
- Add parameters to the fhe parameters to support CRT-based large integers
- Add command line options and tests options to allows the user to give those new parameters
- Update the dialects and pipeline to handle new fhe parameters for CRT-based large integers
- Update the client parameters and the client library to handle the CRT-based large integers

Todo:
- Plug the optimizer to compute the CRT-based large interger parameters
- Plug the pbs for the CRT-based large integer
This commit is contained in:
Quentin Bourgerie
2022-06-20 11:01:06 +02:00
parent 58527a44c3
commit 8cd3a3a599
82 changed files with 3192 additions and 1037 deletions

View File

@@ -0,0 +1,46 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CLIENTLIB_CRT_H_
#define CONCRETELANG_CLIENTLIB_CRT_H_
#include <cstdint>
#include <vector>
namespace concretelang {
namespace clientlib {
namespace crt {
/// Compute the product of the moduli of the crt decomposition.
///
/// \param moduli The moduli of the crt decomposition
/// \returns The product of moduli
uint64_t productOfModuli(std::vector<int64_t> moduli);
/// Compute the crt decomposition of a `val` according the given `moduli`.
///
/// \param moduli The moduli to compute the decomposition.
/// \param val The value to decompose.
/// \returns The remainders.
std::vector<int64_t> crt(std::vector<int64_t> moduli, uint64_t val);
/// Compute the inverse of the crt decomposition.
///
/// \param moduli The moduli used to compute the inverse decomposition.
/// \param remainders The remainders of the decomposition.
uint64_t iCrt(std::vector<int64_t> moduli, std::vector<int64_t> remainders);
/// Encode the plaintext with the given modulus and the product of moduli of the
/// crt decomposition
uint64_t encode(int64_t plaintext, uint64_t modulus, uint64_t product);
/// Decode follow the crt encoding
uint64_t decode(uint64_t val, uint64_t modulus);
} // namespace crt
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -7,6 +7,7 @@
#define CONCRETELANG_CLIENTLIB_CLIENTPARAMETERS_H_
#include <map>
#include <optional>
#include <string>
#include <vector>
@@ -31,6 +32,7 @@ typedef size_t DecompositionBaseLog;
typedef size_t PolynomialSize;
typedef size_t Precision;
typedef double Variance;
typedef std::vector<int64_t> CRTDecomposition;
typedef uint64_t LweDimension;
typedef uint64_t GlweDimension;
@@ -87,6 +89,7 @@ static inline bool operator==(const KeyswitchKeyParam &lhs,
struct Encoding {
Precision precision;
CRTDecomposition crt;
};
static inline bool operator==(const Encoding &lhs, const Encoding &rhs) {
return lhs.precision == rhs.precision;
@@ -168,6 +171,52 @@ struct ClientParameters {
}
return secretKey->second;
}
/// bufferSize returns the size of the whole buffer of a gate.
int64_t bufferSize(CircuitGate gate) {
if (!gate.encryption.hasValue()) {
// Value is not encrypted just returns the tensor size
return gate.shape.size;
}
auto shapeSize = gate.shape.size == 0 ? 1 : gate.shape.size;
// Size of the ciphertext
return shapeSize * lweBufferSize(gate);
}
/// lweBufferSize returns the size of one ciphertext of a gate.
int64_t lweBufferSize(CircuitGate gate) {
assert(gate.encryption.hasValue());
auto nbBlocks = gate.encryption->encoding.crt.size();
nbBlocks = nbBlocks == 0 ? 1 : nbBlocks;
auto param = lweSecretKeyParam(gate);
assert(param.has_value());
return param.value().lweSize() * nbBlocks;
}
/// bufferShape returns the shape of the tensor for the given gate. It returns
/// the shape used at low-level, i.e. contains the dimensions for ciphertexts.
std::vector<int64_t> bufferShape(CircuitGate gate) {
if (!gate.encryption.hasValue()) {
// Value is not encrypted just returns the tensor shape
return gate.shape.dimensions;
}
auto lweSecreteKeyParam = lweSecretKeyParam(gate);
assert(lweSecreteKeyParam.has_value());
// Copy the shape
std::vector<int64_t> shape(gate.shape.dimensions);
auto crt = gate.encryption->encoding.crt;
// CRT case: Add one dimension equals to the number of blocks
if (!crt.empty()) {
shape.push_back(crt.size());
}
// Add one dimension for the size of ciphertext(s)
shape.push_back(lweSecreteKeyParam.value().lweSize());
return shape;
}
};
static inline bool operator==(const ClientParameters &lhs,

View File

@@ -23,11 +23,24 @@ using concretelang::error::StringError;
class PublicArguments;
inline size_t bitWidthAsWord(size_t exactBitWidth) {
if (exactBitWidth <= 8)
return 8;
if (exactBitWidth <= 16)
return 16;
if (exactBitWidth <= 32)
return 32;
if (exactBitWidth <= 64)
return 64;
assert(false && "Bit witdh > 64 not supported");
}
/// Temporary object used to hold and encrypt parameters before calling a
/// ClientLambda. Use preferably TypeClientLambda and serializeCall(Args...).
/// Otherwise convert it to a PublicArguments and use
/// serializeCall(PublicArguments, KeySet).
class EncryptedArguments {
public:
EncryptedArguments() : currentPos(0) {}
@@ -73,7 +86,10 @@ public:
/// Add a vector-tensor argument.
outcome::checked<void, StringError> pushArg(std::vector<uint8_t> arg,
KeySet &keySet);
KeySet &keySet) {
return pushArg((uint8_t *)arg.data(),
llvm::ArrayRef<int64_t>{(int64_t)arg.size()}, keySet);
}
/// Add a 1D tensor argument with data and size of the dimension.
template <typename T>
@@ -82,26 +98,20 @@ public:
return pushArg(std::vector<uint8_t>(data, data + dim1), keySet);
}
// Add a tensor argument.
template <typename T>
outcome::checked<void, StringError>
pushArg(const T *data, llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
return pushArg(8 * sizeof(T), static_cast<const void *>(data), shape,
keySet);
}
/// Add a 1D tensor argument.
template <size_t size>
outcome::checked<void, StringError> pushArg(std::array<uint8_t, size> arg,
KeySet &keySet) {
return pushArg(8, (void *)arg.data(), {size}, keySet);
return pushArg((uint8_t *)arg.data(), llvm::ArrayRef<int64_t>{size},
keySet);
}
/// Add a 2D tensor argument.
template <size_t size0, size_t size1>
outcome::checked<void, StringError>
pushArg(std::array<std::array<uint8_t, size1>, size0> arg, KeySet &keySet) {
return pushArg(8, (void *)arg.data(), {size0, size1}, keySet);
return pushArg((uint8_t *)arg.data(), llvm::ArrayRef<int64_t>{size0, size1},
keySet);
}
/// Add a 3D tensor argument.
@@ -109,7 +119,8 @@ public:
outcome::checked<void, StringError>
pushArg(std::array<std::array<std::array<uint8_t, size2>, size1>, size0> arg,
KeySet &keySet) {
return pushArg(8, (void *)arg.data(), {size0, size1, size2}, keySet);
return pushArg((uint8_t *)arg.data(),
llvm::ArrayRef<int64_t>{size0, size1, size2}, keySet);
}
// Generalize by computing shape by template recursion
@@ -125,13 +136,94 @@ public:
template <typename T>
outcome::checked<void, StringError>
pushArg(T *data, llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
return pushArg(8 * sizeof(T), static_cast<const void *>(data), shape,
keySet);
return pushArg(static_cast<const T *>(data), shape, keySet);
}
outcome::checked<void, StringError> pushArg(size_t width, const void *data,
llvm::ArrayRef<int64_t> shape,
KeySet &keySet);
template <typename T>
outcome::checked<void, StringError>
pushArg(const T *data, llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
auto pos = currentPos;
CircuitGate input = keySet.inputGate(pos);
// Check the width of data
if (input.shape.width > 64) {
return StringError("argument #")
<< pos << " width > 64 bits is not supported";
}
// Check the shape of tensor
if (input.shape.dimensions.empty()) {
return StringError("argument #") << pos << "is not a tensor";
}
if (shape.size() != input.shape.dimensions.size()) {
return StringError("argument #")
<< pos << "has not the expected number of dimension, got "
<< shape.size() << " expected " << input.shape.dimensions.size();
}
// Allocate empty
ciphertextBuffers.resize(ciphertextBuffers.size() + 1);
TensorData &values_and_sizes = ciphertextBuffers.back();
// Check shape
for (size_t i = 0; i < shape.size(); i++) {
if (shape[i] != input.shape.dimensions[i]) {
return StringError("argument #")
<< pos << " has not the expected dimension #" << i << " , got "
<< shape[i] << " expected " << input.shape.dimensions[i];
}
}
// Set sizes
values_and_sizes.sizes = keySet.clientParameters().bufferShape(input);
if (input.encryption.hasValue()) {
// Allocate values
values_and_sizes.values.resize(
keySet.clientParameters().bufferSize(input));
auto lweSize = keySet.clientParameters().lweBufferSize(input);
auto &values = values_and_sizes.values;
for (size_t i = 0, offset = 0; i < input.shape.size;
i++, offset += lweSize) {
OUTCOME_TRYV(keySet.encrypt_lwe(pos, values.data() + offset, data[i]));
}
} else {
// Allocate values take care of gate bitwidth
auto bitsPerValue = bitWidthAsWord(input.shape.width);
auto bytesPerValue = bitsPerValue / 8;
auto nbWordPerValue = 8 / bytesPerValue;
// ceil division
auto size = (input.shape.size / nbWordPerValue) +
(input.shape.size % nbWordPerValue != 0);
size = size == 0 ? 1 : size;
values_and_sizes.values.resize(size);
auto v = (uint8_t *)values_and_sizes.values.data();
for (size_t i = 0; i < input.shape.size; i++) {
auto dst = v + i * bytesPerValue;
auto src = (const uint8_t *)&data[i];
for (size_t j = 0; j < bytesPerValue; j++) {
dst[j] = src[j];
}
}
}
// allocated
preparedArgs.push_back(nullptr);
// aligned
preparedArgs.push_back((void *)values_and_sizes.values.data());
// offset
preparedArgs.push_back((void *)0);
// sizes
for (size_t size : values_and_sizes.sizes) {
preparedArgs.push_back((void *)size);
}
// Set the stride for each dimension, equal to the product of the
// following dimensions.
int64_t stride = values_and_sizes.length();
for (size_t size : values_and_sizes.sizes) {
stride = (size == 0 ? 0 : (stride / size));
preparedArgs.push_back((void *)stride);
}
currentPos++;
return outcome::success();
}
/// Recursive case for scalars: extract first scalar argument from
/// parameter pack and forward rest

View File

@@ -37,7 +37,10 @@ public:
static outcome::checked<std::unique_ptr<KeySet>, StringError>
generate(ClientParameters &params, uint64_t seed_msb, uint64_t seed_lsb);
/// isInputEncrypted return true if the input at the given pos is encrypted.
/// Returns the ClientParameters associated with the KeySet.
ClientParameters clientParameters() { return _clientParameters; }
// isInputEncrypted return true if the input at the given pos is encrypted.
bool isInputEncrypted(size_t pos);
/// getInputLweSecretKeyParam returns the parameters of the lwe secret key for
@@ -155,6 +158,8 @@ private:
std::map<LweSecretKeyID,
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
keyswitchKeys);
clientlib::ClientParameters _clientParameters;
};
} // namespace clientlib

View File

@@ -108,8 +108,7 @@ struct PublicResult {
}
auto buffer = buffers[pos];
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
auto lweSize = clientParameters.lweBufferSize(gate);
std::vector<T> decryptedValues(buffer.length() / lweSize);
for (size_t i = 0; i < decryptedValues.size(); i++) {
auto ciphertext = &buffer.values[i * lweSize];
@@ -120,6 +119,13 @@ struct PublicResult {
return decryptedValues;
}
/// Return the shape of the clear tensor of a result.
outcome::checked<std::vector<int64_t>, StringError>
asClearTextShape(size_t pos) {
OUTCOME_TRY(auto gate, clientParameters.ouput(pos));
return gate.shape.dimensions;
}
// private: TODO tmp
friend class ::concretelang::serverlib::ServerLambda;
ClientParameters clientParameters;

View File

@@ -10,8 +10,17 @@
namespace mlir {
namespace concretelang {
// ApplyLookupTableLowering indicates the strategy to lower an
// FHE.apply_loopup_table ops
enum ApplyLookupTableLowering {
KeySwitchBoostrapLowering,
WopPBSLowering,
};
/// Create a pass to convert `FHE` dialect to `TFHE` dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertFHEToTFHEPass();
std::unique_ptr<OperationPass<ModuleOp>>
createConvertFHEToTFHEPass(ApplyLookupTableLowering lower);
} // namespace concretelang
} // namespace mlir

View File

@@ -23,7 +23,8 @@ using TFHE::GLWECipherTextType;
GLWECipherTextType
convertTypeEncryptedIntegerToGLWE(mlir::MLIRContext *context,
EncryptedIntegerType eint) {
return GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth());
return GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth(),
llvm::ArrayRef<int64_t>());
}
/// Converts the type `t` to `TFHE::GlweCiphetext` if `t` is a

View File

@@ -26,7 +26,8 @@ LweCiphertextType convertTypeToLWE(mlir::MLIRContext *context,
auto glwe = type.dyn_cast_or_null<GLWECipherTextType>();
if (glwe != nullptr) {
assert(glwe.getPolynomialSize() == 1);
return LweCiphertextType::get(context, glwe.getDimension(), glwe.getP());
return LweCiphertextType::get(context, glwe.getDimension(), glwe.getP(),
glwe.getCrtDecomposition());
}
auto lwe = type.dyn_cast_or_null<LweCiphertextType>();
if (lwe != nullptr) {
@@ -122,19 +123,10 @@ mlir::Value createConcreteOpFromTFHE(mlir::PatternRewriter &rewriter,
mlir::Value createAddPlainLweCiphertextWithGlwe(
mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Value arg0,
mlir::Value arg1, mlir::OpResult result, mlir::Type encryptedType) {
PlaintextType encoded_type =
convertPlaintextTypeFromType(rewriter.getContext(), encryptedType);
// encode int into plaintext
mlir::Value encoded = rewriter
.create<mlir::concretelang::Concrete::EncodeIntOp>(
loc, encoded_type, arg1)
.plaintext();
// replace op using the encoded plaintext instead of int
auto op =
rewriter
.create<mlir::concretelang::Concrete::AddPlaintextLweCiphertextOp>(
loc, result.getType(), arg0, encoded);
loc, result.getType(), arg0, arg1);
convertOperandAndResultTypes(rewriter, op, convertTypeToLWEIfTFHEType);
@@ -171,21 +163,11 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter &rewriter,
mlir::Location loc, mlir::Value arg0,
mlir::Value arg1,
mlir::OpResult result) {
auto inType = arg0.getType();
CleartextType encoded_type =
convertCleartextTypeFromType(rewriter.getContext(), inType);
// encode int into plaintext
mlir::Value encoded =
rewriter
.create<mlir::concretelang::Concrete::IntToCleartextOp>(
loc, encoded_type, arg1)
.cleartext();
// replace op using the encoded plaintext instead of int
auto op =
rewriter
.create<mlir::concretelang::Concrete::MulCleartextLweCiphertextOp>(
loc, result.getType(), arg0, encoded);
loc, result.getType(), arg0, arg1);
convertOperandAndResultTypes(rewriter, op, convertTypeToLWEIfTFHEType);

View File

@@ -6,15 +6,44 @@
#ifndef CONCRETELANG_CONVERSION_GLOBALFHECONTEXT_H_
#define CONCRETELANG_CONVERSION_GLOBALFHECONTEXT_H_
#include <cstddef>
#include <cstdint>
#include <vector>
#include "llvm/ADT/Optional.h"
namespace mlir {
namespace concretelang {
typedef std::vector<int64_t> CRTDecomposition;
struct V0FHEConstraint {
size_t norm2;
size_t p;
};
struct PackingKeySwitchParameter {
size_t inputLweDimension;
size_t inputLweCount;
size_t outputPolynomialSize;
size_t level;
size_t baseLog;
};
struct CitcuitBoostrapParameter {
size_t level;
size_t baseLog;
};
struct WopPBSParameter {
PackingKeySwitchParameter packingKeySwitch;
CitcuitBoostrapParameter circuitBootstrap;
};
struct LargeIntegerParameter {
CRTDecomposition crtDecomposition;
WopPBSParameter wopPBS;
};
struct V0Parameter {
size_t glweDimension;
size_t logPolynomialSize;
@@ -24,6 +53,8 @@ struct V0Parameter {
size_t ksLevel;
size_t ksLogBase;
llvm::Optional<LargeIntegerParameter> largeInteger;
V0Parameter() = delete;
V0Parameter(size_t glweDimension, size_t logPolynomialSize, size_t nSmall,

View File

@@ -28,6 +28,11 @@ populateWithTensorTypeConverterPatterns(mlir::RewritePatternSet &patterns,
patterns.getContext(), typeConverter);
addDynamicallyLegalTypeOp<mlir::tensor::ExtractSliceOp>(target,
typeConverter);
// InsertOp
patterns.add<GenericTypeConverterPattern<mlir::tensor::InsertOp>>(
patterns.getContext(), typeConverter);
addDynamicallyLegalTypeOp<mlir::tensor::InsertOp>(target, typeConverter);
// InsertSliceOp
patterns.add<GenericTypeConverterPattern<mlir::tensor::InsertSliceOp>>(
patterns.getContext(), typeConverter);

View File

@@ -20,21 +20,56 @@ def BConcrete_AddLweBuffersOp : BConcrete_Op<"add_lwe_buffer"> {
let results = (outs 1DTensorOf<[I64]>:$result);
}
def BConcrete_AddCRTLweBuffersOp : BConcrete_Op<"add_crt_lwe_buffer"> {
let arguments = (ins
2DTensorOf<[I64]>:$lhs,
2DTensorOf<[I64]>:$rhs,
I64ArrayAttr:$crtDecomposition
);
let results = (outs 2DTensorOf<[I64]>:$result);
}
def BConcrete_AddPlaintextLweBufferOp : BConcrete_Op<"add_plaintext_lwe_buffer"> {
let arguments = (ins 1DTensorOf<[I64]>:$lhs, I64:$rhs);
let results = (outs 1DTensorOf<[I64]>:$result);
}
def BConcrete_AddPlaintextCRTLweBufferOp : BConcrete_Op<"add_plaintext_crt_lwe_buffer"> {
let arguments = (ins
2DTensorOf<[I64]>:$lhs,
AnyInteger:$rhs,
I64ArrayAttr:$crtDecomposition
);
let results = (outs 2DTensorOf<[I64]>:$result);
}
def BConcrete_MulCleartextLweBufferOp : BConcrete_Op<"mul_cleartext_lwe_buffer"> {
let arguments = (ins 1DTensorOf<[I64]>:$lhs, I64:$rhs);
let results = (outs 1DTensorOf<[I64]>:$result);
}
def BConcrete_MulCleartextCRTLweBufferOp : BConcrete_Op<"mul_cleartext_crt_lwe_buffer"> {
let arguments = (ins
2DTensorOf<[I64]>:$lhs,
AnyInteger:$rhs,
I64ArrayAttr:$crtDecomposition
);
let results = (outs 2DTensorOf<[I64]>:$result);
}
def BConcrete_NegateLweBufferOp : BConcrete_Op<"negate_lwe_buffer"> {
let arguments = (ins 1DTensorOf<[I64]>:$ciphertext);
let results = (outs 1DTensorOf<[I64]>:$result);
}
def BConcrete_NegateCRTLweBufferOp : BConcrete_Op<"negate_crt_lwe_buffer"> {
let arguments = (ins
2DTensorOf<[I64]>:$ciphertext,
I64ArrayAttr:$crtDecomposition
);
let results = (outs 2DTensorOf<[I64]>:$result);
}
def BConcrete_FillGlweFromTable : BConcrete_Op<"fill_glwe_from_table"> {
let arguments = (ins
1DTensorOf<[I64]>:$glwe,
@@ -67,5 +102,28 @@ def BConcrete_BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> {
let results = (outs 1DTensorOf<[I64]>:$result);
}
// TODO(16bits): hack
def BConcrete_WopPBSCRTLweBufferOp : BConcrete_Op<"wop_pbs_crt_lwe_buffer"> {
let arguments = (ins
2DTensorOf<[I64]>:$ciphertext,
1DTensorOf<[I64]>:$lookupTable,
// Bootstrap parameters
I32Attr : $bootstrapLevel,
I32Attr : $bootstrapBaseLog,
// Keyswitch parameters
I32Attr : $keyswitchLevel,
I32Attr : $keyswitchBaseLog,
// Packing keyswitch key parameters
I32Attr : $packingKeySwitchInputLweDimension,
I32Attr : $packingKeySwitchinputLweCount,
I32Attr : $packingKeySwitchoutputPolynomialSize,
I32Attr : $packingKeySwitchLevel,
I32Attr : $packingKeySwitchBaseLog,
// Circuit bootstrap parameters
I32Attr : $circuitBootstrapLevel,
I32Attr : $circuitBootstrapBaseLog
);
let results = (outs 2DTensorOf<[I64]>:$result);
}
#endif

View File

@@ -14,6 +14,8 @@
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>> createAddRuntimeContext();
std::unique_ptr<OperationPass<func::FuncOp>> createEliminateCRTOps();
} // namespace concretelang
} // namespace mlir

View File

@@ -16,4 +16,9 @@ def AddRuntimeContext : Pass<"add-runtime-context", "mlir::ModuleOp"> {
let constructor = "mlir::concretelang::createAddRuntimeContext()";
}
def EliminateCRTOps : Pass<"eliminate-bconcrete-crt-ops", "mlir::func::FuncOp"> {
let summary = "Eliminate the crt bconcrete operators.";
let constructor = "mlir::concretelang::createEliminateCRTOpsPass()";
}
#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES

View File

@@ -34,14 +34,14 @@ def Concrete_AddLweCiphertextsOp : Concrete_Op<"add_lwe_ciphertexts"> {
def Concrete_AddPlaintextLweCiphertextOp : Concrete_Op<"add_plaintext_lwe_ciphertext"> {
let summary = "Returns the sum of a clear integer and a lwe ciphertext";
let arguments = (ins Concrete_LweCiphertextType:$lhs, Concrete_PlaintextType:$rhs);
let arguments = (ins Concrete_LweCiphertextType:$lhs, AnyInteger:$rhs);
let results = (outs Concrete_LweCiphertextType:$result);
}
def Concrete_MulCleartextLweCiphertextOp : Concrete_Op<"mul_cleartext_lwe_ciphertext"> {
let summary = "Returns the product of a clear integer and a lwe ciphertext";
let arguments = (ins Concrete_LweCiphertextType:$lhs, Concrete_CleartextType:$rhs);
let arguments = (ins Concrete_LweCiphertextType:$lhs, AnyInteger:$rhs);
let results = (outs Concrete_LweCiphertextType:$result);
}
@@ -82,18 +82,30 @@ def Concrete_KeySwitchLweOp : Concrete_Op<"keyswitch_lwe"> {
let results = (outs Concrete_LweCiphertextType:$result);
}
def Concrete_EncodeIntOp : Concrete_Op<"encode_int"> {
let summary = "Encodes an integer (for it to later be added to a LWE ciphertext)";
// TODO(16bits): hack
def Concrete_WopPBSLweOp : Concrete_Op<"wop_pbs_lwe"> {
let summary = "";
let arguments = (ins AnyInteger:$i);
let results = (outs Concrete_PlaintextType:$plaintext);
}
def Concrete_IntToCleartextOp : Concrete_Op<"int_to_cleartext", [NoSideEffect]> {
let summary = "Keyswitches a LWE ciphertext";
let arguments = (ins AnyInteger:$i);
let results = (outs Concrete_CleartextType:$cleartext);
let arguments = (ins
Concrete_LweCiphertextType:$ciphertext,
1DTensorOf<[I64]>:$accumulator,
// Bootstrap parameters
I32Attr : $bootstrapLevel,
I32Attr : $bootstrapBaseLog,
// Keyswitch parameters
I32Attr : $keyswitchLevel,
I32Attr : $keyswitchBaseLog,
// Packing keyswitch key parameters
I32Attr : $packingKeySwitchInputLweDimension,
I32Attr : $packingKeySwitchinputLweCount,
I32Attr : $packingKeySwitchoutputPolynomialSize,
I32Attr : $packingKeySwitchLevel,
I32Attr : $packingKeySwitchBaseLog,
// Circuit bootstrap parameters
I32Attr : $circuitBootstrapLevel,
I32Attr : $circuitBootstrapBaseLog
);
let results = (outs Concrete_LweCiphertextType:$result);
}
#endif

View File

@@ -40,7 +40,10 @@ def Concrete_LweCiphertextType : Concrete_Type<"LweCiphertext", [MemRefElementTy
// The dimension of the lwe ciphertext
"signed":$dimension,
// Precision of the lwe ciphertext
"signed":$p
"signed":$p,
// CRT decomposition for large integers
ArrayRefParameter<"int64_t", "CRT decomposition">:$crtDecomposition
);
let hasCustomAssemblyFormat = 1;

View File

@@ -115,4 +115,29 @@ def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe"> {
let results = (outs TFHE_GLWECipherTextType : $result);
}
def TFHE_WopPBSGLWEOp : TFHE_Op<"wop_pbs_glwe"> {
let summary = "";
let arguments = (ins
TFHE_GLWECipherTextType : $ciphertext,
1DTensorOf<[I64]> : $lookupTable,
// Bootstrap parameters
I32Attr : $bootstrapLevel,
I32Attr : $bootstrapBaseLog,
// Keyswitch parameters
I32Attr : $keyswitchLevel,
I32Attr : $keyswitchBaseLog,
// Packing keyswitch key parameters
I32Attr : $packingKeySwitchInputLweDimension,
I32Attr : $packingKeySwitchinputLweCount,
I32Attr : $packingKeySwitchoutputPolynomialSize,
I32Attr : $packingKeySwitchLevel,
I32Attr : $packingKeySwitchBaseLog,
// Circuit bootstrap parameters
I32Attr : $circuitBootstrapLevel,
I32Attr : $circuitBootstrapBaseLog
);
let results = (outs TFHE_GLWECipherTextType:$result);
}
#endif

View File

@@ -6,18 +6,18 @@
include "concretelang/Dialect/TFHE/IR/TFHEDialect.td"
include "mlir/IR/BuiltinTypes.td"
class TFHE_Type<string name, list<Trait> traits = []> : TypeDef<TFHE_Dialect, name, traits> { }
class TFHE_Type<string name, list<Trait> traits = []>
: TypeDef<TFHE_Dialect, name, traits> {}
def TFHE_GLWECipherTextType : TFHE_Type<"GLWECipherText", [MemRefElementTypeInterface]> {
let mnemonic = "glwe";
def TFHE_GLWECipherTextType
: TFHE_Type<"GLWECipherText", [MemRefElementTypeInterface]> {
let mnemonic = "glwe";
let summary = "A GLWE ciphertext";
let summary = "A GLWE ciphertext";
let description = [{
An GLWE cipher text
}];
let description = [{An GLWE cipher text}];
let parameters = (ins
let parameters = (ins
// The size of the mask
"signed":$dimension,
// Size of the polynome
@@ -25,22 +25,22 @@ def TFHE_GLWECipherTextType : TFHE_Type<"GLWECipherText", [MemRefElementTypeInte
// Number of bits of the ciphertext
"signed":$bits,
// Number of bits of the plain text representation
"signed":$p
"signed":$p,
// CRT decomposition for large integers
ArrayRefParameter<"int64_t", "CRT decomposition">:$crtDecomposition
);
let hasCustomAssemblyFormat = 1;
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = true;
let genVerifyDecl = true;
let extraClassDeclaration = [{
// Returns true if has an unparametrized parameters
bool hasUnparametrizedParameters() {
return getDimension() ==-1 ||
getPolynomialSize() == -1 ||
getBits() == -1 ||
getP() == -1;
};
}];
let extraClassDeclaration = [{
// Returns true if has an unparametrized parameters
bool hasUnparametrizedParameters() {
return getDimension() == -1 || getPolynomialSize() == -1 ||
getBits() == -1 || getP() == -1;
};
}];
}
#endif

View File

@@ -58,6 +58,24 @@ void memref_bootstrap_lwe_u64(
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
mlir::concretelang::RuntimeContext *context);
uint64_t encode_crt(int64_t plaintext, uint64_t modulus, uint64_t product);
// TODO(16bits): Hackish wrapper for the 16 bits quick win
void memref_wop_pbs_crt_buffer(
// Output memref 2D memref
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset_0,
uint64_t out_offset_1, uint64_t out_size_0, uint64_t out_size_1,
uint64_t out_stride_0, uint64_t out_stride_1,
// Input memref
uint64_t *in_allocated, uint64_t *in_aligned, uint64_t in_offset_0,
uint64_t in_offset_1, uint64_t in_size_0, uint64_t in_size_1,
uint64_t in_stride_0, uint64_t in_stride_1,
// clear text lut
uint64_t *lut_ct_allocated, uint64_t *lut_ct_aligned,
uint64_t lut_ct_offset, uint64_t lut_ct_size, uint64_t lut_ct_stride,
// runtime context that hold evluation keys
mlir::concretelang::RuntimeContext *context);
void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned,
uint64_t src_offset, uint64_t src_size,
uint64_t src_stride, uint64_t *dst_allocated,

View File

@@ -42,6 +42,11 @@ struct CompilationOptions {
llvm::Optional<mlir::concretelang::V0Parameter> v0Parameter;
/// largeIntegerParameter force the compiler engine to lower FHE.eint using
/// the large integers strategy with the given parameters.
llvm::Optional<mlir::concretelang::LargeIntegerParameter>
largeIntegerParameter;
bool verifyDiagnostics;
bool autoParallelize;

View File

@@ -94,14 +94,15 @@ buildTensorLambdaResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result) {
llvm::Expected<std::vector<T>> tensorOrError =
typedResult<std::vector<T>>(keySet, result);
if (auto err = tensorOrError.takeError())
return std::move(err);
std::vector<int64_t> tensorDim(result.buffers[0].sizes.begin(),
result.buffers[0].sizes.end() - 1);
auto tensorDim = result.asClearTextShape(0);
if (tensorDim.has_error())
return StreamStringError(tensorDim.error().mesg);
return std::make_unique<TensorLambdaArgument<IntLambdaArgument<T>>>(
*tensorOrError, tensorDim);
*tensorOrError, tensorDim.value());
}
/// pecialization of `typedResult()` for a single result wrapped into

View File

@@ -36,6 +36,7 @@ markFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::LogicalResult
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult

View File

@@ -1,12 +1,12 @@
add_compile_options( -Werror )
add_compile_options(-Werror)
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# using Clang
add_compile_options( -Wno-error=pessimizing-move -Wno-pessimizing-move )
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
add_compile_options(-Wno-error=pessimizing-move -Wno-pessimizing-move)
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
# using GCC
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
add_compile_options( -Werror -Wno-error=pessimizing-move -Wno-pessimizing-move )
add_compile_options(-Werror -Wno-error=pessimizing-move -Wno-pessimizing-move)
endif()
endif()
@@ -14,6 +14,7 @@ add_mlir_library(
ConcretelangClientLib
ClientLambda.cpp
ClientParameters.cpp
CRT.cpp
EncryptedArguments.cpp
KeySet.cpp
KeySetCache.cpp
@@ -23,8 +24,6 @@ add_mlir_library(
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/ClientLib
LINK_LIBS
ConcretelangRuntime
LINK_LIBS PUBLIC
Concrete
)

View File

@@ -0,0 +1,99 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <cstddef>
#include <stdio.h>
#include "concretelang/ClientLib/CRT.h"
namespace concretelang {
namespace clientlib {
namespace crt {
uint64_t productOfModuli(std::vector<int64_t> moduli) {
uint64_t product = 1;
for (auto modulus : moduli) {
product *= modulus;
}
return product;
}
std::vector<int64_t> crt(std::vector<int64_t> moduli, uint64_t val) {
std::vector<int64_t> remainders(moduli.size(), 0);
for (size_t i = 0; i < moduli.size(); i++) {
remainders[i] = val % moduli[i];
}
return remainders;
}
// https://www.geeksforgeeks.org/multiplicative-inverse-under-modulo-m/
// Returns modulo inverse of a with respect
// to m using extended Euclid Algorithm
// Assumption: a and m are coprimes, i.e.,
// gcd(a, m) = 1
int64_t modInverse(int64_t a, int64_t m) {
int64_t m0 = m;
int64_t y = 0, x = 1;
if (m == 1)
return 0;
while (a > 1) {
// q is quotient
int64_t q = a / m;
int64_t t = m;
// m is remainder now, process same as
// Euclid's algo
m = a % m;
a = t;
t = y;
// Update y and x
y = x - q * y;
x = t;
}
// Make x positive
if (x < 0)
x += m0;
return x;
}
uint64_t iCrt(std::vector<int64_t> moduli, std::vector<int64_t> remainders) {
// Compute the product of moduli
int64_t product = productOfModuli(moduli);
int64_t result = 0;
// Apply above formula
for (size_t i = 0; i < remainders.size(); i++) {
int tmp = product / moduli[i];
result += remainders[i] * modInverse(tmp, moduli[i]) * tmp;
}
return result % product;
}
uint64_t encode(int64_t plaintext, uint64_t modulus, uint64_t product) {
// values are represented on the interval [0; product[ so we represent
// plantext on this interval
if (plaintext < 0) {
plaintext = product + plaintext;
}
__uint128_t m = plaintext % modulus;
return m * ((__uint128_t)(1) << 64) / modulus;
}
uint64_t decode(uint64_t val, uint64_t modulus) {
auto result = (__uint128_t)val * (__uint128_t)modulus;
result = result + ((result & ((__uint128_t)(1) << 63)) << 1);
result = result / ((__uint128_t)(1) << 64);
return (uint64_t)result % modulus;
}
} // namespace crt
} // namespace clientlib
} // namespace concretelang

View File

@@ -234,6 +234,9 @@ llvm::json::Value toJSON(const Encoding &v) {
llvm::json::Object object{
{"precision", v.precision},
};
if (!v.crt.empty()) {
object.insert({"crt", v.crt});
}
return object;
}
bool fromJSON(const llvm::json::Value j, Encoding &v, llvm::json::Path p) {
@@ -248,6 +251,18 @@ bool fromJSON(const llvm::json::Value j, Encoding &v, llvm::json::Path p) {
return false;
}
v.precision = precision.getValue();
auto crt = obj->getArray("crt");
if (crt != nullptr) {
for (auto dim : *crt) {
auto iDim = dim.getAsInteger();
if (!iDim.hasValue()) {
p.report("dimensions must be integer");
return false;
}
v.crt.push_back(iDim.getValue());
}
}
return true;
}

View File

@@ -11,17 +11,6 @@ namespace clientlib {
using StringError = concretelang::error::StringError;
size_t bitWidthAsWord(size_t exactBitWidth) {
size_t sortedWordBitWidths[] = {8, 16, 32, 64};
size_t previousWidth = 0;
for (auto currentWidth : sortedWordBitWidths) {
if (previousWidth < exactBitWidth && exactBitWidth <= currentWidth) {
return currentWidth;
}
}
return exactBitWidth;
}
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
EncryptedArguments::exportPublicArguments(ClientParameters clientParameters,
RuntimeContext runtimeContext) {
@@ -33,7 +22,7 @@ outcome::checked<void, StringError>
EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
auto pos = currentPos++;
CircuitGate input = keySet.inputGate(pos);
OUTCOME_TRY(CircuitGate input, keySet.clientParameters().input(pos));
if (input.shape.size != 0) {
return StringError("argument #") << pos << " is not a scalar";
}
@@ -42,12 +31,11 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
preparedArgs.push_back((void *)arg);
return outcome::success();
}
ciphertextBuffers.resize(ciphertextBuffers.size() + 1); // Allocate empty
// Allocate empty
ciphertextBuffers.resize(ciphertextBuffers.size() + 1);
TensorData &values_and_sizes = ciphertextBuffers.back();
auto lweSize = keySet.getInputLweSecretKeyParam(pos).lweSize();
values_and_sizes.sizes.push_back(lweSize);
values_and_sizes.values.resize(lweSize);
values_and_sizes.sizes = keySet.clientParameters().bufferShape(input);
values_and_sizes.values.resize(keySet.clientParameters().bufferSize(input));
OUTCOME_TRYV(keySet.encrypt_lwe(pos, values_and_sizes.values.data(), arg));
// Note: Since we bufferized lwe ciphertext take care of memref calling
// convention
@@ -57,97 +45,18 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
preparedArgs.push_back((void *)values_and_sizes.values.data());
// offset
preparedArgs.push_back((void *)0);
// size
preparedArgs.push_back((void *)values_and_sizes.values.size());
// stride
preparedArgs.push_back((void *)1);
return outcome::success();
}
outcome::checked<void, StringError>
EncryptedArguments::pushArg(std::vector<uint8_t> arg, KeySet &keySet) {
return pushArg(8, (void *)arg.data(), {(int64_t)arg.size()}, keySet);
}
outcome::checked<void, StringError>
EncryptedArguments::pushArg(size_t width, const void *data,
llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
auto pos = currentPos;
CircuitGate input = keySet.inputGate(pos);
// Check the width of data
if (input.shape.width > 64) {
return StringError("argument #")
<< pos << " width > 64 bits is not supported";
}
auto roundedSize = bitWidthAsWord(input.shape.width);
if (width != roundedSize) {
return StringError("argument #") << pos << "width mismatch, got " << width
<< " expected " << roundedSize;
}
// Check the shape of tensor
if (input.shape.dimensions.empty()) {
return StringError("argument #") << pos << "is not a tensor";
}
if (shape.size() != input.shape.dimensions.size()) {
return StringError("argument #")
<< pos << "has not the expected number of dimension, got "
<< shape.size() << " expected " << input.shape.dimensions.size();
}
ciphertextBuffers.resize(ciphertextBuffers.size() + 1); // Allocate empty
TensorData &values_and_sizes = ciphertextBuffers.back();
for (size_t i = 0; i < shape.size(); i++) {
values_and_sizes.sizes.push_back(shape[i]);
if (shape[i] != input.shape.dimensions[i]) {
return StringError("argument #")
<< pos << " has not the expected dimension #" << i << " , got "
<< shape[i] << " expected " << input.shape.dimensions[i];
}
}
if (input.encryption.hasValue()) {
auto lweSize = keySet.getInputLweSecretKeyParam(pos).lweSize();
values_and_sizes.sizes.push_back(lweSize);
// Encrypted tensor: for now we support only 8 bits for encrypted tensor
if (width != 8) {
return StringError("argument #")
<< pos << " width mismatch, expected 8 got " << width;
}
const uint8_t *data8 = (const uint8_t *)data;
// Allocate a buffer for ciphertexts of size of tensor
values_and_sizes.values.resize(input.shape.size * lweSize);
auto &values = values_and_sizes.values;
// Allocate ciphertexts and encrypt, for every values in tensor
for (size_t i = 0, offset = 0; i < input.shape.size;
i++, offset += lweSize) {
OUTCOME_TRYV(keySet.encrypt_lwe(pos, values.data() + offset, data8[i]));
}
} else {
values_and_sizes.values.resize(input.shape.size);
for (size_t i = 0; i < input.shape.size; i++) {
values_and_sizes.values[i] = ((const uint64_t *)data)[i];
}
}
// allocated
preparedArgs.push_back(nullptr);
// aligned
preparedArgs.push_back((void *)values_and_sizes.values.data());
// offset
preparedArgs.push_back((void *)0);
// sizes
for (size_t size : values_and_sizes.sizes) {
for (auto size : values_and_sizes.sizes) {
preparedArgs.push_back((void *)size);
}
// Set the stride for each dimension, equal to the product of the
// following dimensions.
// strides
int64_t stride = values_and_sizes.length();
for (size_t size : values_and_sizes.sizes) {
for (size_t i = 0; i < values_and_sizes.sizes.size() - 1; i++) {
auto size = values_and_sizes.sizes[i];
stride = (size == 0 ? 0 : (stride / size));
preparedArgs.push_back((void *)stride);
}
currentPos++;
preparedArgs.push_back((void *)1);
return outcome::success();
}

View File

@@ -4,6 +4,7 @@
// for license information.
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/ClientLib/CRT.h"
#include "concretelang/Support/Error.h"
#define CAPI_ERR_TO_STRINGERROR(instr, msg) \
@@ -31,16 +32,16 @@ outcome::checked<std::unique_ptr<KeySet>, StringError>
KeySet::generate(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb) {
auto keySet = std::make_unique<KeySet>();
OUTCOME_TRYV(keySet->generateKeysFromParams(params, seed_msb, seed_lsb));
OUTCOME_TRYV(keySet->setupEncryptionMaterial(params, seed_msb, seed_lsb));
return std::move(keySet);
}
outcome::checked<void, StringError>
KeySet::setupEncryptionMaterial(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb) {
_clientParameters = params;
// Set inputs and outputs LWE secret keys
{
for (auto param : params.inputs) {
@@ -189,9 +190,16 @@ KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size) {
return StringError("allocate_lwe position of argument is too high");
}
auto inputSk = inputs[argPos];
auto encryption = std::get<0>(inputSk).encryption;
if (!encryption.hasValue()) {
return StringError("allocate_lwe argument #")
<< argPos << "is not encypeted";
}
auto numBlocks =
encryption->encoding.crt.empty() ? 1 : encryption->encoding.crt.size();
size = std::get<1>(inputSk).lweSize();
*ciphertext = (uint64_t *)malloc(sizeof(uint64_t) * size);
*ciphertext = (uint64_t *)malloc(sizeof(uint64_t) * size * numBlocks);
return outcome::success();
}
@@ -205,20 +213,40 @@ bool KeySet::isOutputEncrypted(size_t argPos) {
std::get<0>(outputs[argPos]).encryption.hasValue();
}
/// Return the number of bits to represents the given value
uint64_t bitWidthOfValue(uint64_t value) { return std::ceil(std::log2(value)); }
outcome::checked<void, StringError>
KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) {
if (argPos >= inputs.size()) {
return StringError("encrypt_lwe position of argument is too high");
}
auto inputSk = inputs[argPos];
if (!std::get<0>(inputSk).encryption.hasValue()) {
auto encryption = std::get<0>(inputSk).encryption;
if (!encryption.hasValue()) {
return StringError("encrypt_lwe the positional argument is not encrypted");
}
// Encode - TODO we could check if the input value is in the right range
uint64_t plaintext =
input << (64 - (std::get<0>(inputSk).encryption->encoding.precision + 1));
::encrypt_lwe_u64(engine, std::get<2>(inputSk), ciphertext, plaintext,
std::get<0>(inputSk).encryption->variance);
auto encoding = encryption->encoding;
auto lweSecretKeyParam = std::get<1>(inputSk);
auto lweSecretKey = std::get<2>(inputSk);
// CRT encoding - N blocks with crt encoding
auto crt = encryption->encoding.crt;
if (!crt.empty()) {
// Put each decomposition into a new ciphertext
auto product = crt::productOfModuli(crt);
for (auto modulus : crt) {
auto plaintext = crt::encode(input, modulus, product);
::encrypt_lwe_u64(engine, lweSecretKey, ciphertext, plaintext,
encryption->variance);
ciphertext = ciphertext + lweSecretKeyParam.lweSize();
}
return outcome::success();
}
// Simple TFHE integers - 1 blocks with one padding bits
// TODO we could check if the input value is in the right range
uint64_t plaintext = input << (64 - (encryption->encoding.precision + 1));
::encrypt_lwe_u64(engine, lweSecretKey, ciphertext, plaintext,
encryption->variance);
return outcome::success();
}
@@ -228,13 +256,31 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
return StringError("decrypt_lwe: position of argument is too high");
}
auto outputSk = outputs[argPos];
if (!std::get<0>(outputSk).encryption.hasValue()) {
auto lweSecretKey = std::get<2>(outputSk);
auto lweSecretKeyParam = std::get<1>(outputSk);
auto encryption = std::get<0>(outputSk).encryption;
if (!encryption.hasValue()) {
return StringError("decrypt_lwe: the positional argument is not encrypted");
}
uint64_t plaintext =
::decrypt_lwe_u64(engine, std::get<2>(outputSk), ciphertext);
auto crt = encryption->encoding.crt;
// CRT encoding - N blocks with crt encoding
if (!crt.empty()) {
std::vector<int64_t> remainders;
// decrypt and decode remainders
for (auto modulus : crt) {
auto decrypted = ::decrypt_lwe_u64(engine, lweSecretKey, ciphertext);
auto plaintext = crt::decode(decrypted, modulus);
remainders.push_back(plaintext);
ciphertext = ciphertext + lweSecretKeyParam.lweSize();
}
// compute the inverse crt
output = crt::iCrt(crt, remainders);
return outcome::success();
}
// Simple TFHE integers - 1 blocks with one padding bits
uint64_t plaintext = ::decrypt_lwe_u64(engine, lweSecretKey, ciphertext);
// Decode
size_t precision = std::get<0>(outputSk).encryption->encoding.precision;
size_t precision = encryption->encoding.precision;
output = plaintext >> (64 - precision - 2);
size_t carry = output % 2;
output = ((output >> 1) + carry) % (1 << (precision + 1));

View File

@@ -90,7 +90,6 @@ KeySetCache::loadKeys(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb, std::string folderPath) {
// TODO: text dump of all parameter in /hash
auto key_set = std::make_unique<KeySet>();
// Mark the folder as recently use.
// e.g. so the CI can do some cleanup of unused keys.
utime(folderPath.c_str(), nullptr);

View File

@@ -38,6 +38,9 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/Function.h"
namespace Concrete = ::mlir::concretelang::Concrete;
namespace BConcrete = ::mlir::concretelang::BConcrete;
namespace {
struct ConcreteToBConcretePass
: public ConcreteToBConcreteBase<ConcreteToBConcretePass> {
@@ -68,9 +71,14 @@ public:
});
addConversion([&](mlir::concretelang::Concrete::LweCiphertextType type) {
assert(type.getDimension() != -1);
llvm::SmallVector<int64_t, 2> shape;
auto crt = type.getCrtDecomposition();
if (!crt.empty()) {
shape.push_back(crt.size());
}
shape.push_back(type.getDimension() + 1);
return mlir::RankedTensorType::get(
{type.getDimension() + 1},
mlir::IntegerType::get(type.getContext(), 64));
shape, mlir::IntegerType::get(type.getContext(), 64));
});
addConversion([&](mlir::concretelang::Concrete::GlweCiphertextType type) {
assert(type.getGlweDimension() != -1);
@@ -91,91 +99,18 @@ public:
mlir::SmallVector<int64_t> newShape;
newShape.reserve(type.getShape().size() + 1);
newShape.append(type.getShape().begin(), type.getShape().end());
auto crt = lwe.getCrtDecomposition();
if (!crt.empty()) {
newShape.push_back(crt.size());
}
newShape.push_back(lwe.getDimension() + 1);
mlir::Type r = mlir::RankedTensorType::get(
newShape, mlir::IntegerType::get(type.getContext(), 64));
return r;
});
addConversion([&](mlir::MemRefType type) {
auto lwe = type.getElementType()
.dyn_cast_or_null<
mlir::concretelang::Concrete::LweCiphertextType>();
if (lwe == nullptr) {
return (mlir::Type)(type);
}
assert(lwe.getDimension() != -1);
mlir::SmallVector<int64_t> newShape;
newShape.reserve(type.getShape().size() + 1);
newShape.append(type.getShape().begin(), type.getShape().end());
newShape.push_back(lwe.getDimension() + 1);
mlir::Type r = mlir::MemRefType::get(
newShape, mlir::IntegerType::get(type.getContext(), 64));
return r;
});
}
};
struct ConcreteEncodeIntOpPattern
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::EncodeIntOp> {
ConcreteEncodeIntOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<mlir::concretelang::Concrete::EncodeIntOp>(
context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::Concrete::EncodeIntOp op,
mlir::PatternRewriter &rewriter) const override {
{
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
op.getLoc(), rewriter.getIntegerType(64), op->getOperands().front());
mlir::Value constantShiftOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI64IntegerAttr(64 - op.getType().getP()));
mlir::Type resultType = rewriter.getIntegerType(64);
rewriter.replaceOpWithNewOp<mlir::arith::ShLIOp>(
op, resultType, castedInt, constantShiftOp);
}
return mlir::success();
};
};
struct ConcreteIntToCleartextOpPattern
: public mlir::OpRewritePattern<
mlir::concretelang::Concrete::IntToCleartextOp> {
ConcreteIntToCleartextOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<mlir::concretelang::Concrete::IntToCleartextOp>(
context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::Concrete::IntToCleartextOp op,
mlir::PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(
op, rewriter.getIntegerType(64), op->getOperands().front());
return mlir::success();
};
};
/// This rewrite pattern transforms any instance of `Concrete.zero_tensor`
/// operators.
///
/// Example:
///
/// ```mlir
/// %0 = "Concrete.zero_tensor" () :
/// tensor<...x!Concrete.lwe_ciphertext<lweDim,p>>
/// ```
///
/// becomes:
///
/// ```mlir
/// %0 = tensor.generate {
/// ^bb0(... : index):
/// %c0 = arith.constant 0 : i64
/// tensor.yield %z
/// }: tensor<...xlweDim+1xi64>
/// i64>
/// ```
template <typename ZeroOp>
struct ZeroOpPattern : public mlir::OpRewritePattern<ZeroOp> {
ZeroOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
@@ -205,20 +140,7 @@ struct ZeroOpPattern : public mlir::OpRewritePattern<ZeroOp> {
};
};
/// This template rewrite pattern transforms any instance of
/// `ConcreteOp` to an instance of `BConcreteOp`.
///
/// Example:
///
/// %0 = "ConcreteOp"(%arg0, ...) :
/// (!Concrete.lwe_ciphertext<lwe_dimension, p>, ...) ->
/// (!Concrete.lwe_ciphertext<lwe_dimension, p>)
///
/// becomes:
///
/// %0 = "BConcreteOp"(%arg0, ...) : (tensor<dimension+1, i64>>, ..., ) ->
/// (tensor<dimension+1, i64>>)
template <typename ConcreteOp, typename BConcreteOp>
template <typename ConcreteOp, typename BConcreteOp, typename BConcreteCRTOp>
struct LowToBConcrete : public mlir::OpRewritePattern<ConcreteOp> {
LowToBConcrete(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<ConcreteOp>(context, benefit) {}
@@ -236,9 +158,126 @@ struct LowToBConcrete : public mlir::OpRewritePattern<ConcreteOp> {
llvm::ArrayRef<::mlir::NamedAttribute> attributes =
concreteOp.getOperation()->getAttrs();
BConcreteOp bConcreteOp = rewriter.replaceOpWithNewOp<BConcreteOp>(
concreteOp, newResultTy, concreteOp.getOperation()->getOperands(),
attributes);
auto crt = resultTy.getCrtDecomposition();
mlir::Operation *bConcreteOp;
if (crt.empty()) {
bConcreteOp = rewriter.replaceOpWithNewOp<BConcreteOp>(
concreteOp, newResultTy, concreteOp.getOperation()->getOperands(),
attributes);
} else {
auto newAttributes = attributes.vec();
newAttributes.push_back(rewriter.getNamedAttr(
"crtDecomposition", rewriter.getI64ArrayAttr(crt)));
bConcreteOp = rewriter.replaceOpWithNewOp<BConcreteCRTOp>(
concreteOp, newResultTy, concreteOp.getOperation()->getOperands(),
newAttributes);
}
mlir::concretelang::convertOperandAndResultTypes(
rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) {
return converter.convertType(t);
});
return ::mlir::success();
};
};
struct AddPlaintextLweCiphertextOpPattern
: public mlir::OpRewritePattern<Concrete::AddPlaintextLweCiphertextOp> {
AddPlaintextLweCiphertextOpPattern(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<Concrete::AddPlaintextLweCiphertextOp>(
context, benefit) {}
::mlir::LogicalResult
matchAndRewrite(Concrete::AddPlaintextLweCiphertextOp concreteOp,
::mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
auto loc = concreteOp.getLoc();
mlir::concretelang::Concrete::LweCiphertextType resultTy =
((mlir::Type)concreteOp->getResult(0).getType())
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
auto newResultTy =
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
llvm::ArrayRef<::mlir::NamedAttribute> attributes =
concreteOp.getOperation()->getAttrs();
auto crt = resultTy.getCrtDecomposition();
mlir::Operation *bConcreteOp;
if (crt.empty()) {
// Encode the plaintext value
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
loc, rewriter.getIntegerType(64), concreteOp.rhs());
mlir::Value constantShiftOp = rewriter.create<mlir::arith::ConstantOp>(
loc,
rewriter.getI64IntegerAttr(64 - concreteOp.getType().getP() - 1));
auto encoded = rewriter.create<mlir::arith::ShLIOp>(
loc, rewriter.getI64Type(), castedInt, constantShiftOp);
bConcreteOp =
rewriter.replaceOpWithNewOp<BConcrete::AddPlaintextLweBufferOp>(
concreteOp, newResultTy,
mlir::ValueRange{concreteOp.lhs(), encoded}, attributes);
} else {
// The encoding is done when we eliminate CRT ops
auto newAttributes = attributes.vec();
newAttributes.push_back(rewriter.getNamedAttr(
"crtDecomposition", rewriter.getI64ArrayAttr(crt)));
bConcreteOp =
rewriter.replaceOpWithNewOp<BConcrete::AddPlaintextCRTLweBufferOp>(
concreteOp, newResultTy, concreteOp.getOperation()->getOperands(),
newAttributes);
}
mlir::concretelang::convertOperandAndResultTypes(
rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) {
return converter.convertType(t);
});
return ::mlir::success();
};
};
struct MulCleartextLweCiphertextOpPattern
: public mlir::OpRewritePattern<Concrete::MulCleartextLweCiphertextOp> {
MulCleartextLweCiphertextOpPattern(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<Concrete::MulCleartextLweCiphertextOp>(
context, benefit) {}
::mlir::LogicalResult
matchAndRewrite(Concrete::MulCleartextLweCiphertextOp concreteOp,
::mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
auto loc = concreteOp.getLoc();
mlir::concretelang::Concrete::LweCiphertextType resultTy =
((mlir::Type)concreteOp->getResult(0).getType())
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
auto newResultTy =
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
llvm::ArrayRef<::mlir::NamedAttribute> attributes =
concreteOp.getOperation()->getAttrs();
auto crt = resultTy.getCrtDecomposition();
mlir::Operation *bConcreteOp;
if (crt.empty()) {
// Encode the plaintext value
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
loc, rewriter.getIntegerType(64), concreteOp.rhs());
bConcreteOp =
rewriter.replaceOpWithNewOp<BConcrete::MulCleartextLweBufferOp>(
concreteOp, newResultTy,
mlir::ValueRange{concreteOp.lhs(), castedInt}, attributes);
} else {
auto newAttributes = attributes.vec();
newAttributes.push_back(rewriter.getNamedAttr(
"crtDecomposition", rewriter.getI64ArrayAttr(crt)));
bConcreteOp =
rewriter.replaceOpWithNewOp<BConcrete::MulCleartextCRTLweBufferOp>(
concreteOp, newResultTy, concreteOp.getOperation()->getOperands(),
newAttributes);
}
mlir::concretelang::convertOperandAndResultTypes(
rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) {
@@ -249,27 +288,6 @@ struct LowToBConcrete : public mlir::OpRewritePattern<ConcreteOp> {
};
};
/// This rewrite pattern transforms any instance of
/// `Concrete.glwe_from_table` operators.
///
/// Example:
///
/// ```mlir
/// %0 = "Concrete.glwe_from_table"(%tlu)
/// : (tensor<$Dxi64>) ->
/// !Concrete.glwe_ciphertext<$polynomialSize,$glweDimension,$p>
/// ```
///
/// with $D = 2^$p
///
/// becomes:
///
/// ```mlir
/// %0 = linalg.init_tensor [polynomialSize*(glweDimension+1)]
/// : tensor<polynomialSize*(glweDimension+1), i64>
/// "BConcrete.fill_glwe_from_table" : (%0, polynomialSize, glweDimension, %tlu)
/// : tensor<polynomialSize*(glweDimension+1), i64>, i64, i64, tensor<$Dxi64>
/// ```
struct GlweFromTablePattern : public mlir::OpRewritePattern<
mlir::concretelang::Concrete::GlweFromTable> {
GlweFromTablePattern(::mlir::MLIRContext *context,
@@ -307,26 +325,6 @@ struct GlweFromTablePattern : public mlir::OpRewritePattern<
};
};
/// This rewrite pattern transforms any instance of
/// `tensor.extract_slice` operators that operates on tensor of lwe ciphertext.
///
/// Example:
///
/// ```mlir
/// %0 = tensor.extract_slice %arg0
/// [offsets...] [sizes...] [strides...]
/// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>> to
/// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
/// ```
///
/// becomes:
///
/// ```mlir
/// %0 = tensor.extract_slice %arg0
/// [offsets..., 0] [sizes..., lweDimension+1] [strides..., 1]
/// : tensor<...xlweDimension+1,i64> to
/// tensor<...xlweDimension+1,i64>
/// ```
struct ExtractSliceOpPattern
: public mlir::OpRewritePattern<mlir::tensor::ExtractSliceOp> {
ExtractSliceOpPattern(::mlir::MLIRContext *context,
@@ -339,29 +337,41 @@ struct ExtractSliceOpPattern
::mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
auto resultTy = extractSliceOp.result().getType();
auto resultEltTy =
auto lweResultTy =
resultTy.cast<mlir::RankedTensorType>()
.getElementType()
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
auto newResultTy = converter.convertType(resultTy);
auto nbBlock = lweResultTy.getCrtDecomposition().size();
auto newResultTy =
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
// add 0 to the static_offsets
mlir::SmallVector<mlir::Attribute> staticOffsets;
staticOffsets.append(extractSliceOp.static_offsets().begin(),
extractSliceOp.static_offsets().end());
if (nbBlock != 0) {
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
}
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
// add the lweSize to the sizes
mlir::SmallVector<mlir::Attribute> staticSizes;
staticSizes.append(extractSliceOp.static_sizes().begin(),
extractSliceOp.static_sizes().end());
staticSizes.push_back(
rewriter.getI64IntegerAttr(resultEltTy.getDimension() + 1));
if (nbBlock != 0) {
staticSizes.push_back(rewriter.getI64IntegerAttr(
newResultTy.getDimSize(newResultTy.getRank() - 2)));
}
staticSizes.push_back(rewriter.getI64IntegerAttr(
newResultTy.getDimSize(newResultTy.getRank() - 1)));
// add 1 to the strides
mlir::SmallVector<mlir::Attribute> staticStrides;
staticStrides.append(extractSliceOp.static_strides().begin(),
extractSliceOp.static_strides().end());
if (nbBlock != 0) {
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
}
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
// replace tensor.extract_slice to the new one
@@ -382,26 +392,6 @@ struct ExtractSliceOpPattern
};
};
/// This rewrite pattern transforms any instance of
/// `tensor.extract` operators that operates on tensor of lwe ciphertext.
///
/// Example:
///
/// ```mlir
/// %0 = tensor.extract %t[offsets...]
/// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
/// ```
///
/// becomes:
///
/// ```mlir
/// %1 = tensor.extract_slice %arg0
/// [offsets...] [1..., lweDimension+1] [1...]
/// : tensor<...xlweDimension+1,i64> to
/// tensor<1...xlweDimension+1,i64>
/// %0 = linalg.tensor_collapse_shape %0 [[...]] :
/// tensor<1x1xlweDimension+1xi64> into tensor<lweDimension+1xi64>
/// ```
// TODO: since they are a bug on lowering extract_slice with rank reduction we
// add a linalg.tensor_collapse_shape after the extract_slice without rank
// reduction. See
@@ -424,20 +414,29 @@ struct ExtractOpPattern
if (lweResultTy == nullptr) {
return mlir::failure();
}
auto nbBlock = lweResultTy.getCrtDecomposition().size();
auto newResultTy =
converter.convertType(lweResultTy).cast<mlir::RankedTensorType>();
auto rankOfResult = extractOp.indices().size() + 1;
auto rankOfResult = extractOp.indices().size() +
/* for the lwe dimension */ 1 +
/* for the block dimension */
(nbBlock == 0 ? 0 : 1);
// [min..., 0] for static_offsets ()
mlir::SmallVector<mlir::Attribute> staticOffsets(
rankOfResult,
rewriter.getI64IntegerAttr(std::numeric_limits<int64_t>::min()));
if (nbBlock != 0) {
staticOffsets[staticOffsets.size() - 2] = rewriter.getI64IntegerAttr(0);
}
staticOffsets[staticOffsets.size() - 1] = rewriter.getI64IntegerAttr(0);
// [1..., lweDimension+1] for static_sizes
// [1..., lweDimension+1] for static_sizes or
// [1..., nbBlock, lweDimension+1]
mlir::SmallVector<mlir::Attribute> staticSizes(
rankOfResult, rewriter.getI64IntegerAttr(1));
if (nbBlock != 0) {
staticSizes[staticSizes.size() - 2] = rewriter.getI64IntegerAttr(nbBlock);
}
staticSizes[staticSizes.size() - 1] = rewriter.getI64IntegerAttr(
newResultTy.getDimSize(newResultTy.getRank() - 1));
@@ -446,38 +445,45 @@ struct ExtractOpPattern
rankOfResult, rewriter.getI64IntegerAttr(1));
// replace tensor.extract_slice to the new one
mlir::SmallVector<int64_t> extractedSliceShape(
extractOp.indices().size() + 1, 0);
extractedSliceShape.reserve(extractOp.indices().size() + 1);
for (size_t i = 0; i < extractedSliceShape.size() - 1; i++) {
extractedSliceShape[i] = 1;
mlir::SmallVector<int64_t> extractedSliceShape(rankOfResult, 1);
if (nbBlock != 0) {
extractedSliceShape[extractedSliceShape.size() - 2] = nbBlock;
extractedSliceShape[extractedSliceShape.size() - 1] =
newResultTy.getDimSize(1);
} else {
extractedSliceShape[extractedSliceShape.size() - 1] =
newResultTy.getDimSize(0);
}
extractedSliceShape[extractedSliceShape.size() - 1] =
newResultTy.getDimSize(0);
auto extractedSliceType =
mlir::RankedTensorType::get(extractedSliceShape, rewriter.getI64Type());
auto extractedSlice = rewriter.create<mlir::tensor::ExtractSliceOp>(
extractOp.getLoc(), extractedSliceType, extractOp.tensor(),
extractOp.indices(), mlir::SmallVector<mlir::Value>{},
mlir::SmallVector<mlir::Value>{}, rewriter.getArrayAttr(staticOffsets),
rewriter.getArrayAttr(staticSizes),
rewriter.getArrayAttr(staticStrides));
mlir::concretelang::convertOperandAndResultTypes(
rewriter, extractedSlice, [&](mlir::MLIRContext *, mlir::Type t) {
return converter.convertType(t);
});
mlir::ReassociationIndices reassociation;
for (int64_t i = 0; i < extractedSliceType.getRank(); i++) {
for (int64_t i = 0;
i < extractedSliceType.getRank() - (nbBlock == 0 ? 0 : 1); i++) {
reassociation.push_back(i);
}
mlir::SmallVector<mlir::ReassociationIndices> reassocs{reassociation};
if (nbBlock != 0) {
reassocs.push_back({extractedSliceType.getRank() - 1});
}
mlir::tensor::CollapseShapeOp collapseOp =
rewriter.replaceOpWithNewOp<mlir::tensor::CollapseShapeOp>(
extractOp, newResultTy, extractedSlice,
mlir::SmallVector<mlir::ReassociationIndices>{reassociation});
extractOp, newResultTy, extractedSlice, reassocs);
mlir::concretelang::convertOperandAndResultTypes(
rewriter, collapseOp, [&](mlir::MLIRContext *, mlir::Type t) {
@@ -488,26 +494,6 @@ struct ExtractOpPattern
};
};
/// This rewrite pattern transforms any instance of
/// `tensor.insert_slice` operators that operates on tensor of lwe ciphertext.
///
/// Example:
///
/// ```mlir
/// %0 = tensor.insert_slice %arg1
/// into %arg0[offsets...] [sizes...] [strides...]
/// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>> into
/// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
/// ```
///
/// becomes:
///
/// ```mlir
/// %0 = tensor.insert_slice %arg1
/// into %arg0[offsets..., 0] [sizes..., lweDimension+1] [strides..., 1]
/// : tensor<...xlweDimension+1xi64> into
/// tensor<...xlweDimension+1xi64>
/// ```
struct InsertSliceOpPattern
: public mlir::OpRewritePattern<mlir::tensor::InsertSliceOp> {
InsertSliceOpPattern(::mlir::MLIRContext *context,
@@ -520,7 +506,14 @@ struct InsertSliceOpPattern
::mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
auto resultTy = insertSliceOp.result().getType();
auto lweResultTy =
resultTy.cast<mlir::RankedTensorType>()
.getElementType()
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
if (lweResultTy == nullptr) {
return mlir::failure();
}
auto nbBlock = lweResultTy.getCrtDecomposition().size();
auto newResultTy =
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
@@ -528,12 +521,19 @@ struct InsertSliceOpPattern
mlir::SmallVector<mlir::Attribute> staticOffsets;
staticOffsets.append(insertSliceOp.static_offsets().begin(),
insertSliceOp.static_offsets().end());
if (nbBlock != 0) {
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
}
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
// add lweDimension+1 to static_sizes
mlir::SmallVector<mlir::Attribute> staticSizes;
staticSizes.append(insertSliceOp.static_sizes().begin(),
insertSliceOp.static_sizes().end());
if (nbBlock != 0) {
staticSizes.push_back(rewriter.getI64IntegerAttr(
newResultTy.getDimSize(newResultTy.getRank() - 2)));
}
staticSizes.push_back(rewriter.getI64IntegerAttr(
newResultTy.getDimSize(newResultTy.getRank() - 1)));
@@ -541,6 +541,9 @@ struct InsertSliceOpPattern
mlir::SmallVector<mlir::Attribute> staticStrides;
staticStrides.append(insertSliceOp.static_strides().begin(),
insertSliceOp.static_strides().end());
if (nbBlock != 0) {
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
}
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
// replace tensor.insert_slice with the new one
@@ -560,28 +563,6 @@ struct InsertSliceOpPattern
};
};
/// This rewrite pattern transforms any instance of `tensor.insert`
/// operators that operates on an lwe ciphertexts to a
/// `tensor.insert_slice` op operating on the bufferized representation
/// of the ciphertext.
///
/// Example:
///
/// ```mlir
/// %0 = tensor.insert %arg1
/// into %arg0[offsets...]
/// : !Concrete.lwe_ciphertext<lweDimension,p> into
/// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
/// ```
///
/// becomes:
///
/// ```mlir
/// %0 = tensor.insert_slice %arg1
/// into %arg0[offsets..., 0] [sizes..., lweDimension+1] [strides..., 1]
/// : tensor<lweDimension+1xi64> into
/// tensor<...xlweDimension+1xi64>
/// ```
struct InsertOpPattern : public mlir::OpRewritePattern<mlir::tensor::InsertOp> {
InsertOpPattern(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
@@ -591,14 +572,24 @@ struct InsertOpPattern : public mlir::OpRewritePattern<mlir::tensor::InsertOp> {
matchAndRewrite(mlir::tensor::InsertOp insertOp,
::mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
mlir::Type resultTy = insertOp.result().getType();
auto resultTy =
insertOp.result().getType().dyn_cast_or_null<mlir::RankedTensorType>();
auto lweResultTy = resultTy.getElementType()
.dyn_cast_or_null<Concrete::LweCiphertextType>();
if (lweResultTy == nullptr) {
return mlir::failure();
};
auto hasBlock = lweResultTy.getCrtDecomposition().size() != 0;
mlir::RankedTensorType newResultTy =
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
// add 0 to static_offsets
// add zeros to static_offsets
mlir::SmallVector<mlir::OpFoldResult> offsets;
offsets.append(insertOp.indices().begin(), insertOp.indices().end());
offsets.push_back(rewriter.getIndexAttr(0));
if (hasBlock) {
offsets.push_back(rewriter.getIndexAttr(0));
}
// Inserting a smaller tensor into a (potentially) bigger one. Set
// dimensions for all leading dimensions of the target tensor not
@@ -607,6 +598,10 @@ struct InsertOpPattern : public mlir::OpRewritePattern<mlir::tensor::InsertOp> {
rewriter.getI64IntegerAttr(1));
// Add size for the bufferized source element
if (hasBlock) {
sizes.push_back(rewriter.getI64IntegerAttr(
newResultTy.getDimSize(newResultTy.getRank() - 2)));
}
sizes.push_back(rewriter.getI64IntegerAttr(
newResultTy.getDimSize(newResultTy.getRank() - 1)));
@@ -711,26 +706,26 @@ struct FromElementsOpPattern
};
};
/// This template rewrite pattern transforms any instance of
/// `ShapeOp` operators that operates on tensor of lwe ciphertext by adding the
/// lwe size as a size of the tensor result and by adding a trivial
/// reassociation at the end of the reassociations map.
///
/// Example:
///
/// ```mlir
/// %0 = "ShapeOp" %arg0 [reassocations...]
/// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>> into
/// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
/// ```
///
/// becomes:
///
/// ```mlir
/// %0 = "ShapeOp" %arg0 [reassociations..., [inRank or outRank]]
/// : tensor<...xlweDimesion+1xi64> into
/// tensor<...xlweDimesion+1xi64>
/// ```
// This template rewrite pattern transforms any instance of
// `ShapeOp` operators that operates on tensor of lwe ciphertext by adding the
// lwe size as a size of the tensor result and by adding a trivial
// reassociation at the end of the reassociations map.
//
// Example:
//
// ```mlir
// %0 = "ShapeOp" %arg0 [reassocations...]
// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>> into
// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
// ```
//
// becomes:
//
// ```mlir
// %0 = "ShapeOp" %arg0 [reassociations..., [inRank or outRank]]
// : tensor<...xlweDimesion+1xi64> into
// tensor<...xlweDimesion+1xi64>
// ```
template <typename ShapeOp, typename VecTy, bool inRank>
struct TensorShapeOpPattern : public mlir::OpRewritePattern<ShapeOp> {
TensorShapeOpPattern(::mlir::MLIRContext *context,
@@ -741,22 +736,35 @@ struct TensorShapeOpPattern : public mlir::OpRewritePattern<ShapeOp> {
matchAndRewrite(ShapeOp shapeOp,
::mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
auto resultTy = shapeOp.result().getType();
auto resultTy = ((mlir::Type)shapeOp.result().getType()).cast<VecTy>();
auto lweResultTy =
((mlir::Type)resultTy.getElementType())
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
auto newResultTy =
((mlir::Type)converter.convertType(resultTy)).cast<VecTy>();
// add [rank] to reassociations
auto oldReassocs = shapeOp.getReassociationIndices();
mlir::SmallVector<mlir::ReassociationIndices> newReassocs;
newReassocs.append(oldReassocs.begin(), oldReassocs.end());
mlir::ReassociationIndices lweAssoc;
auto reassocTy =
((mlir::Type)converter.convertType(
(inRank ? shapeOp.src() : shapeOp.result()).getType()))
.cast<VecTy>();
lweAssoc.push_back(reassocTy.getRank() - 1);
newReassocs.push_back(lweAssoc);
auto oldReassocs = shapeOp.getReassociationIndices();
mlir::SmallVector<mlir::ReassociationIndices> newReassocs;
newReassocs.append(oldReassocs.begin(), oldReassocs.end());
// add [rank-1] to reassociations if crt decomp
if (!lweResultTy.getCrtDecomposition().empty()) {
mlir::ReassociationIndices lweAssoc;
lweAssoc.push_back(reassocTy.getRank() - 2);
newReassocs.push_back(lweAssoc);
}
// add [rank] to reassociations
{
mlir::ReassociationIndices lweAssoc;
lweAssoc.push_back(reassocTy.getRank() - 1);
newReassocs.push_back(lweAssoc);
}
ShapeOp op = rewriter.replaceOpWithNewOp<ShapeOp>(
shapeOp, newResultTy, shapeOp.src(), newReassocs);
@@ -785,20 +793,6 @@ void insertTensorShapeOpPattern(mlir::MLIRContext &context,
});
}
/// Rewrites `bufferization.alloc_tensor` ops for which the converted type in
/// BConcrete is different from the original type.
///
/// Example:
///
/// ```
/// bufferization.alloc_tensor() : tensor<4x!Concrete.lwe_ciphertext<4096,6>>
/// ```
///
/// becomes:
///
/// ```
/// bufferization.alloc_tensor() : tensor<4x4097xi64>
/// ```
struct AllocTensorOpPattern
: public mlir::OpRewritePattern<mlir::bufferization::AllocTensorOp> {
AllocTensorOpPattern(::mlir::MLIRContext *context,
@@ -877,10 +871,6 @@ void ConcreteToBConcretePass::runOnOperation() {
// Add Concrete ops are illegal after the conversion
target.addIllegalDialect<mlir::concretelang::Concrete::ConcreteDialect>();
// Add patterns to convert cleartext and plaintext to i64
patterns
.insert<ConcreteEncodeIntOpPattern, ConcreteIntToCleartextOpPattern>(
&getContext());
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
// Add patterns to convert the zero ops to tensor.generate
@@ -894,27 +884,25 @@ void ConcreteToBConcretePass::runOnOperation() {
// BConcrete op
patterns.insert<
LowToBConcrete<mlir::concretelang::Concrete::AddLweCiphertextsOp,
mlir::concretelang::BConcrete::AddLweBuffersOp>,
LowToBConcrete<
mlir::concretelang::Concrete::AddPlaintextLweCiphertextOp,
mlir::concretelang::BConcrete::AddPlaintextLweBufferOp>,
LowToBConcrete<
mlir::concretelang::Concrete::MulCleartextLweCiphertextOp,
mlir::concretelang::BConcrete::MulCleartextLweBufferOp>,
LowToBConcrete<
mlir::concretelang::Concrete::MulCleartextLweCiphertextOp,
mlir::concretelang::BConcrete::MulCleartextLweBufferOp>,
mlir::concretelang::BConcrete::AddLweBuffersOp,
BConcrete::AddCRTLweBuffersOp>,
AddPlaintextLweCiphertextOpPattern, MulCleartextLweCiphertextOpPattern,
LowToBConcrete<mlir::concretelang::Concrete::NegateLweCiphertextOp,
mlir::concretelang::BConcrete::NegateLweBufferOp>,
mlir::concretelang::BConcrete::NegateLweBufferOp,
BConcrete::NegateCRTLweBufferOp>,
LowToBConcrete<mlir::concretelang::Concrete::KeySwitchLweOp,
mlir::concretelang::BConcrete::KeySwitchLweBufferOp,
mlir::concretelang::BConcrete::KeySwitchLweBufferOp>,
LowToBConcrete<mlir::concretelang::Concrete::BootstrapLweOp,
mlir::concretelang::BConcrete::BootstrapLweBufferOp>>(
&getContext());
mlir::concretelang::BConcrete::BootstrapLweBufferOp,
mlir::concretelang::BConcrete::KeySwitchLweBufferOp>,
LowToBConcrete<Concrete::WopPBSLweOp, BConcrete::WopPBSCRTLweBufferOp,
BConcrete::WopPBSCRTLweBufferOp>>(&getContext());
patterns.insert<GlweFromTablePattern>(&getContext());
// Add patterns to rewrite tensor operators that works on encrypted tensors
// Add patterns to rewrite tensor operators that works on encrypted
// tensors
patterns
.insert<ExtractSliceOpPattern, ExtractOpPattern, InsertSliceOpPattern,
InsertOpPattern, FromElementsOpPattern>(&getContext());
@@ -939,7 +927,8 @@ void ConcreteToBConcretePass::runOnOperation() {
patterns.insert<ForOpPattern>(&getContext());
// Add patterns to rewrite some of memref ops that was introduced by the
// linalg bufferization of encrypted tensor (first conversion of this pass)
// linalg bufferization of encrypted tensor (first conversion of this
// pass)
insertTensorShapeOpPattern<mlir::memref::ExpandShapeOp, mlir::MemRefType,
false>(getContext(), patterns, target);
insertTensorShapeOpPattern<mlir::tensor::ExpandShapeOp, mlir::TensorType,

View File

@@ -26,10 +26,6 @@ namespace FHE = mlir::concretelang::FHE;
namespace TFHE = mlir::concretelang::TFHE;
namespace {
struct FHEToTFHEPass : public FHEToTFHEBase<FHEToTFHEPass> {
void runOnOperation() final;
};
} // namespace
using mlir::concretelang::FHE::EncryptedIntegerType;
using mlir::concretelang::TFHE::GLWECipherTextType;
@@ -84,10 +80,10 @@ public:
/// : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) ->
/// !TFHE.glwe<{_,_,_}{2}>
/// ```
struct ApplyLookupTableEintOpPattern
struct ApplyLookupTableEintOpToKeyswitchBootstrapPattern
: public mlir::OpRewritePattern<FHE::ApplyLookupTableEintOp> {
ApplyLookupTableEintOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
ApplyLookupTableEintOpToKeyswitchBootstrapPattern(
mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<FHE::ApplyLookupTableEintOp>(context,
benefit) {}
@@ -115,6 +111,51 @@ struct ApplyLookupTableEintOpPattern
};
};
/// This rewrite pattern transforms any instance of `FHE.apply_lookup_table`
/// operators.
///
/// Example:
///
/// ```mlir
/// %0 = "FHE.apply_lookup_table"(%ct, %lut): (!FHE.eint<2>, tensor<4xi64>)
/// ->(!FHE.eint<2>)
/// ```
///
/// becomes:
///
/// ```mlir
/// %0 = "TFHE.wop_pbs_glwe"(%ct, %lut)
/// : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) ->
/// (!TFHE.glwe<{_,_,_}{2}>)
/// ```
struct ApplyLookupTableEintOpToWopPBSPattern
: public mlir::OpRewritePattern<FHE::ApplyLookupTableEintOp> {
ApplyLookupTableEintOpToWopPBSPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<FHE::ApplyLookupTableEintOp>(context,
benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::ApplyLookupTableEintOp lutOp,
mlir::PatternRewriter &rewriter) const override {
FHEToTFHETypeConverter converter;
auto inputTy = converter.convertType(lutOp.a().getType())
.cast<TFHE::GLWECipherTextType>();
auto resultTy = converter.convertType(lutOp.getType());
// %0 = "TFHE.wop_pbs_glwe"(%ct, %lut)
// : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) ->
// (!TFHE.glwe<{_,_,_}{2}>)
auto wopPBS = rewriter.replaceOpWithNewOp<TFHE::WopPBSGLWEOp>(
lutOp, resultTy, lutOp.a(), lutOp.lut(), -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1);
mlir::concretelang::convertOperandAndResultTypes(
rewriter, wopPBS, [&](mlir::MLIRContext *, mlir::Type t) {
return converter.convertType(t);
});
return ::mlir::success();
};
};
/// This rewrite pattern transforms any instance of `FHE.sub_eint_int`
/// operators to a negation and an addition.
struct SubEintIntOpPattern : public mlir::OpRewritePattern<FHE::SubEintIntOp> {
@@ -194,97 +235,118 @@ struct SubEintOpPattern : public mlir::OpRewritePattern<FHE::SubEintOp> {
};
};
void FHEToTFHEPass::runOnOperation() {
auto op = this->getOperation();
struct FHEToTFHEPass : public FHEToTFHEBase<FHEToTFHEPass> {
mlir::ConversionTarget target(getContext());
FHEToTFHETypeConverter converter;
FHEToTFHEPass(mlir::concretelang::ApplyLookupTableLowering lutLowerStrategy)
: lutLowerStrategy(lutLowerStrategy) {}
// Mark ops from the target dialect as legal operations
target.addLegalDialect<mlir::concretelang::TFHE::TFHEDialect>();
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
void runOnOperation() {
auto op = this->getOperation();
// Make sure that no ops from `FHE` remain after the lowering
target.addIllegalDialect<mlir::concretelang::FHE::FHEDialect>();
mlir::ConversionTarget target(getContext());
FHEToTFHETypeConverter converter;
// Make sure that no ops `linalg.generic` that have illegal types
target
.addDynamicallyLegalOp<mlir::linalg::GenericOp, mlir::tensor::GenerateOp>(
[&](mlir::Operation *op) {
return (
converter.isLegal(op->getOperandTypes()) &&
converter.isLegal(op->getResultTypes()) &&
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
});
// Mark ops from the target dialect as legal operations
target.addLegalDialect<mlir::concretelang::TFHE::TFHEDialect>();
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
// Make sure that func has legal signature
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
[&](mlir::func::FuncOp funcOp) {
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
converter.isLegal(&funcOp.getBody());
});
// Make sure that no ops from `FHE` remain after the lowering
target.addIllegalDialect<mlir::concretelang::FHE::FHEDialect>();
// Add all patterns required to lower all ops from `FHE` to
// `TFHE`
mlir::RewritePatternSet patterns(&getContext());
// Make sure that no ops `linalg.generic` that have illegal types
target.addDynamicallyLegalOp<mlir::linalg::GenericOp,
mlir::tensor::GenerateOp>(
[&](mlir::Operation *op) {
return (
converter.isLegal(op->getOperandTypes()) &&
converter.isLegal(op->getResultTypes()) &&
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
});
populateWithGeneratedFHEToTFHE(patterns);
// Make sure that func has legal signature
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
[&](mlir::func::FuncOp funcOp) {
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
converter.isLegal(&funcOp.getBody());
});
patterns.add<
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>>(
patterns.getContext(), converter);
// Add all patterns required to lower all ops from `FHE` to
// `TFHE`
mlir::RewritePatternSet patterns(&getContext());
patterns.add<ApplyLookupTableEintOpPattern>(&getContext());
patterns.add<SubEintOpPattern>(&getContext());
patterns.add<SubEintIntOpPattern>(&getContext());
populateWithGeneratedFHEToTFHE(patterns);
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
FHEToTFHETypeConverter>>(
&getContext(), converter);
patterns.add<
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>>(
patterns.getContext(), converter);
patterns.add<
mlir::concretelang::GenericTypeConverterPattern<mlir::linalg::YieldOp>>(
patterns.getContext(), converter);
switch (lutLowerStrategy) {
case mlir::concretelang::KeySwitchBoostrapLowering:
patterns.add<ApplyLookupTableEintOpToKeyswitchBootstrapPattern>(
&getContext());
break;
case mlir::concretelang::WopPBSLowering:
patterns.add<ApplyLookupTableEintOpToWopPBSPattern>(&getContext());
break;
}
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
FHEToTFHETypeConverter>>(
&getContext(), converter);
patterns.add<SubEintOpPattern>(&getContext());
patterns.add<SubEintIntOpPattern>(&getContext());
patterns.add<
RegionOpTypeConverterPattern<mlir::scf::ForOp, FHEToTFHETypeConverter>>(
&getContext(), converter);
patterns.add<mlir::concretelang::GenericTypeAndOpConverterPattern<
mlir::concretelang::FHE::ZeroTensorOp,
mlir::concretelang::TFHE::ZeroTensorGLWEOp>>(&getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
FHEToTFHETypeConverter>>(
&getContext(), converter);
mlir::concretelang::populateWithTensorTypeConverterPatterns(patterns, target,
converter);
patterns.add<
mlir::concretelang::GenericTypeConverterPattern<mlir::linalg::YieldOp>>(
patterns.getContext(), converter);
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
patterns, converter);
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
FHEToTFHETypeConverter>>(
&getContext(), converter);
// Conversion of RT Dialect Ops
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DataflowTaskOp>>(patterns.getContext(),
converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DataflowTaskOp>(target, converter);
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DataflowYieldOp>>(patterns.getContext(),
converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DataflowYieldOp>(target, converter);
patterns.add<
RegionOpTypeConverterPattern<mlir::scf::ForOp, FHEToTFHETypeConverter>>(
&getContext(), converter);
patterns.add<mlir::concretelang::GenericTypeAndOpConverterPattern<
mlir::concretelang::FHE::ZeroTensorOp,
mlir::concretelang::TFHE::ZeroTensorGLWEOp>>(&getContext(), converter);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
this->signalPassFailure();
mlir::concretelang::populateWithTensorTypeConverterPatterns(
patterns, target, converter);
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
patterns, converter);
// Conversion of RT Dialect Ops
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DataflowTaskOp>>(patterns.getContext(),
converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DataflowTaskOp>(target, converter);
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DataflowYieldOp>>(patterns.getContext(),
converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DataflowYieldOp>(target, converter);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))
.failed()) {
this->signalPassFailure();
}
}
}
private:
mlir::concretelang::ApplyLookupTableLowering lutLowerStrategy;
};
} // namespace
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>> createConvertFHEToTFHEPass() {
return std::make_unique<FHEToTFHEPass>();
std::unique_ptr<OperationPass<ModuleOp>>
createConvertFHEToTFHEPass(ApplyLookupTableLowering lower) {
return std::make_unique<FHEToTFHEPass>(lower);
}
} // namespace concretelang
} // namespace mlir

View File

@@ -59,12 +59,17 @@ public:
auto dimension = cryptoParameters.getNBigGlweDimension();
auto polynomialSize = 1;
auto precision = (signed)type.getP();
auto crtDecomposition =
cryptoParameters.largeInteger.hasValue()
? cryptoParameters.largeInteger->crtDecomposition
: mlir::concretelang::CRTDecomposition{};
if ((int)dimension == type.getDimension() &&
(int)polynomialSize == type.getPolynomialSize()) {
return type;
}
return TFHE::GLWECipherTextType::get(type.getContext(), dimension,
polynomialSize, bits, precision);
polynomialSize, bits, precision,
crtDecomposition);
}
TFHE::GLWECipherTextType glweLookupTableType(GLWECipherTextType &type) {
@@ -73,7 +78,7 @@ public:
auto polynomialSize = cryptoParameters.getPolynomialSize();
auto precision = (signed)type.getP();
return TFHE::GLWECipherTextType::get(type.getContext(), dimension,
polynomialSize, bits, precision);
polynomialSize, bits, precision, {});
}
TFHE::GLWECipherTextType glweIntraPBSType(GLWECipherTextType &type) {
@@ -82,7 +87,7 @@ public:
auto polynomialSize = 1;
auto precision = (signed)type.getP();
return TFHE::GLWECipherTextType::get(type.getContext(), dimension,
polynomialSize, bits, precision);
polynomialSize, bits, precision, {});
}
mlir::concretelang::V0Parameter cryptoParameters;
@@ -154,6 +159,49 @@ private:
mlir::concretelang::V0Parameter &cryptoParameters;
};
struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
WopPBSGLWEOpPattern(mlir::MLIRContext *context,
TFHEGlobalParametrizationTypeConverter &converter,
mlir::concretelang::V0Parameter &cryptoParameters,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<TFHE::WopPBSGLWEOp>(context, benefit),
converter(converter), cryptoParameters(cryptoParameters) {}
mlir::LogicalResult
matchAndRewrite(TFHE::WopPBSGLWEOp wopPBSOp,
mlir::PatternRewriter &rewriter) const override {
auto newOp = rewriter.replaceOpWithNewOp<TFHE::WopPBSGLWEOp>(
wopPBSOp, converter.convertType(wopPBSOp.result().getType()),
wopPBSOp.ciphertext(), wopPBSOp.lookupTable(),
// Bootstrap parameters
cryptoParameters.brLevel, cryptoParameters.brLogBase,
// Keyswitch parameters
cryptoParameters.ksLevel, cryptoParameters.ksLogBase,
// Packing keyswitch key parameters
cryptoParameters.largeInteger->wopPBS.packingKeySwitch
.inputLweDimension,
cryptoParameters.largeInteger->wopPBS.packingKeySwitch.inputLweCount,
cryptoParameters.largeInteger->wopPBS.packingKeySwitch
.outputPolynomialSize,
cryptoParameters.largeInteger->wopPBS.packingKeySwitch.level,
cryptoParameters.largeInteger->wopPBS.packingKeySwitch.baseLog,
// Circuit bootstrap parameters
cryptoParameters.largeInteger->wopPBS.circuitBootstrap.level,
cryptoParameters.largeInteger->wopPBS.circuitBootstrap.baseLog);
rewriter.startRootUpdate(newOp);
auto ciphertextType =
wopPBSOp.ciphertext().getType().cast<TFHE::GLWECipherTextType>();
newOp.ciphertext().setType(converter.glweInterPBSType(ciphertextType));
rewriter.finalizeRootUpdate(newOp);
return mlir::success();
};
private:
TFHEGlobalParametrizationTypeConverter &converter;
mlir::concretelang::V0Parameter &cryptoParameters;
};
/// This rewrite pattern transforms any instance of `TFHE.glwe_from_table` by
/// parametrize GLWE return type and pad the table if the precision has been
/// changed.
@@ -271,6 +319,16 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
return converter.isLegal(op->getResultTypes());
});
// Parametrize wop pbs
patterns.add<WopPBSGLWEOpPattern>(&getContext(), converter,
cryptoParameters);
target.addDynamicallyLegalOp<TFHE::WopPBSGLWEOp>(
[&](TFHE::WopPBSGLWEOp op) {
return !op.getType()
.cast<TFHE::GLWECipherTextType>()
.hasUnparametrizedParameters();
});
// Add all patterns to convert TFHE types
populateWithTFHEOpTypeConversionPatterns(patterns, target, converter);
patterns.add<RegionOpTypeConverterPattern<

View File

@@ -109,6 +109,45 @@ private:
mlir::TypeConverter &converter;
};
struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
WopPBSGLWEOpPattern(mlir::MLIRContext *context,
mlir::TypeConverter &converter,
mlir::PatternBenefit benefit = 100)
: mlir::OpRewritePattern<TFHE::WopPBSGLWEOp>(context, benefit),
converter(converter) {}
mlir::LogicalResult
matchAndRewrite(TFHE::WopPBSGLWEOp wopOp,
mlir::PatternRewriter &rewriter) const override {
mlir::Type resultType = converter.convertType(wopOp.getType());
auto newOp = rewriter.replaceOpWithNewOp<Concrete::WopPBSLweOp>(
wopOp, resultType, wopOp.ciphertext(), wopOp.lookupTable(),
// Bootstrap parameters
wopOp.bootstrapLevel(), wopOp.bootstrapBaseLog(),
// Keyswitch parameters
wopOp.keyswitchLevel(), wopOp.keyswitchBaseLog(),
// Packing keyswitch key parameters
wopOp.packingKeySwitchInputLweDimension(),
wopOp.packingKeySwitchinputLweCount(),
wopOp.packingKeySwitchoutputPolynomialSize(),
wopOp.packingKeySwitchLevel(), wopOp.packingKeySwitchBaseLog(),
// Circuit bootstrap parameters
wopOp.circuitBootstrapLevel(), wopOp.circuitBootstrapBaseLog());
rewriter.startRootUpdate(newOp);
newOp.ciphertext().setType(
converter.convertType(wopOp.ciphertext().getType()));
rewriter.finalizeRootUpdate(newOp);
return ::mlir::success();
}
private:
mlir::TypeConverter &converter;
};
void TFHEToConcretePass::runOnOperation() {
auto op = this->getOperation();
@@ -147,6 +186,7 @@ void TFHEToConcretePass::runOnOperation() {
mlir::concretelang::Concrete::ZeroTensorLWEOp>>(&getContext(), converter);
patterns.add<GLWEFromTableOpPattern>(&getContext());
patterns.add<BootstrapGLWEOpPattern>(&getContext(), converter);
patterns.add<WopPBSGLWEOpPattern>(&getContext(), converter);
target.addDynamicallyLegalOp<Concrete::BootstrapLweOp>(
[&](Concrete::BootstrapLweOp op) {
return (converter.isLegal(op->getOperandTypes()) &&

View File

@@ -36,23 +36,23 @@ namespace {} // namespace
namespace {
mlir::Type getDynamic1DMemrefWithUnknownOffset(mlir::RewriterBase &rewriter) {
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
size_t rank) {
mlir::MLIRContext *ctx = rewriter.getContext();
return mlir::MemRefType::get(
{-1}, rewriter.getI64Type(),
mlir::AffineMap::get(1, 1,
mlir::getAffineDimExpr(0, ctx) +
mlir::getAffineSymbolExpr(0, ctx)));
std::vector<int64_t> shape(rank, -1);
return mlir::MemRefType::get(shape, rewriter.getI64Type(),
rewriter.getMultiDimIdentityMap(rank));
}
/// Returns `memref.cast %0 : memref<AxT> to memref<?xT>` if %0 a 1D memref
mlir::Value getCasted1DMemRef(mlir::RewriterBase &rewriter, mlir::Location loc,
mlir::Value value) {
// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>`
mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Location loc,
mlir::Value value) {
mlir::Type valueType = value.getType();
if (valueType.isa<mlir::MemRefType>()) {
if (auto memrefTy = valueType.dyn_cast_or_null<mlir::MemRefType>()) {
return rewriter.create<mlir::memref::CastOp>(
loc, getDynamic1DMemrefWithUnknownOffset(rewriter), value);
loc,
getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()),
value);
} else {
return value;
}
@@ -69,10 +69,13 @@ char memref_bootstrap_lwe_u64[] = "memref_bootstrap_lwe_u64";
char memref_expand_lut_in_trivial_glwe_ct_u64[] =
"memref_expand_lut_in_trivial_glwe_ct_u64";
char memref_wop_pbs_crt_buffer[] = "memref_wop_pbs_crt_buffer";
mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
mlir::Operation *op, mlir::RewriterBase &rewriter, char const *funcName) {
auto memref1DType = getDynamic1DMemrefWithUnknownOffset(rewriter);
auto memref1DType = getDynamicMemrefWithUnknownOffset(rewriter, 1);
auto memref2DType = getDynamicMemrefWithUnknownOffset(rewriter, 2);
auto contextType =
mlir::concretelang::Concrete::ContextType::get(rewriter.getContext());
@@ -109,6 +112,15 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
memref1DType,
},
{});
} else if (funcName == memref_wop_pbs_crt_buffer) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{
memref2DType,
memref2DType,
memref1DType,
contextType,
},
{});
} else {
op->emitError("unknwon external function") << funcName;
return mlir::failure();
@@ -185,7 +197,7 @@ struct BufferizableWithCallOpInterface
// The first operand is the result
mlir::SmallVector<mlir::Value, 3> operands{
getCasted1DMemRef(rewriter, loc, *outMemref),
getCastedMemRef(rewriter, loc, *outMemref),
};
// For all tensor operand get the corresponding casted buffer
for (auto &operand : op->getOpOperands()) {
@@ -194,7 +206,7 @@ struct BufferizableWithCallOpInterface
} else {
auto memrefOperand =
bufferization::getBuffer(rewriter, operand.get(), options);
operands.push_back(getCasted1DMemRef(rewriter, loc, memrefOperand));
operands.push_back(getCastedMemRef(rewriter, loc, memrefOperand));
}
}
// Append the context argument
@@ -264,14 +276,14 @@ struct BufferizableGlweFromTableOpInterface
auto loc = op->getLoc();
auto castOp = cast<BConcrete::FillGlweFromTable>(op);
auto glweOp = getCasted1DMemRef(
rewriter, loc,
bufferization::getBuffer(rewriter, castOp->getOpOperand(0).get(),
options));
auto lutOp = getCasted1DMemRef(
rewriter, loc,
bufferization::getBuffer(rewriter, castOp->getOpOperand(1).get(),
options));
auto glweOp =
getCastedMemRef(rewriter, loc,
bufferization::getBuffer(
rewriter, castOp->getOpOperand(0).get(), options));
auto lutOp =
getCastedMemRef(rewriter, loc,
bufferization::getBuffer(
rewriter, castOp->getOpOperand(1).get(), options));
auto polySizeOp = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(), rewriter.getI32IntegerAttr(castOp.polynomialSize()));
@@ -326,6 +338,10 @@ void mlir::concretelang::BConcrete::
BConcrete::BootstrapLweBufferOp::attachInterface<
BufferizableWithCallOpInterface<BConcrete::BootstrapLweBufferOp,
memref_bootstrap_lwe_u64, true>>(*ctx);
// TODO(16bits): hack
BConcrete::WopPBSCRTLweBufferOp::attachInterface<
BufferizableWithCallOpInterface<BConcrete::WopPBSCRTLweBufferOp,
memref_wop_pbs_crt_buffer, true>>(*ctx);
BConcrete::FillGlweFromTable::attachInterface<
BufferizableGlweFromTableOpInterface>(*ctx);
});

View File

@@ -1,6 +1,7 @@
add_mlir_dialect_library(ConcretelangBConcreteTransforms
BufferizableOpInterfaceImpl.cpp
AddRuntimeContext.cpp
EliminateCRTOps.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/BConcrete
@@ -18,4 +19,4 @@ add_mlir_dialect_library(ConcretelangBConcreteTransforms
MLIRMemRefDialect
MLIRPass
MLIRTransforms
)
)

View File

@@ -0,0 +1,561 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "concretelang/ClientLib/CRT.h"
#include "concretelang/Conversion/Tools.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
#include "concretelang/Dialect/BConcrete/Transforms/Passes.h"
namespace arith = mlir::arith;
namespace tensor = mlir::tensor;
namespace bufferization = mlir::bufferization;
namespace scf = mlir::scf;
namespace BConcrete = mlir::concretelang::BConcrete;
namespace crt = concretelang::clientlib::crt;
namespace {
char encode_crt[] = "encode_crt";
// This template rewrite pattern transforms any instance of
// `BConcreteCRTOp` operators to `BConcreteOp` on
// each block.
//
// Example:
//
// ```mlir
// %0 = "BConcreteCRTOp"(%arg0, %arg1) {crtDecomposition = [...]}
// : (tensor<nbBlocksxlweSizexi64>, tensor<nbBlocksxlweSizexi64>) ->
// (tensor<nbBlocksxlweSizexi64>)
// ```
//
// becomes:
//
// ```mlir
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
// %blockArg = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
// %tmp = "BConcreteOp"(%blockArg)
// : (tensor<lweSizexi64>) -> (tensor<lweSizexi64>)
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
// }
// ```
template <typename BConcreteCRTOp, typename BConcreteOp>
struct BConcreteCRTUnaryOpPattern
: public mlir::OpRewritePattern<BConcreteCRTOp> {
BConcreteCRTUnaryOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<BConcreteCRTOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(BConcreteCRTOp op,
mlir::PatternRewriter &rewriter) const override {
auto resultTy =
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
auto loc = op.getLoc();
assert(resultTy.getShape().size() == 2);
auto shape = resultTy.getShape();
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto cB = rewriter.create<arith::ConstantIndexOp>(loc, shape[0]);
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
op.getLoc(), resultTy, mlir::ValueRange{});
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
rewriter.replaceOpWithNewOp<scf::ForOp>(
op, c0, cB, c1, init,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i,
mlir::ValueRange iterArgs) {
// [%i, 0]
mlir::SmallVector<mlir::OpFoldResult> offsets{
i, rewriter.getI64IntegerAttr(0)};
// [1, lweSize]
mlir::SmallVector<mlir::OpFoldResult> sizes{
rewriter.getI64IntegerAttr(1),
rewriter.getI64IntegerAttr(shape[1])};
// [1, 1]
mlir::SmallVector<mlir::OpFoldResult> strides{
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
auto blockTy = mlir::RankedTensorType::get({shape[1]},
resultTy.getElementType());
// %blockArg = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
auto blockArg = builder.create<tensor::ExtractSliceOp>(
loc, blockTy, op.ciphertext(), offsets, sizes, strides);
// %tmp = "BConcrete.add_lwe_buffer"(%blockArg0, %blockArg1)
// : (tensor<lweSizexi64>, tensor<lweSizexi64>) ->
// (tensor<lweSizexi64>)
auto tmp = builder.create<BConcreteOp>(loc, blockTy, blockArg);
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
// 1] : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
auto res = builder.create<tensor::InsertSliceOp>(
loc, tmp, iterArgs[0], offsets, sizes, strides);
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
builder.create<scf::YieldOp>(loc, (mlir::Value)res);
});
return mlir::success();
}
};
// This template rewrite pattern transforms any instance of
// `BConcreteCRTOp` operators to `BConcreteOp` on
// each block.
//
// Example:
//
// ```mlir
// %0 = "BConcreteCRTOp"(%arg0, %arg1) {crtDecomposition = [...]}
// : (tensor<nbBlocksxlweSizexi64>, tensor<nbBlocksxlweSizexi64>) ->
// (tensor<nbBlocksxlweSizexi64>)
// ```
//
// becomes:
//
// ```mlir
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
// %blockArg1 = tensor.extract_slice %arg1[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
// : (tensor<lweSizexi64>, tensor<lweSizexi64>) ->
// (tensor<lweSizexi64>)
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
// }
// ```
template <typename BConcreteCRTOp, typename BConcreteOp>
struct BConcreteCRTBinaryOpPattern
: public mlir::OpRewritePattern<BConcreteCRTOp> {
BConcreteCRTBinaryOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<BConcreteCRTOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(BConcreteCRTOp op,
mlir::PatternRewriter &rewriter) const override {
auto resultTy =
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
auto loc = op.getLoc();
assert(resultTy.getShape().size() == 2);
auto shape = resultTy.getShape();
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto cB = rewriter.create<arith::ConstantIndexOp>(loc, shape[0]);
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
op.getLoc(), resultTy, mlir::ValueRange{});
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
rewriter.replaceOpWithNewOp<scf::ForOp>(
op, c0, cB, c1, init,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i,
mlir::ValueRange iterArgs) {
// [%i, 0]
mlir::SmallVector<mlir::OpFoldResult> offsets{
i, rewriter.getI64IntegerAttr(0)};
// [1, lweSize]
mlir::SmallVector<mlir::OpFoldResult> sizes{
rewriter.getI64IntegerAttr(1),
rewriter.getI64IntegerAttr(shape[1])};
// [1, 1]
mlir::SmallVector<mlir::OpFoldResult> strides{
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
auto blockTy = mlir::RankedTensorType::get({shape[1]},
resultTy.getElementType());
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
auto blockArg0 = builder.create<tensor::ExtractSliceOp>(
loc, blockTy, op.lhs(), offsets, sizes, strides);
// %blockArg1 = tensor.extract_slice %arg1[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
auto blockArg1 = builder.create<tensor::ExtractSliceOp>(
loc, blockTy, op.rhs(), offsets, sizes, strides);
// %tmp = "BConcrete.add_lwe_buffer"(%blockArg0, %blockArg1)
// : (tensor<lweSizexi64>, tensor<lweSizexi64>) ->
// (tensor<lweSizexi64>)
auto tmp =
builder.create<BConcreteOp>(loc, blockTy, blockArg0, blockArg1);
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
// 1] : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
auto res = builder.create<tensor::InsertSliceOp>(
loc, tmp, iterArgs[0], offsets, sizes, strides);
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
builder.create<scf::YieldOp>(loc, (mlir::Value)res);
});
return mlir::success();
}
};
// This template rewrite pattern transforms any instance of
// `BConcreteCRTOp` operators to `BConcreteOp` on
// each block with the crt decomposition of the cleartext.
//
// Example:
//
// ```mlir
// %0 = "BConcreteCRTOp"(%arg0, %x) {crtDecomposition = [d0...dn]}
// : (tensor<nbBlocksxlweSizexi64>, i64) -> (tensor<nbBlocksxlweSizexi64>)
// ```
//
// becomes:
//
// ```mlir
// // Build the decomposition of the plaintext
// %x0_a = arith.constant 64/d0 : f64
// %x0_b = arith.mulf %x, %x0_a : i64
// %x0 = arith.fptoui %x0_b : f64 to i64
// ...
// %xn_a = arith.constant 64/dn : f64
// %xn_b = arith.mulf %x, %xn_a : i64
// %xn = arith.fptoui %xn_b : f64 to i64
// %x_decomp = tensor.from_elements %x0, ..., %xn : tensor<nbBlocksxi64>
// // Loop on blocks
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
// %blockArg1 = tensor.extract %x_decomp[%i] : tensor<nbBlocksxi64>
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
// }
// ```
struct AddPlaintextCRTLweBufferOpPattern
: public mlir::OpRewritePattern<BConcrete::AddPlaintextCRTLweBufferOp> {
AddPlaintextCRTLweBufferOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<BConcrete::AddPlaintextCRTLweBufferOp>(context,
benefit) {
}
mlir::LogicalResult
matchAndRewrite(BConcrete::AddPlaintextCRTLweBufferOp op,
mlir::PatternRewriter &rewriter) const override {
auto resultTy =
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
auto loc = op.getLoc();
assert(resultTy.getShape().size() == 2);
auto shape = resultTy.getShape();
auto rhs = op.rhs();
mlir::SmallVector<mlir::Value, 5> plaintextElements;
uint64_t moduliProduct = 1;
for (mlir::Attribute di : op.crtDecomposition()) {
moduliProduct *= di.cast<mlir::IntegerAttr>().getValue().getZExtValue();
}
if (auto cst =
mlir::dyn_cast_or_null<arith::ConstantIntOp>(rhs.getDefiningOp())) {
auto apCst = cst.getValue().cast<mlir::IntegerAttr>().getValue();
auto value = apCst.getSExtValue();
// constant value, encode at compile time
for (mlir::Attribute di : op.crtDecomposition()) {
auto modulus = di.cast<mlir::IntegerAttr>().getValue().getZExtValue();
auto encoded = crt::encode(value, modulus, moduliProduct);
plaintextElements.push_back(
rewriter.create<arith::ConstantIntOp>(loc, encoded, 64));
}
} else {
// dynamic value, encode at runtime
if (insertForwardDeclaration(
op, rewriter, encode_crt,
mlir::FunctionType::get(rewriter.getContext(),
{rewriter.getI64Type(),
rewriter.getI64Type(),
rewriter.getI64Type()},
{rewriter.getI64Type()}))
.failed()) {
return mlir::failure();
}
auto extOp =
rewriter.create<arith::ExtSIOp>(loc, rewriter.getI64Type(), rhs);
auto moduliProductOp =
rewriter.create<arith::ConstantIntOp>(loc, moduliProduct, 64);
for (mlir::Attribute di : op.crtDecomposition()) {
auto modulus = di.cast<mlir::IntegerAttr>().getValue().getZExtValue();
auto modulusOp =
rewriter.create<arith::ConstantIntOp>(loc, modulus, 64);
plaintextElements.push_back(
rewriter
.create<mlir::func::CallOp>(
loc, encode_crt, mlir::TypeRange{rewriter.getI64Type()},
mlir::ValueRange{extOp, modulusOp, moduliProductOp})
.getResult(0));
}
}
// %x_decomp = tensor.from_elements %x0, ..., %xn : tensor<nbBlocksxi64>
auto x_decomp =
rewriter.create<tensor::FromElementsOp>(loc, plaintextElements);
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto cB = rewriter.create<arith::ConstantIndexOp>(loc, shape[0]);
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
op.getLoc(), resultTy, mlir::ValueRange{});
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
rewriter.replaceOpWithNewOp<scf::ForOp>(
op, c0, cB, c1, init,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i,
mlir::ValueRange iterArgs) {
// [%i, 0]
mlir::SmallVector<mlir::OpFoldResult> offsets{
i, rewriter.getI64IntegerAttr(0)};
// [1, lweSize]
mlir::SmallVector<mlir::OpFoldResult> sizes{
rewriter.getI64IntegerAttr(1),
rewriter.getI64IntegerAttr(shape[1])};
// [1, 1]
mlir::SmallVector<mlir::OpFoldResult> strides{
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
auto blockTy = mlir::RankedTensorType::get({shape[1]},
resultTy.getElementType());
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
auto blockArg0 = builder.create<tensor::ExtractSliceOp>(
loc, blockTy, op.lhs(), offsets, sizes, strides);
// %blockArg1 = tensor.extract %x_decomp[%i] : tensor<nbBlocksxi64>
auto blockArg1 = builder.create<tensor::ExtractOp>(loc, x_decomp, i);
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
auto tmp = builder.create<BConcrete::AddPlaintextLweBufferOp>(
loc, blockTy, blockArg0, blockArg1);
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
// 1] : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
auto res = builder.create<tensor::InsertSliceOp>(
loc, tmp, iterArgs[0], offsets, sizes, strides);
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
builder.create<scf::YieldOp>(loc, (mlir::Value)res);
});
return mlir::success();
}
};
// This template rewrite pattern transforms any instance of
// `BConcreteCRTOp` operators to `BConcreteOp` on
// each block with the crt decomposition of the cleartext.
//
// Example:
//
// ```mlir
// %0 = "BConcreteCRTOp"(%arg0, %x) {crtDecomposition = [d0...dn]}
// : (tensor<nbBlocksxlweSizexi64>, i64) -> (tensor<nbBlocksxlweSizexi64>)
// ```
//
// becomes:
//
// ```mlir
// // Build the decomposition of the plaintext
// %x0_a = arith.constant 64/d0 : f64
// %x0_b = arith.mulf %x, %x0_a : i64
// %x0 = arith.fptoui %x0_b : f64 to i64
// ...
// %xn_a = arith.constant 64/dn : f64
// %xn_b = arith.mulf %x, %xn_a : i64
// %xn = arith.fptoui %xn_b : f64 to i64
// %x_decomp = tensor.from_elements %x0, ..., %xn : tensor<nbBlocksxi64>
// // Loop on blocks
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
// %blockArg1 = tensor.extract %x_decomp[%i] : tensor<nbBlocksxi64>
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
// }
// ```
struct MulCleartextCRTLweBufferOpPattern
: public mlir::OpRewritePattern<BConcrete::MulCleartextCRTLweBufferOp> {
MulCleartextCRTLweBufferOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<BConcrete::MulCleartextCRTLweBufferOp>(context,
benefit) {
}
mlir::LogicalResult
matchAndRewrite(BConcrete::MulCleartextCRTLweBufferOp op,
mlir::PatternRewriter &rewriter) const override {
auto resultTy =
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
auto loc = op.getLoc();
assert(resultTy.getShape().size() == 2);
auto shape = resultTy.getShape();
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto cB = rewriter.create<arith::ConstantIndexOp>(loc, shape[0]);
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
op.getLoc(), resultTy, mlir::ValueRange{});
auto rhs = rewriter.create<arith::ExtUIOp>(op.getLoc(),
rewriter.getI64Type(), op.rhs());
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
rewriter.replaceOpWithNewOp<scf::ForOp>(
op, c0, cB, c1, init,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i,
mlir::ValueRange iterArgs) {
// [%i, 0]
mlir::SmallVector<mlir::OpFoldResult> offsets{
i, rewriter.getI64IntegerAttr(0)};
// [1, lweSize]
mlir::SmallVector<mlir::OpFoldResult> sizes{
rewriter.getI64IntegerAttr(1),
rewriter.getI64IntegerAttr(shape[1])};
// [1, 1]
mlir::SmallVector<mlir::OpFoldResult> strides{
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
auto blockTy = mlir::RankedTensorType::get({shape[1]},
resultTy.getElementType());
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
auto blockArg0 = builder.create<tensor::ExtractSliceOp>(
loc, blockTy, op.lhs(), offsets, sizes, strides);
// %tmp = BConcrete.mul_cleartext_lwe_buffer(%blockArg0, %x)
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
auto tmp = builder.create<BConcrete::MulCleartextLweBufferOp>(
loc, blockTy, blockArg0, rhs);
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
// 1] : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
auto res = builder.create<tensor::InsertSliceOp>(
loc, tmp, iterArgs[0], offsets, sizes, strides);
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
builder.create<scf::YieldOp>(loc, (mlir::Value)res);
});
return mlir::success();
}
};
struct EliminateCRTOpsPass : public EliminateCRTOpsBase<EliminateCRTOpsPass> {
void runOnOperation() final;
};
void EliminateCRTOpsPass::runOnOperation() {
auto op = getOperation();
mlir::ConversionTarget target(getContext());
mlir::RewritePatternSet patterns(&getContext());
// add_crt_lwe_buffers
target.addIllegalOp<BConcrete::AddCRTLweBuffersOp>();
patterns.add<BConcreteCRTBinaryOpPattern<BConcrete::AddCRTLweBuffersOp,
BConcrete::AddLweBuffersOp>>(
&getContext());
// add_plaintext_crt_lwe_buffers
target.addIllegalOp<BConcrete::AddPlaintextCRTLweBufferOp>();
patterns.add<AddPlaintextCRTLweBufferOpPattern>(&getContext());
// mul_cleartext_crt_lwe_buffer
target.addIllegalOp<BConcrete::MulCleartextCRTLweBufferOp>();
patterns.add<MulCleartextCRTLweBufferOpPattern>(&getContext());
target.addIllegalOp<BConcrete::NegateCRTLweBufferOp>();
patterns.add<BConcreteCRTUnaryOpPattern<BConcrete::NegateCRTLweBufferOp,
BConcrete::NegateLweBufferOp>>(
&getContext());
// This dialect are used to transforms crt ops to bconcrete ops
target
.addLegalDialect<arith::ArithmeticDialect, tensor::TensorDialect,
scf::SCFDialect, bufferization::BufferizationDialect,
mlir::func::FuncDialect, BConcrete::BConcreteDialect>();
// Apply the conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
this->signalPassFailure();
return;
}
}
} // namespace
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<func::FuncOp>> createEliminateCRTOps() {
return std::make_unique<EliminateCRTOpsPass>();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -25,6 +25,13 @@ void ConcreteDialect::initialize() {
>();
}
void printSigned(mlir::AsmPrinter &p, signed i) {
if (i == -1)
p << "_";
else
p << i;
}
mlir::Type GlweCiphertextType::parse(mlir::AsmParser &parser) {
if (parser.parseLess())
return Type();
@@ -50,42 +57,61 @@ mlir::Type GlweCiphertextType::parse(mlir::AsmParser &parser) {
void GlweCiphertextType::print(mlir::AsmPrinter &p) const {
p << "<";
if (getImpl()->glweDimension == -1)
p << "_";
else
p << getImpl()->glweDimension;
printSigned(p, getGlweDimension());
p << ",";
if (getImpl()->polynomialSize == -1)
p << "_";
else
p << getImpl()->polynomialSize;
printSigned(p, getPolynomialSize());
p << ",";
if (getImpl()->p == -1)
p << "_";
else
p << getImpl()->p;
printSigned(p, getP());
p << ">";
}
void LweCiphertextType::print(mlir::AsmPrinter &p) const {
p << "<";
if (getDimension() == -1)
p << "_";
else
p << getDimension();
// decomposition parameters if any
auto crt = getCrtDecomposition();
if (!crt.empty()) {
p << "crt=[";
for (auto c : crt.drop_back(1)) {
printSigned(p, c);
p << ",";
}
printSigned(p, crt.back());
p << "]";
p << ",";
}
printSigned(p, getDimension());
p << ",";
if (getP() == -1)
p << "_";
else
p << getP();
printSigned(p, getP());
p << ">";
}
mlir::Type LweCiphertextType::parse(mlir::AsmParser &parser) {
if (parser.parseLess())
return mlir::Type();
// Parse for the crt decomposition if any
std::vector<int64_t> crtDecomposition;
if (!parser.parseOptionalKeyword("crt")) {
if (parser.parseEqual() || parser.parseLSquare())
return mlir::Type();
while (true) {
int64_t c = -1;
if (parser.parseOptionalKeyword("_") && parser.parseInteger(c)) {
return mlir::Type();
}
crtDecomposition.push_back(c);
if (parser.parseOptionalComma()) {
if (parser.parseRSquare()) {
return mlir::Type();
} else {
break;
}
}
}
if (parser.parseComma())
return mlir::Type();
}
int dimension = -1;
if (parser.parseOptionalKeyword("_") && parser.parseInteger(dimension))
return mlir::Type();
@@ -99,7 +125,7 @@ mlir::Type LweCiphertextType::parse(mlir::AsmParser &parser) {
mlir::Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
return getChecked(loc, loc.getContext(), dimension, p);
return getChecked(loc, loc.getContext(), dimension, p, crtDecomposition);
}
void CleartextType::print(mlir::AsmPrinter &p) const {

View File

@@ -16,29 +16,10 @@ namespace concretelang {
namespace {
/// Get the integer value that the cleartext was created from if it exists.
llvm::Optional<mlir::Value>
getIntegerFromCleartextIfExists(mlir::Value cleartext) {
assert(
cleartext.getType().isa<mlir::concretelang::Concrete::CleartextType>());
// Cleartext are supposed to be created from integers
auto intToCleartextOp = cleartext.getDefiningOp();
if (intToCleartextOp == nullptr)
return {};
if (llvm::isa<Concrete::IntToCleartextOp>(intToCleartextOp)) {
// We want to match when the integer value is constant
return intToCleartextOp->getOperand(0);
}
return {};
}
/// Get the constant integer that the cleartext was created from if it exists.
llvm::Optional<IntegerAttr>
getConstantIntFromCleartextIfExists(mlir::Value cleartext) {
auto cleartextInt = getIntegerFromCleartextIfExists(cleartext);
if (!cleartextInt.hasValue())
return {};
auto constantOp = cleartextInt.getValue().getDefiningOp();
auto constantOp = cleartext.getDefiningOp();
if (constantOp == nullptr)
return {};
if (llvm::isa<arith::ConstantOp>(constantOp)) {

View File

@@ -32,7 +32,8 @@ void TFHEDialect::initialize() {
/// - The bits parameter is 64 (we support only this for v0)
::mlir::LogicalResult GLWECipherTextType::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
signed dimension, signed polynomialSize, signed bits, signed p) {
signed dimension, signed polynomialSize, signed bits, signed p,
llvm::ArrayRef<int64_t>) {
if (bits != -1 && bits != 64) {
emitError() << "GLWE bits parameter can only be 64";
return ::mlir::failure();

View File

@@ -40,6 +40,10 @@ mlir::LogicalResult _verifyGLWEIntegerOperator(mlir::OpState &op,
emitOpErrorForIncompatibleGLWEParameter(op, "p");
return mlir::failure();
}
if (a.getCrtDecomposition() != result.getCrtDecomposition()) {
emitOpErrorForIncompatibleGLWEParameter(op, "crt");
return mlir::failure();
}
// verify consistency of width of inputs
if ((int)b.getWidth() > a.getP() + 1) {
@@ -107,6 +111,11 @@ mlir::LogicalResult verifyBinaryGLWEOperator(Operator &op) {
emitOpErrorForIncompatibleGLWEParameter(op, "p");
return mlir::failure();
}
if (a.getCrtDecomposition() != b.getCrtDecomposition() ||
a.getCrtDecomposition() != result.getCrtDecomposition()) {
emitOpErrorForIncompatibleGLWEParameter(op, "crt");
return mlir::failure();
}
return mlir::success();
}
@@ -137,6 +146,10 @@ mlir::LogicalResult verifyUnaryGLWEOperator(Operator &op) {
emitOpErrorForIncompatibleGLWEParameter(op, "p");
return mlir::failure();
}
if (a.getCrtDecomposition() != result.getCrtDecomposition()) {
emitOpErrorForIncompatibleGLWEParameter(op, "crt");
return mlir::failure();
}
return mlir::success();
}

View File

@@ -9,28 +9,35 @@ namespace mlir {
namespace concretelang {
namespace TFHE {
void printSigned(mlir::AsmPrinter &p, signed i) {
if (i == -1)
p << "_";
else
p << i;
}
void GLWECipherTextType::print(mlir::AsmPrinter &p) const {
p << "<{";
if (getDimension() == -1)
p << "_";
else
p << getDimension();
p << ",";
if (getPolynomialSize() == -1)
p << "_";
else
p << getPolynomialSize();
p << ",";
if (getBits() == -1)
p << "_";
else
p << getBits();
p << "}";
p << "<";
auto crt = getCrtDecomposition();
if (!crt.empty()) {
p << "crt=[";
for (auto c : crt.drop_back(1)) {
printSigned(p, c);
p << ",";
}
printSigned(p, crt.back());
p << "]";
}
p << "{";
if (getP() == -1)
p << "_";
else
p << getP();
printSigned(p, getDimension());
p << ",";
printSigned(p, getPolynomialSize());
p << ",";
printSigned(p, getBits());
p << "}";
p << "{";
printSigned(p, getP());
p << "}>";
}
@@ -38,9 +45,31 @@ mlir::Type GLWECipherTextType::parse(AsmParser &parser) {
if (parser.parseLess())
return mlir::Type();
// First parameters block
// Parse for the crt decomposition if any
std::vector<int64_t> crtDecomposition;
if (!parser.parseOptionalKeyword("crt")) {
if (parser.parseEqual() || parser.parseLSquare())
return mlir::Type();
while (true) {
signed c = -1;
if (parser.parseOptionalKeyword("_") && parser.parseInteger(c)) {
return mlir::Type();
}
crtDecomposition.push_back(c);
if (parser.parseOptionalComma()) {
if (parser.parseRSquare()) {
return mlir::Type();
} else {
break;
}
}
}
}
if (parser.parseLBrace())
return mlir::Type();
// First parameters block
int dimension = -1;
if (parser.parseOptionalKeyword("_") && parser.parseInteger(dimension))
return mlir::Type();
@@ -69,7 +98,8 @@ mlir::Type GLWECipherTextType::parse(AsmParser &parser) {
if (parser.parseGreater())
return mlir::Type();
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
return getChecked(loc, loc.getContext(), dimension, polynomialSize, bits, p);
return getChecked(loc, loc.getContext(), dimension, polynomialSize, bits, p,
llvm::ArrayRef<int64_t>(crtDecomposition));
}
} // namespace TFHE
} // namespace concretelang

View File

@@ -26,10 +26,11 @@ target_link_libraries(
ConcretelangRuntime
PUBLIC
Concrete
ConcretelangClientLib
pthread m dl
$<TARGET_OBJECTS:mlir_c_runner_utils>
)
install(TARGETS ConcretelangRuntime omp EXPORT ConcretelangRuntime)
install(EXPORT ConcretelangRuntime DESTINATION "./")

View File

@@ -3,11 +3,13 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang/Runtime/wrappers.h"
#include <assert.h>
#include <stdio.h>
#include <string.h>
#include "concretelang/ClientLib/CRT.h"
#include "concretelang/Runtime/wrappers.h"
void memref_expand_lut_in_trivial_glwe_ct_u64(
uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
@@ -98,6 +100,10 @@ void memref_bootstrap_lwe_u64(
glwe_ct_aligned + glwe_ct_offset);
}
uint64_t encode_crt(int64_t plaintext, uint64_t modulus, uint64_t product) {
return concretelang::clientlib::crt::encode(plaintext, modulus, product);
}
void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned,
uint64_t src_offset, uint64_t src_size,
uint64_t src_stride, uint64_t *dst_allocated,

View File

@@ -38,171 +38,166 @@ TensorData multi_arity_call_dynamic_rank(void *(*func)(void *...),
using concretelang::clientlib::MemRefDescriptor;
constexpr auto convert = concretelang::clientlib::tensorDataFromMemRef;
switch (rank) {
case 0: {
case 1: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<1> (*)(void *...)>(func), args);
return convert(1, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 1: {
case 2: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<2> (*)(void *...)>(func), args);
return convert(2, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 2: {
case 3: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<3> (*)(void *...)>(func), args);
return convert(3, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 3: {
case 4: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<4> (*)(void *...)>(func), args);
return convert(4, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 4: {
case 5: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<5> (*)(void *...)>(func), args);
return convert(5, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 5: {
case 6: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<6> (*)(void *...)>(func), args);
return convert(6, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 6: {
case 7: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<7> (*)(void *...)>(func), args);
return convert(7, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 7: {
case 8: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<8> (*)(void *...)>(func), args);
return convert(8, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 8: {
case 9: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<9> (*)(void *...)>(func), args);
return convert(9, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 9: {
case 10: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<10> (*)(void *...)>(func), args);
return convert(10, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 10: {
case 11: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<11> (*)(void *...)>(func), args);
return convert(11, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 11: {
case 12: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<12> (*)(void *...)>(func), args);
return convert(12, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 12: {
case 13: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<13> (*)(void *...)>(func), args);
return convert(13, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 13: {
case 14: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<14> (*)(void *...)>(func), args);
return convert(14, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 14: {
case 15: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<15> (*)(void *...)>(func), args);
return convert(15, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 15: {
case 16: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<16> (*)(void *...)>(func), args);
return convert(16, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 16: {
case 17: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<17> (*)(void *...)>(func), args);
return convert(17, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 17: {
case 18: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<18> (*)(void *...)>(func), args);
return convert(18, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 18: {
case 19: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<19> (*)(void *...)>(func), args);
return convert(19, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 19: {
case 20: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<20> (*)(void *...)>(func), args);
return convert(20, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 20: {
case 21: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<21> (*)(void *...)>(func), args);
return convert(21, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 21: {
case 22: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<22> (*)(void *...)>(func), args);
return convert(22, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 22: {
case 23: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<23> (*)(void *...)>(func), args);
return convert(23, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 23: {
case 24: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<24> (*)(void *...)>(func), args);
return convert(24, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 24: {
case 25: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<25> (*)(void *...)>(func), args);
return convert(25, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 25: {
case 26: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<26> (*)(void *...)>(func), args);
return convert(26, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 26: {
case 27: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<27> (*)(void *...)>(func), args);
return convert(27, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 27: {
case 28: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<28> (*)(void *...)>(func), args);
return convert(28, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 28: {
case 29: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<29> (*)(void *...)>(func), args);
return convert(29, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 29: {
case 30: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<30> (*)(void *...)>(func), args);
return convert(30, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 30: {
case 31: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<31> (*)(void *...)>(func), args);
return convert(31, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 31: {
case 32: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<32> (*)(void *...)>(func), args);
return convert(32, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
case 32: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<33> (*)(void *...)>(func), args);
return convert(33, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
}
default:
assert(false);

View File

@@ -69,12 +69,6 @@ ServerLambda::load(std::string funcName, std::string outputPath) {
return ServerLambda::loadFromModule(module, funcName);
}
TensorData dynamicCall(void *(*func)(void *...),
std::vector<void *> &preparedArgs, CircuitGate &output) {
size_t rank = output.shape.dimensions.size();
return multi_arity_call_dynamic_rank(func, preparedArgs, rank);
}
std::unique_ptr<clientlib::PublicResult>
ServerLambda::call(PublicArguments &args, EvaluationKeys &evaluationKeys) {
std::vector<void *> preparedArgs(args.preparedArgs.begin(),
@@ -84,9 +78,12 @@ ServerLambda::call(PublicArguments &args, EvaluationKeys &evaluationKeys) {
runtimeContext.evaluationKeys = evaluationKeys;
preparedArgs.push_back((void *)&runtimeContext);
return clientlib::PublicResult::fromBuffers(
clientParameters,
{dynamicCall(this->func, preparedArgs, clientParameters.outputs[0])});
assert(clientParameters.outputs.size() == 1 &&
"ServerLambda::call is implemented for only one output");
auto output = args.clientParameters.outputs[0];
auto rank = args.clientParameters.bufferShape(output).size();
auto result = multi_arity_call_dynamic_rank(func, preparedArgs, rank);
return clientlib::PublicResult::fromBuffers(clientParameters, {result});
;
}

View File

@@ -43,8 +43,8 @@ TensorData multi_arity_call_dynamic_rank(void *(*func)(void *...),
constexpr auto convert = concretelang::clientlib::tensorDataFromMemRef;
switch (rank) {""")
for tensor_rank in range(0, 33):
memref_rank = tensor_rank + 1
for tensor_rank in range(1, 33):
memref_rank = tensor_rank
print(f""" case {tensor_rank}: {{
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<{memref_rank}> (*)(void *...)>(func), args);

View File

@@ -166,30 +166,39 @@ CompilerEngine::getConcreteOptimizerDescription(CompilationResult &res) {
return std::move(descriptions->begin()->second);
}
/// set the fheContext field if the v0Constraint can be computed
/// set the fheContext field if the v0Constraint can be computed
llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
auto descrOrErr = getConcreteOptimizerDescription(res);
if (auto err = descrOrErr.takeError()) {
return err;
}
// The function is non-crypto and without constraint override
if (!descrOrErr.get().hasValue()) {
if (compilerOptions.v0Parameter.hasValue()) {
// parameters come from the compiler options
V0Parameter v0Params = compilerOptions.v0Parameter.value();
if (compilerOptions.largeIntegerParameter.hasValue()) {
v0Params.largeInteger = compilerOptions.largeIntegerParameter;
}
res.fheContext.emplace(mlir::concretelang::V0FHEContext{{0, 0}, v0Params});
return llvm::Error::success();
}
auto descr = std::move(descrOrErr.get().getValue());
auto config = this->compilerOptions.optimizerConfig;
auto fheParams = (compilerOptions.v0Parameter.hasValue())
? compilerOptions.v0Parameter
: getParameter(descr, config);
if (!fheParams) {
return StreamStringError()
<< "Could not determine V0 parameters for 2-norm of "
<< (*descrOrErr)->constraint.norm2 << " and p of "
<< (*descrOrErr)->constraint.p;
// compute parameters
else {
auto descr = getConcreteOptimizerDescription(res);
if (auto err = descr.takeError()) {
return err;
}
if (!descr.get().hasValue()) {
return llvm::Error::success();
}
auto optV0Params =
getParameter(descr.get().value(), compilerOptions.optimizerConfig);
if (!optV0Params) {
return StreamStringError()
<< "Could not determine V0 parameters for 2-norm of "
<< (*descr)->constraint.norm2 << " and p of "
<< (*descr)->constraint.p;
}
res.fheContext.emplace(mlir::concretelang::V0FHEContext{
descr.get().value().constraint, optV0Params.getValue()});
}
res.fheContext.emplace(
mlir::concretelang::V0FHEContext{descr.constraint, fheParams.getValue()});
return llvm::Error::success();
}
@@ -282,7 +291,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
// FHE -> TFHE
if (mlir::concretelang::pipeline::lowerFHEToTFHE(mlirContext, module,
enablePass)
res.fheContext, enablePass)
.failed()) {
return errorDiag("Lowering from FHE to TFHE failed");
}

View File

@@ -110,21 +110,13 @@ JITLambda::call(clientlib::PublicArguments &args,
// Prepare the outputs vector to store the output value of the lambda.
auto numOutputs = 0;
for (auto &output : args.clientParameters.outputs) {
if (output.shape.dimensions.empty()) {
auto shape = args.clientParameters.bufferShape(output);
if (shape.size() == 0) {
// scalar gate
if (output.encryption.hasValue()) {
// encrypted scalar : memref<lweSizexi64>
numOutputs += numArgOfRankedMemrefCallingConvention(1);
} else {
// clear scalar
numOutputs += 1;
}
numOutputs += 1;
} else {
// memref gate : rank+1 if the output is encrypted for the lwe size
// dimension
auto rank = output.shape.dimensions.size() +
(output.encryption.hasValue() ? 1 : 0);
numOutputs += numArgOfRankedMemrefCallingConvention(rank);
// buffer gate
numOutputs += numArgOfRankedMemrefCallingConvention(shape.size());
}
}
std::vector<void *> outputs(numOutputs);
@@ -148,7 +140,6 @@ JITLambda::call(clientlib::PublicArguments &args,
for (auto &out : outputs) {
rawArgs[i++] = &out;
}
// Invoke
if (auto err = invokeRaw(rawArgs)) {
return std::move(err);
@@ -159,14 +150,14 @@ JITLambda::call(clientlib::PublicArguments &args,
{
size_t outputOffset = 0;
for (auto &output : args.clientParameters.outputs) {
if (output.shape.dimensions.empty() && !output.encryption.hasValue()) {
// clear scalar
auto shape = args.clientParameters.bufferShape(output);
if (shape.size() == 0) {
// scalar scalar
buffers.push_back(
clientlib::tensorDataFromScalar((uint64_t)outputs[outputOffset++]));
} else {
// encrypted scalar, and tensor gate are memref
auto rank = output.shape.dimensions.size() +
(output.encryption.hasValue() ? 1 : 0);
// buffer gate
auto rank = shape.size();
auto allocated = (uint64_t *)outputs[outputOffset++];
auto aligned = (uint64_t *)outputs[outputOffset++];
auto offset = (size_t)outputs[outputOffset++];

View File

@@ -182,6 +182,7 @@ markFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::LogicalResult
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("FHEToTFHE", pm, context);
@@ -192,8 +193,14 @@ lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
// to linalg.generic operations
addPotentiallyNestedPass(pm, mlir::createLinalgGeneralizationPass(),
enablePass);
addPotentiallyNestedPass(pm, mlir::concretelang::createConvertFHEToTFHEPass(),
enablePass);
mlir::concretelang::ApplyLookupTableLowering lowerStrategy =
mlir::concretelang::KeySwitchBoostrapLowering;
if (fheContext.hasValue() && fheContext->parameter.largeInteger.hasValue()) {
lowerStrategy = mlir::concretelang::WopPBSLowering;
}
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertFHEToTFHEPass(lowerStrategy),
enablePass);
return pm.run(module.getOperation());
}
@@ -260,6 +267,8 @@ lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("BConcreteToStd", pm, context);
addPotentiallyNestedPass(pm, mlir::concretelang::createEliminateCRTOps(),
enablePass);
addPotentiallyNestedPass(pm, mlir::concretelang::createAddRuntimeContext(),
enablePass);
return pm.run(module.getOperation());

View File

@@ -3,6 +3,7 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <llvm/ADT/Optional.h>
#include <llvm/ADT/STLExtras.h>
#include <llvm/Support/Error.h>
@@ -12,6 +13,7 @@
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
#include "concretelang/Support/Error.h"
#include "concretelang/Support/V0Curves.h"
namespace mlir {
@@ -20,6 +22,7 @@ namespace concretelang {
using ::concretelang::clientlib::BIG_KEY;
using ::concretelang::clientlib::CircuitGate;
using ::concretelang::clientlib::ClientParameters;
using ::concretelang::clientlib::Encoding;
using ::concretelang::clientlib::EncryptionGate;
using ::concretelang::clientlib::LweSecretKeyID;
using ::concretelang::clientlib::Precision;
@@ -52,23 +55,21 @@ llvm::Expected<CircuitGate> gateFromMLIRType(LweSecretKeyID secretKeyID,
},
};
}
if (auto lweType = type.dyn_cast_or_null<
mlir::concretelang::Concrete::LweCiphertextType>()) {
// TODO - Get the width from the LWECiphertextType instead of global
// precision (could be possible after merge concrete-ciphertext-parameter)
size_t precision = (size_t)lweType.getP();
if (auto lweTy = type.dyn_cast_or_null<
mlir::concretelang::Concrete::LweCiphertextType>()) {
return CircuitGate{
/* .encryption = */ llvm::Optional<EncryptionGate>({
/* .secretKeyID = */ secretKeyID,
/* .variance = */ variance,
/* .encoding = */
{
/* .precision = */ precision,
/* .precision = */ lweTy.getP(),
/* .crt = */ lweTy.getCrtDecomposition().vec(),
},
}),
/*.shape = */
{
/*.width = */ precision,
/*.width = */ lweTy.getP(),
/*.dimensions = */ std::vector<int64_t>(),
/*.size = */ 0,
},
@@ -99,6 +100,7 @@ createClientParametersForV0(V0FHEContext fheContext,
auto v0Param = fheContext.parameter;
Variance encryptionVariance = v0Curve->getVariance(
v0Param.glweDimension, v0Param.getPolynomialSize(), 64);
// Variance encryptionVariance = 0.;
Variance keyswitchVariance = v0Curve->getVariance(1, v0Param.nSmall, 64);
// Static client parameters from global parameters for v0
ClientParameters c;
@@ -138,10 +140,9 @@ createClientParametersForV0(V0FHEContext fheContext,
return op.getName() == functionName;
});
if (funcOp == rangeOps.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find the function for generate client parameters '" +
functionName + "'",
llvm::inconvertibleErrorCode());
return StreamStringError(
"cannot find the function for generate client parameters: ")
<< functionName;
}
// Create input and output circuit gate parameters

View File

@@ -170,7 +170,8 @@ llvm::cl::opt<std::string>
llvm::cl::list<uint64_t>
jitArgs("jit-args",
llvm::cl::desc("Value of arguments to pass to the main func"),
llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore);
llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore,
llvm::cl::MiscFlags::CommaSeparated);
llvm::cl::opt<std::string> jitKeySetCachePath(
"jit-keyset-cache-path",
@@ -210,6 +211,31 @@ llvm::cl::list<int64_t> v0Parameter(
"logPolynomialSize, nSmall, brLevel, brLobBase, ksLevel, ksLogBase]"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
llvm::cl::list<int64_t> largeIntegerCRTDecomposition(
"large-integer-crt-decomposition",
llvm::cl::desc(
"Use the large integer to lower FHE.eint with the given decomposition, "
"must be used with the other large-integers options (experimental)"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
llvm::cl::list<int64_t> largeIntegerPackingKeyswitch(
"large-integer-packing-keyswitch",
llvm::cl::desc(
"Use the large integer to lower FHE.eint with the given parameters for "
"packing keyswitch, must be used with the other large-integers options "
"(experimental) [inputLweDimension, inputLweCount, "
"outputPolynomialSize, level, baseLog]"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
llvm::cl::list<int64_t> largeIntegerCircuitBootstrap(
"large-integer-circuit-bootstrap",
llvm::cl::desc(
"Use the large integer to lower FHE.eint with the given parameters for "
"the cicuit boostrap, must be used with the other large-integers "
"options "
"(experimental) [level, baseLog]"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
} // namespace cmdline
namespace llvm {
@@ -257,6 +283,7 @@ cmdlineCompilationOptions() {
if (!cmdline::fhelinalgTileSizes.empty())
options.fhelinalgTileSizes.emplace(cmdline::fhelinalgTileSizes);
// Setup the v0 parameter options
if (!cmdline::v0Parameter.empty()) {
if (cmdline::v0Parameter.size() != 7) {
return llvm::make_error<llvm::StringError>(
@@ -270,10 +297,44 @@ cmdlineCompilationOptions() {
cmdline::v0Parameter[6]);
}
if (!cmdline::v0Constraint.empty() && !cmdline::optimizerV0) {
return llvm::make_error<llvm::StringError>(
"You must use --v0-constraint with --optimizer-v0-strategy",
llvm::inconvertibleErrorCode());
// Setup the large integer options
if (!cmdline::largeIntegerCRTDecomposition.empty() ||
!cmdline::largeIntegerPackingKeyswitch.empty() ||
!cmdline::largeIntegerPackingKeyswitch.empty()) {
if (cmdline::largeIntegerCRTDecomposition.empty() ||
cmdline::largeIntegerPackingKeyswitch.empty() ||
cmdline::largeIntegerPackingKeyswitch.empty()) {
return llvm::make_error<llvm::StringError>(
"The large-integers options should all be set",
llvm::inconvertibleErrorCode());
}
if (cmdline::largeIntegerPackingKeyswitch.size() != 5) {
return llvm::make_error<llvm::StringError>(
"The large-integers-packing-keyswitch must be a list of 5 integer",
llvm::inconvertibleErrorCode());
}
if (cmdline::largeIntegerCircuitBootstrap.size() != 2) {
return llvm::make_error<llvm::StringError>(
"The large-integers-packing-keyswitch must be a list of 2 integer",
llvm::inconvertibleErrorCode());
}
options.largeIntegerParameter = mlir::concretelang::LargeIntegerParameter();
options.largeIntegerParameter->crtDecomposition =
cmdline::largeIntegerCRTDecomposition;
options.largeIntegerParameter->wopPBS.packingKeySwitch.inputLweDimension =
cmdline::largeIntegerPackingKeyswitch[0];
options.largeIntegerParameter->wopPBS.packingKeySwitch.inputLweCount =
cmdline::largeIntegerPackingKeyswitch[1];
options.largeIntegerParameter->wopPBS.packingKeySwitch
.outputPolynomialSize = cmdline::largeIntegerPackingKeyswitch[2];
options.largeIntegerParameter->wopPBS.packingKeySwitch.level =
cmdline::largeIntegerPackingKeyswitch[3];
options.largeIntegerParameter->wopPBS.packingKeySwitch.baseLog =
cmdline::largeIntegerPackingKeyswitch[4];
options.largeIntegerParameter->wopPBS.circuitBootstrap.level =
cmdline::largeIntegerCircuitBootstrap[0];
options.largeIntegerParameter->wopPBS.circuitBootstrap.baseLog =
cmdline::largeIntegerCircuitBootstrap[1];
}
options.optimizerConfig.p_error = cmdline::pbsErrorProbability;
@@ -502,9 +563,9 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
}
if (cmdline::action == Action::COMPILE) {
auto err =
outputLib->emitArtifacts(/*sharedLib=*/true, /*staticLib=*/true,
/*clientParameters=*/true, /*cppHeader=*/true);
auto err = outputLib->emitArtifacts(
/*sharedLib=*/true, /*staticLib=*/true,
/*clientParameters=*/true, /*cppHeader=*/true);
if (err) {
return mlir::failure();
}

View File

@@ -1,10 +1,19 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
//CHECK: func.func @add_glwe(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
//CHECK: func @add_lwe_ciphertexts(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.add_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, tensor<2049xi64>) -> tensor<2049xi64>
//CHECK: return %[[V0]] : tensor<2049xi64>
//CHECK: }
func.func @add_glwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
func.func @add_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
%0 = "Concrete.add_lwe_ciphertexts"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7>
return %0 : !Concrete.lwe_ciphertext<2048,7>
}
//CHECK: func @add_crt_lwe_ciphertexts(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64>
//CHECK: return %[[V0]] : tensor<5x2049xi64>
//CHECK: }
func.func @add_crt_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>, %arg1: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7> {
%0 = "Concrete.add_lwe_ciphertexts"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>, !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>
return %0 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>
}

View File

@@ -1,30 +1,40 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
//CHECK: func.func @add_glwe_const_int(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {
//CHECK: func.func @add_glwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
//CHECK: %c1_i8 = arith.constant 1 : i8
//CHECK: %0 = arith.extui %c1_i8 : i8 to i64
//CHECK: %[[V0:.*]] = arith.extui %c1_i8 : i8 to i64
//CHECK: %c56_i64 = arith.constant 56 : i64
//CHECK: %1 = arith.shli %0, %c56_i64 : i64
//CHECK: %2 = "BConcrete.add_plaintext_lwe_buffer"(%arg0, %1) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: return %2 : tensor<1025xi64>
//CHECK: %[[V1:.*]] = arith.shli %[[V0]], %c56_i64 : i64
//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_buffer"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: return %[[V2]] : tensor<1025xi64>
//CHECK: }
func.func @add_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
%0 = arith.constant 1 : i8
%1 = "Concrete.encode_int"(%0) : (i8) -> !Concrete.plaintext<8>
%2 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %1) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.plaintext<8>) -> !Concrete.lwe_ciphertext<1024,7>
%2 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7>
return %2 : !Concrete.lwe_ciphertext<1024,7>
}
//CHECK: func.func @add_glwe_int(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64> {
//CHECK: %0 = arith.extui %arg1 : i5 to i64
//CHECK: func.func @add_glwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i5) -> tensor<1025xi64> {
//CHECK: %[[V0:.*]] = arith.extui %[[A1]] : i5 to i64
//CHECK: %c59_i64 = arith.constant 59 : i64
//CHECK: %1 = arith.shli %0, %c59_i64 : i64
//CHECK: %2 = "BConcrete.add_plaintext_lwe_buffer"(%arg0, %1) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: return %2 : tensor<1025xi64>
//CHECK: %[[V1:.*]] = arith.shli %[[V0]], %c59_i64 : i64
//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_buffer"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: return %[[V2]] : tensor<1025xi64>
//CHECK: }
func.func @add_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> {
%0 = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5>
%1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<1024,4>
%1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4>
return %1 : !Concrete.lwe_ciphertext<1024,4>
}
//CHECK: func.func @add_plaintext_lwe_ciphertext(%[[A0:.*]]: tensor<5x1025xi64>) -> tensor<5x1025xi64> {
//CHECK: %c1_i8 = arith.constant 1 : i8
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_buffer"(%[[A0]], %c1_i8) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i8) -> tensor<5x1025xi64>
//CHECK: return %[[V0]] : tensor<5x1025xi64>
//CHECK: }
func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7> {
%0 = arith.constant 1 : i8
%2 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>, i8) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>
return %2 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>
}

View File

@@ -6,3 +6,10 @@
func.func @identity(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
return %arg0 : !Concrete.lwe_ciphertext<1024,7>
}
// CHECK: func.func @identity_crt(%arg0: tensor<5x1025xi64>) -> tensor<5x1025xi64> {
// CHECK-NEXT: return %arg0 : tensor<5x1025xi64>
// CHECK-NEXT: }
func.func @identity_crt(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7> {
return %arg0 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>
}

View File

@@ -1,26 +1,33 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
//CHECK: func.func @mul_lwe_const_int(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {
//CHECK: func.func @mul_lwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
//CHECK: %c1_i8 = arith.constant 1 : i8
//CHECK: %0 = arith.extui %c1_i8 : i8 to i64
//CHECK: %1 = "BConcrete.mul_cleartext_lwe_buffer"(%arg0, %0) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: return %1 : tensor<1025xi64>
//CHECK: %[[V0:.*]] = arith.extui %c1_i8 : i8 to i64
//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_buffer"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: return %[[V1]] : tensor<1025xi64>
//CHECK: }
func.func @mul_lwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
%0 = arith.constant 1 : i8
%1 = "Concrete.int_to_cleartext"(%0) : (i8) -> !Concrete.cleartext<8>
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.cleartext<8>) -> !Concrete.lwe_ciphertext<1024,7>
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7>
return %2 : !Concrete.lwe_ciphertext<1024,7>
}
//CHECK: func.func @mul_lwe_int(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64> {
//CHECK: %0 = arith.extui %arg1 : i5 to i64
//CHECK: %1 = "BConcrete.mul_cleartext_lwe_buffer"(%arg0, %0) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: return %1 : tensor<1025xi64>
//CHECK: func.func @mul_lwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i5) -> tensor<1025xi64> {
//CHECK: %[[V0:.*]] = arith.extui %[[A1]] : i5 to i64
//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_buffer"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: return %[[V1]] : tensor<1025xi64>
//CHECK: }
func.func @mul_lwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> {
%0 = "Concrete.int_to_cleartext"(%arg1) : (i5) -> !Concrete.cleartext<5>
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.cleartext<5>) -> !Concrete.lwe_ciphertext<1024,4>
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4>
return %1 : !Concrete.lwe_ciphertext<1024,4>
}
//CHECK: func.func @mul_cleartext_lwe_ciphertext_crt(%[[A0:.*]]: tensor<5x1025xi64>, %[[A1:.*]]: i5) -> tensor<5x1025xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i5) -> tensor<5x1025xi64>
//CHECK: return %[[V0]] : tensor<5x1025xi64>
//CHECK: }
func.func @mul_cleartext_lwe_ciphertext_crt(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4> {
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>, i5) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>
return %1 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>
}

View File

@@ -8,3 +8,12 @@ func.func @neg_lwe(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_cip
%0 = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4>
return %0 : !Concrete.lwe_ciphertext<1024,4>
}
//CHECK: func.func @negate_lwe_ciphertext(%[[A0:.*]]: tensor<5x1025xi64>) -> tensor<5x1025xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_buffer"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>) -> tensor<5x1025xi64>
//CHECK: return %[[V0]] : tensor<5x1025xi64>
//CHECK: }
func.func @negate_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4> {
%0 = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>
return %0 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>
}

View File

@@ -1,24 +1,20 @@
// RUN: concretecompiler --split-input-file --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
// CHECK: func
// DISABLED-CHECK: func.func @tensor_collapse_shape(%arg0: tensor<2x3x4x5x6x1025xi64>) -> tensor<720x1025xi64> {
// DISABLED-CHECK-NEXT: %0 = bufferization.to_memref %arg0 : memref<2x3x4x5x6x1025xi64>
// DISABLED-CHECK-NEXT: %1 = memref.collapse_shape %0 [[_:\[\[0, 1, 2, 3, 4\], \[5\]\]]] : memref<2x3x4x5x6x1025xi64> into memref<720x1025xi64>
// DISABLED-CHECK-NEXT: %2 = bufferization.to_tensor %1 : memref<720x1025xi64>
// DISABLED-CHECK-NEXT: return %2 : tensor<720x1025xi64>
//CHECK: func.func @tensor_collapse_shape(%[[A0:.*]]: tensor<2x3x4x5x6x1025xi64>) -> tensor<720x1025xi64> {
//CHECK: %[[V0:.*]] = tensor.collapse_shape %[[A0]] [[_:\[\[0, 1, 2, 3, 4\], \[5\]\]]] : tensor<2x3x4x5x6x1025xi64> into tensor<720x1025xi64>
//CHECK: return %[[V0]] : tensor<720x1025xi64>
//CHECK: }
func.func @tensor_collapse_shape(%arg0: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<720x!Concrete.lwe_ciphertext<1024,4>> {
%0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4]] {MANP = 1 : ui1}: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext<1024,4>> into tensor<720x!Concrete.lwe_ciphertext<1024,4>>
return %0 : tensor<720x!Concrete.lwe_ciphertext<1024,4>>
}
// -----
// DISABLED-CHECK: func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x5x1025xi64>) -> tensor<5x6x1025xi64> {
// DISABLED-CHECK-NEXT: %0 = bufferization.to_memref %arg0 : memref<2x3x5x1025xi64>
// DISABLED-CHECK-NEXT: %1 = memref.collapse_shape %0 [[_:\[\[0, 1, 2\], \[3\]\]]] : memref<2x3x5x1025xi64> into memref<30x1025xi64>
// DISABLED-CHECK-NEXT: %2 = memref.expand_shape %1 [[_:\[\[0, 1\], \[2\]\]]] : memref<30x1025xi64> into memref<5x6x1025xi64>
// DISABLED-CHECK-NEXT: %3 = bufferization.to_tensor %2 : memref<5x6x1025xi64>
// DISABLED-CHECK-NEXT: return %3 : tensor<5x6x1025xi64>
//CHECK: func.func @tensor_collatenspse_shape(%[[A0:.*]]: tensor<2x3x5x1025xi64>) -> tensor<5x6x1025xi64> {
//CHECK: %[[V0:.*]] = tensor.collapse_shape %[[A0]] [[_:\[\[0, 1, 2\], \[3\]\]]] : tensor<2x3x5x1025xi64> into tensor<30x1025xi64>
//CHECK: %[[V1:.*]] = tensor.expand_shape %[[V0]] [[_:\[\[0, 1\], \[2\]\]]] : tensor<30x1025xi64> into tensor<5x6x1025xi64>
//CHECK: return %[[V1]] : tensor<5x6x1025xi64>
//CHECK: }
func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x5x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<5x6x!Concrete.lwe_ciphertext<1024,4>> {
%0 = tensor.collapse_shape %arg0 [[0, 1, 2]] {MANP = 1 : ui1}: tensor<2x3x5x!Concrete.lwe_ciphertext<1024,4>> into tensor<30x!Concrete.lwe_ciphertext<1024,4>>
%1 = tensor.expand_shape %0 [[0, 1]] {MANP = 1 : ui1}: tensor<30x!Concrete.lwe_ciphertext<1024,4>> into tensor<5x6x!Concrete.lwe_ciphertext<1024,4>>
@@ -26,13 +22,31 @@ func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x5x!Concrete.lwe_ciphertex
}
// -----
// DISABLED-CHECK: func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x2x3x4x1025xi64>) -> tensor<6x2x12x1025xi64> {
// DISABLED-CHECK-NEXT: %0 = bufferization.to_memref %arg0 : memref<2x3x2x3x4x1025xi64>
// DISABLED-CHECK-NEXT: %1 = memref.collapse_shape %0 [[_:\[\[0, 1\], \[2\], \[3, 4\], \[5\]\]]] : memref<2x3x2x3x4x1025xi64> into memref<6x2x12x1025xi64>
// DISABLED-CHECK-NEXT: %2 = bufferization.to_tensor %1 : memref<6x2x12x1025xi64>
// DISABLED-CHECK-NEXT: return %2 : tensor<6x2x12x1025xi64>
//CHECK: func.func @tensor_collatenspse_shape(%[[A0:.*]]: tensor<2x3x2x3x4x1025xi64>) -> tensor<6x2x12x1025xi64> {
//CHECK: %[[V0:.*]] = tensor.collapse_shape %[[A0]] [[_:\[\[0, 1\], \[2\], \[3, 4\], \[5\]\]]] : tensor<2x3x2x3x4x1025xi64> into tensor<6x2x12x1025xi64>
//CHECK: return %[[V0]] : tensor<6x2x12x1025xi64>
//CHECK: }
func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x2x3x4x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<6x2x12x!Concrete.lwe_ciphertext<1024,4>> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] {MANP = 1 : ui1}: tensor<2x3x2x3x4x!Concrete.lwe_ciphertext<1024,4>> into tensor<6x2x12x!Concrete.lwe_ciphertext<1024,4>>
return %0 : tensor<6x2x12x!Concrete.lwe_ciphertext<1024,4>>
}
// -----
//CHECK: func.func @tensor_collapse_shape_crt(%[[A0:.*]]: tensor<2x3x4x5x6x5x1025xi64>) -> tensor<720x5x1025xi64> {
//CHECK: %[[V0:.*]] = tensor.collapse_shape %[[A0]] [[_:\[\[0, 1, 2, 3, 4\], \[5\], \[6\]\]]] : tensor<2x3x4x5x6x5x1025xi64> into tensor<720x5x1025xi64>
//CHECK: return %[[V0]] : tensor<720x5x1025xi64>
//CHECK: }
func.func @tensor_collapse_shape_crt(%arg0: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>>) -> tensor<720x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>> {
%0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4]] {MANP = 1 : ui1}: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>> into tensor<720x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>>
return %0 : tensor<720x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>>
}
// -----
//CHECK: func.func @tensor_expand_shape_crt(%[[A0:.*]]: tensor<30x1025xi64>) -> tensor<5x6x1025xi64> {
//CHECK: %[[V0:.*]] = tensor.expand_shape %[[A0]] [[_:\[\[0, 1\], \[2\]\]]] : tensor<30x1025xi64> into tensor<5x6x1025xi64>
//CHECK: return %[[V0]] : tensor<5x6x1025xi64>
//CHECK: }
func.func @tensor_expand_shape_crt(%arg0: tensor<30x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<5x6x!Concrete.lwe_ciphertext<1024,4>> {
%0 = tensor.expand_shape %arg0 [[0, 1]] {MANP = 1 : ui1}: tensor<30x!Concrete.lwe_ciphertext<1024,4>> into tensor<5x6x!Concrete.lwe_ciphertext<1024,4>>
return %0 : tensor<5x6x!Concrete.lwe_ciphertext<1024,4>>
}

View File

@@ -1,7 +1,15 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
// CHECK: func.func @tensor_identity(%arg0: tensor<2x3x4x1025xi64>) -> tensor<2x3x4x1025xi64> {
// CHECK-NEXT: return %arg0 : tensor<2x3x4x1025xi64>
// CHECK-NEXT: }
func.func @tensor_identity(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>> {
return %arg0 : tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>>
}
// CHECK: func.func @tensor_identity_crt(%arg0: tensor<2x3x4x5x1025xi64>) -> tensor<2x3x4x5x1025xi64> {
// CHECK-NEXT: return %arg0 : tensor<2x3x4x5x1025xi64>
// CHECK-NEXT: }
func.func @tensor_identity_crt(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>> {
return %arg0 : tensor<2x3x4x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>>
}

View File

@@ -1,22 +1,22 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
// CHECK-LABEL: func.func @add_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7>
//CHECK: func.func @add_glwe_const_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
//CHECK: %c1_i8 = arith.constant 1 : i8
//CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[A0]], %c1_i8) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7>
//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,7>
//CHECK: }
func.func @add_glwe_const_int(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{7}> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
// CHECK-NEXT: %[[V2:.*]] = "Concrete.encode_int"(%[[V1]]) : (i8) -> !Concrete.plaintext<8>
// CHECK-NEXT: %[[V3:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %[[V2]]) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.plaintext<8>) -> !Concrete.lwe_ciphertext<1024,7>
// CHECK-NEXT: return %[[V3]] : !Concrete.lwe_ciphertext<1024,7>
%0 = arith.constant 1 : i8
%1 = "TFHE.add_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,1,64}{7}>, i8) -> (!TFHE.glwe<{1024,1,64}{7}>)
return %1: !TFHE.glwe<{1024,1,64}{7}>
}
// CHECK-LABEL: func.func @add_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4>
//CHECK: func.func @add_glwe_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i5) -> !Concrete.lwe_ciphertext<1024,4> {
//CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[A0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4>
//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,4>
//CHECK: }
func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !TFHE.glwe<{1024,1,64}{4}> {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5>
// CHECK-NEXT: %[[V2:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %[[V1]]) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<1024,4>
// CHECK-NEXT: return %[[V2]] : !Concrete.lwe_ciphertext<1024,4>
%1 = "TFHE.add_glwe_int"(%arg0, %arg1): (!TFHE.glwe<{1024,1,64}{4}>, i5) -> (!TFHE.glwe<{1024,1,64}{4}>)
return %1: !TFHE.glwe<{1024,1,64}{4}>
}

View File

@@ -1,22 +1,22 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
// CHECK-LABEL: func.func @mul_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7>
//CHECK: func.func @mul_glwe_const_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
//CHECK: %c1_i8 = arith.constant 1 : i8
//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%[[A0]], %c1_i8) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7>
//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,7>
//CHECK: }
func.func @mul_glwe_const_int(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{7}> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
// CHECK-NEXT: %[[V2:.*]] = "Concrete.int_to_cleartext"(%[[V1]]) : (i8) -> !Concrete.cleartext<8>
// CHECK-NEXT: %[[V3:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %[[V2]]) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.cleartext<8>) -> !Concrete.lwe_ciphertext<1024,7>
// CHECK-NEXT: return %[[V3]] : !Concrete.lwe_ciphertext<1024,7>
%0 = arith.constant 1 : i8
%1 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,1,64}{7}>, i8) -> (!TFHE.glwe<{1024,1,64}{7}>)
return %1: !TFHE.glwe<{1024,1,64}{7}>
}
// CHECK-LABEL: func.func @mul_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4>
//CHECK: func.func @mul_glwe_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i5) -> !Concrete.lwe_ciphertext<1024,4> {
//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%[[A0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4>
//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,4>
//CHECK: }
func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !TFHE.glwe<{1024,1,64}{4}> {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.int_to_cleartext"(%arg1) : (i5) -> !Concrete.cleartext<5>
// CHECK-NEXT: %[[V2:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %[[V1]]) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.cleartext<5>) -> !Concrete.lwe_ciphertext<1024,4>
// CHECK-NEXT: return %[[V2]] : !Concrete.lwe_ciphertext<1024,4>
%1 = "TFHE.mul_glwe_int"(%arg0, %arg1): (!TFHE.glwe<{1024,1,64}{4}>, i5) -> (!TFHE.glwe<{1024,1,64}{4}>)
return %1: !TFHE.glwe<{1024,1,64}{4}>
}

View File

@@ -1,23 +1,23 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
// CHECK-LABEL: func.func @sub_const_int_glwe(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7>
//CHECK: func.func @sub_const_int_glwe(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
//CHECK: %c1_i8 = arith.constant 1 : i8
//CHECK: %[[V0:.*]] = "Concrete.negate_lwe_ciphertext"(%[[A0]]) : (!Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7>
//CHECK: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[V0]], %c1_i8) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7>
//CHECK: return %[[V1]] : !Concrete.lwe_ciphertext<1024,7>
//CHECK: }
func.func @sub_const_int_glwe(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{7}> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
// CHECK-NEXT: %[[NEG:.*]] = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7>
// CHECK-NEXT: %[[V2:.*]] = "Concrete.encode_int"(%[[V1]]) : (i8) -> !Concrete.plaintext<8>
// CHECK-NEXT: %[[V3:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[NEG]], %[[V2]]) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.plaintext<8>) -> !Concrete.lwe_ciphertext<1024,7>
// CHECK-NEXT: return %[[V3]] : !Concrete.lwe_ciphertext<1024,7>
%0 = arith.constant 1 : i8
%1 = "TFHE.sub_int_glwe"(%0, %arg0): (i8, !TFHE.glwe<{1024,1,64}{7}>) -> (!TFHE.glwe<{1024,1,64}{7}>)
return %1: !TFHE.glwe<{1024,1,64}{7}>
}
// CHECK-LABEL: func.func @sub_int_glwe(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4>
//CHECK: func.func @sub_int_glwe(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i5) -> !Concrete.lwe_ciphertext<1024,4> {
//CHECK: %[[V0:.*]] = "Concrete.negate_lwe_ciphertext"(%[[A0]]) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4>
//CHECK: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[V0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4>
//CHECK: return %[[V1]] : !Concrete.lwe_ciphertext<1024,4>
//CHECK: }
func.func @sub_int_glwe(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !TFHE.glwe<{1024,1,64}{4}> {
// CHECK-NEXT: %[[NEG:.*]] = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4>
// CHECK-NEXT: %[[V1:.*]] = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5>
// CHECK-NEXT: %[[V2:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[NEG]], %[[V1]]) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<1024,4>
// CHECK-NEXT: return %[[V2]] : !Concrete.lwe_ciphertext<1024,4>
%1 = "TFHE.sub_int_glwe"(%arg1, %arg0): (i5, !TFHE.glwe<{1024,1,64}{4}>) -> (!TFHE.glwe<{1024,1,64}{4}>)
return %1: !TFHE.glwe<{1024,1,64}{4}>
}

View File

@@ -9,6 +9,15 @@ func.func @add_lwe_ciphertexts(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>)
return %0 : tensor<2049xi64>
}
//CHECK: func.func @add_crt_lwe_ciphertexts(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64>
//CHECK: return %[[V0]] : tensor<5x2049xi64>
//CHECK: }
func.func @add_crt_lwe_ciphertexts(%arg0: tensor<5x2049xi64>, %arg1: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
%0 = "BConcrete.add_crt_lwe_buffer"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> ( tensor<5x2049xi64>)
return %0 : tensor<5x2049xi64>
}
//CHECK: func.func @add_plaintext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64>
//CHECK: return %[[V0]] : tensor<2049xi64>
@@ -18,7 +27,16 @@ func.func @add_plaintext_lwe_ciphertext(%arg0: tensor<2049xi64>, %arg1: i64) ->
return %0 : tensor<2049xi64>
}
//CHECK: func.func @mul_cleartext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> {
//CHECK: func.func @add_plaintext_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: i64) -> tensor<5x2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64>
//CHECK: return %[[V0]] : tensor<5x2049xi64>
//CHECK: }
func.func @add_plaintext_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>, %arg1: i64) -> tensor<5x2049xi64> {
%0 = "BConcrete.add_plaintext_crt_lwe_buffer"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> ( tensor<5x2049xi64>)
return %0 : tensor<5x2049xi64>
}
//CHECK: func @mul_cleartext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64>
//CHECK: return %[[V0]] : tensor<2049xi64>
//CHECK: }
@@ -27,6 +45,15 @@ func.func @mul_cleartext_lwe_ciphertext(%arg0: tensor<2049xi64>, %arg1: i64) ->
return %0 : tensor<2049xi64>
}
//CHECK: func.func @mul_cleartext_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: i64) -> tensor<5x2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64>
//CHECK: return %[[V0]] : tensor<5x2049xi64>
//CHECK: }
func.func @mul_cleartext_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>, %arg1: i64) -> tensor<5x2049xi64> {
%0 = "BConcrete.mul_cleartext_crt_lwe_buffer"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> (tensor<5x2049xi64>)
return %0 : tensor<5x2049xi64>
}
//CHECK: func.func @negate_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.negate_lwe_buffer"(%[[A0]]) : (tensor<2049xi64>) -> tensor<2049xi64>
//CHECK: return %[[V0]] : tensor<2049xi64>
@@ -36,6 +63,15 @@ func.func @negate_lwe_ciphertext(%arg0: tensor<2049xi64>) -> tensor<2049xi64> {
return %0 : tensor<2049xi64>
}
//CHECK: func.func @negate_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_buffer"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> tensor<5x2049xi64>
//CHECK: return %[[V0]] : tensor<5x2049xi64>
//CHECK: }
func.func @negate_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
%0 = "BConcrete.negate_crt_lwe_buffer"(%arg0) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> (tensor<5x2049xi64>)
return %0 : tensor<5x2049xi64>
}
//CHECK: func.func @bootstrap_lwe(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<4096xi64>) -> tensor<2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[A0]], %[[A1]]) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (tensor<2049xi64>, tensor<4096xi64>) -> tensor<2049xi64>
//CHECK: return %[[V0]] : tensor<2049xi64>

View File

@@ -1,14 +1,12 @@
// RUN: concretecompiler --optimize-concrete=false --action=dump-concrete %s 2>&1| FileCheck %s
// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7>
//CHECK: func.func @mul_cleartext_lwe_ciphertext_0(%[[A0:.*]]: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
//CHECK: %c0_i7 = arith.constant 0 : i7
//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%[[A0]], %c0_i7) : (!Concrete.lwe_ciphertext<2048,7>, i7) -> !Concrete.lwe_ciphertext<2048,7>
//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<2048,7>
//CHECK: }
func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
// CHECK-NEXT: %[[V0:.*]] = arith.constant 0 : i7
// CHECK-NEXT: %[[V1:.*]] = "Concrete.int_to_cleartext"(%[[V0]]) : (i7) -> !Concrete.cleartext<7>
// CHECK-NEXT: %[[V2:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %[[V1]]) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7>
// CHECK-NEXT: return %[[V2]] : !Concrete.lwe_ciphertext<2048,7>
%0 = arith.constant 0 : i7
%1 = "Concrete.int_to_cleartext"(%0) : (i7) -> !Concrete.cleartext<7>
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>)
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0): (!Concrete.lwe_ciphertext<2048,7>, i7) -> (!Concrete.lwe_ciphertext<2048,7>)
return %2: !Concrete.lwe_ciphertext<2048,7>
}

View File

@@ -9,21 +9,21 @@ func.func @add_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !
return %1: !Concrete.lwe_ciphertext<2048,7>
}
// CHECK-LABEL: func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<2048,7>
func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<2048,7> {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<2048,7>
// CHECK-LABEL: func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i5) -> !Concrete.lwe_ciphertext<2048,7>
func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i5) -> !Concrete.lwe_ciphertext<2048,7> {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, i5) -> !Concrete.lwe_ciphertext<2048,7>
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
%1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.plaintext<5>) -> (!Concrete.lwe_ciphertext<2048,7>)
%1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, i5) -> (!Concrete.lwe_ciphertext<2048,7>)
return %1: !Concrete.lwe_ciphertext<2048,7>
}
// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7>
func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7> {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7>
// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i7) -> !Concrete.lwe_ciphertext<2048,7>
func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i7) -> !Concrete.lwe_ciphertext<2048,7> {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, i7) -> !Concrete.lwe_ciphertext<2048,7>
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>)
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, i7) -> (!Concrete.lwe_ciphertext<2048,7>)
return %1: !Concrete.lwe_ciphertext<2048,7>
}
@@ -38,9 +38,9 @@ func.func @negate_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Co
// CHECK-LABEL: func.func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7>
func.func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7> {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7>
// CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, level = -1 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7>
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
%1 = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7>
%1 = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, level = -1 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7>
return %1: !Concrete.lwe_ciphertext<2048,7>
}
@@ -51,22 +51,3 @@ func.func @keyswitch_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.l
%1 = "Concrete.keyswitch_lwe"(%arg0){baseLog = 2 : i32, level = 3 : i32}: (!Concrete.lwe_ciphertext<2048,7>) -> (!Concrete.lwe_ciphertext<2048,7>)
return %1: !Concrete.lwe_ciphertext<2048,7>
}
// CHECK-LABEL: func.func @encode_int(%arg0: i6) -> !Concrete.plaintext<6>
func.func @encode_int(%arg0: i6) -> (!Concrete.plaintext<6>) {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.encode_int"(%arg0) : (i6) -> !Concrete.plaintext<6>
// CHECK-NEXT: return %[[V1]] : !Concrete.plaintext<6>
%0 = "Concrete.encode_int"(%arg0): (i6) -> !Concrete.plaintext<6>
return %0: !Concrete.plaintext<6>
}
// CHECK-LABEL: func.func @int_to_cleartext() -> !Concrete.cleartext<6>
func.func @int_to_cleartext() -> !Concrete.cleartext<6> {
// CHECK-NEXT: %[[V0:.*]] = arith.constant 5 : i6
// CHECK-NEXT: %[[V1:.*]] = "Concrete.int_to_cleartext"(%[[V0]]) : (i6) -> !Concrete.cleartext<6>
// CHECK-NEXT: return %[[V1]] : !Concrete.cleartext<6>
%0 = arith.constant 5 : i6
%1 = "Concrete.int_to_cleartext"(%0) : (i6) -> !Concrete.cleartext<6>
return %1 : !Concrete.cleartext<6>
}

View File

@@ -1,14 +1,14 @@
// RUN: concretecompiler --action=dump-concrete %s 2>&1| FileCheck %s
// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7>
func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7> {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7>
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i7) -> !Concrete.lwe_ciphertext<2048,7>
func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i7) -> !Concrete.lwe_ciphertext<2048,7> {
// CHECK-NEXT: %[[V1:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, i7) -> !Concrete.lwe_ciphertext<2048,7>
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>)
return %1: !Concrete.lwe_ciphertext<2048,7>
}
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, i7) -> (!Concrete.lwe_ciphertext<2048,7>)
return %1: !Concrete.lwe_ciphertext<2048,7>
}
// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7>
func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
@@ -16,8 +16,7 @@ func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
%0 = arith.constant 0 : i7
%1 = "Concrete.int_to_cleartext"(%0) : (i7) -> !Concrete.cleartext<7>
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>)
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0): (!Concrete.lwe_ciphertext<2048,7>, i7) -> (!Concrete.lwe_ciphertext<2048,7>)
return %2: !Concrete.lwe_ciphertext<2048,7>
}
@@ -27,7 +26,6 @@ func.func @mul_cleartext_lwe_ciphertext_minus_1(%arg0: !Concrete.lwe_ciphertext<
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
%0 = arith.constant -1 : i7
%1 = "Concrete.int_to_cleartext"(%0) : (i7) -> !Concrete.cleartext<7>
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>)
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0): (!Concrete.lwe_ciphertext<2048,7>, i7) -> (!Concrete.lwe_ciphertext<2048,7>)
return %2: !Concrete.lwe_ciphertext<2048,7>
}

View File

@@ -13,7 +13,13 @@ func.func @type_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Conc
return %arg0: !Concrete.lwe_ciphertext<2048,7>
}
// CHECK-LABEL: func.func @type_cleartext(%arg0: !Concrete.cleartext<5>) -> !Concrete.cleartext<5>
// CHECK-LABEL: func @type_lwe_ciphertext_with_crt(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>
func.func @type_lwe_ciphertext_with_crt(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7> {
// CHECK-NEXT: return %arg0 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>
return %arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>
}
// CHECK-LABEL: func @type_cleartext(%arg0: !Concrete.cleartext<5>) -> !Concrete.cleartext<5>
func.func @type_cleartext(%arg0: !Concrete.cleartext<5>) -> !Concrete.cleartext<5> {
// CHECK-NEXT: return %arg0 : !Concrete.cleartext<5>
return %arg0: !Concrete.cleartext<5>

View File

@@ -52,3 +52,20 @@ func.func @add_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>, %arg1: !TFHE.glwe<{1024,
return %1: !TFHE.glwe<{1024,12,64}{7}>
}
// -----
// GLWE polynomialSize parameter result
func.func @add_glwe(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, %arg1: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}> {
// expected-error @+1 {{'TFHE.add_glwe' op should have the same GLWE 'crt' parameter}}
%1 = "TFHE.add_glwe"(%arg0, %arg1): (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> (!TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>)
return %1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>
}
// -----
// GLWE polynomialSize parameter inputs
func.func @add_glwe(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, %arg1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}> {
// expected-error @+1 {{'TFHE.add_glwe' op should have the same GLWE 'crt' parameter}}
%1 = "TFHE.add_glwe"(%arg0, %arg1): (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>) -> (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>)
return %1: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>
}

View File

@@ -30,6 +30,16 @@ func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,
// -----
// GLWE crt parameter
func.func @add_glwe_int(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}> {
%0 = arith.constant 1 : i8
// expected-error @+1 {{'TFHE.add_glwe_int' op should have the same GLWE 'crt' parameter}}
%1 = "TFHE.add_glwe_int"(%arg0, %0): (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, i8) -> (!TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>)
return %1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>
}
// -----
// integer width doesn't match GLWE parameter
func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}> {
%0 = arith.constant 1 : i9

View File

@@ -30,6 +30,16 @@ func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,
// -----
// GLWE crt parameter
func.func @mul_glwe_int(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}> {
%0 = arith.constant 1 : i8
// expected-error @+1 {{'TFHE.mul_glwe_int' op should have the same GLWE 'crt' parameter}}
%1 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, i8) -> (!TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>)
return %1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>
}
// -----
// integer width doesn't match GLWE parameter
func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}> {
%0 = arith.constant 1 : i9

View File

@@ -27,6 +27,15 @@ func.func @neg_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,11,6
// -----
// GLWE crt parameter
func.func @neg_glwe(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}> {
// expected-error @+1 {{'TFHE.neg_glwe' op should have the same GLWE 'crt' parameter}}
%1 = "TFHE.neg_glwe"(%arg0): (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> (!TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>)
return %1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>
}
// -----
// integer width doesn't match GLWE parameter
func.func @neg_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,11,64}{7}> {
// expected-error @+1 {{'TFHE.neg_glwe' op should have the same GLWE 'polynomialSize' parameter}}

View File

@@ -28,6 +28,17 @@ func.func @sub_int_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,
return %1: !TFHE.glwe<{1024,11,64}{7}>
}
// -----
// GLWE polynomialSize parameter
func.func @sub_int_glwe(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}> {
%0 = arith.constant 1 : i8
// expected-error @+1 {{'TFHE.sub_int_glwe' op should have the same GLWE 'crt' parameter}}
%1 = "TFHE.sub_int_glwe"(%0, %arg0): (i8, !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> (!TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>)
return %1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>
}
// -----
// integer width doesn't match GLWE parameter

View File

@@ -11,3 +11,15 @@ func.func @glwe_1(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> {
// CHECK-LABEL: return %arg0 : !TFHE.glwe<{_,_,_}{7}>
return %arg0: !TFHE.glwe<{_,_,_}{7}>
}
// CHECK-LABEL: func.func @glwe_crt(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}>) -> !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}>
func.func @glwe_crt(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}>) -> !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}> {
// CHECK-LABEL: return %arg0 : !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}>
return %arg0: !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}>
}
// CHECK-LABEL: func.func @glwe_crt_undef(%arg0: !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}>) -> !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}>
func.func @glwe_crt_undef(%arg0: !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}>) -> !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}> {
// CHECK-LABEL: return %arg0 : !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}>
return %arg0: !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}>
}

View File

@@ -67,7 +67,7 @@ llvm::Error checkResult(ScalarDesc &desc,
}
if (desc.value != res64->getValue()) {
return StreamStringError("unexpected result value: got ")
<< res64->getValue() << "expected " << desc.value;
<< res64->getValue() << " expected " << desc.value;
}
return llvm::Error::success();
}
@@ -204,6 +204,12 @@ template <> struct llvm::yaml::MappingTraits<EndToEndDesc> {
desc.v0Constraint->p = v0constraint[0];
desc.v0Constraint->norm2 = v0constraint[1];
}
mlir::concretelang::LargeIntegerParameter largeInterger;
io.mapOptional("large-integer-crt-decomposition",
largeInterger.crtDecomposition);
if (!largeInterger.crtDecomposition.empty()) {
desc.largeIntegerParameter = largeInterger;
}
}
};

View File

@@ -14,7 +14,7 @@ struct TensorDescription {
ValueWidth width;
};
struct ScalarDesc {
uint64_t value;
int64_t value;
ValueWidth width;
};
@@ -43,6 +43,8 @@ struct EndToEndDesc {
std::vector<TestDescription> tests;
llvm::Optional<mlir::concretelang::V0Parameter> v0Parameter;
llvm::Optional<mlir::concretelang::V0FHEConstraint> v0Constraint;
llvm::Optional<mlir::concretelang::LargeIntegerParameter>
largeIntegerParameter;
};
llvm::Expected<mlir::concretelang::LambdaArgument *>

View File

@@ -1,3 +1,17 @@
## Part of the Concrete Compiler Project, under the BSD3 License with Zama
## Exceptions. See
## https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
## for license information.
##
## This file contains programs and references to end to end test the compiler.
## Program in this file aims to test the `tensor` dialects on encrypted integers.
##
## Operators:
## - tensor.extract
## - tensor.insert
## - tensor.extract_slice
## - tensor.insert_slice
description: identity
program: |
func.func @main(%t: tensor<2x10x!FHE.eint<6>>) -> tensor<2x10x!FHE.eint<6>> {
@@ -14,6 +28,24 @@ tests:
0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
shape: [2,10]
---
description: identity_16bits
program: |
func.func @main(%t: tensor<2x10x!FHE.eint<16>>) -> tensor<2x10x!FHE.eint<16>> {
return %t : tensor<2x10x!FHE.eint<16>>
}
tests:
- inputs:
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
shape: [2,10]
outputs:
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
shape: [2,10]
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: extract
program: |
func.func @main(%t: tensor<2x10x!FHE.eint<6>>, %i: index, %j: index) ->
@@ -59,6 +91,157 @@ tests:
outputs:
- scalar: 9
---
description: extract_16bits
program: |
func.func @main(%t: tensor<2x10x!FHE.eint<16>>, %i: index, %j: index) ->
!FHE.eint<16> {
%c = tensor.extract %t[%i, %j] : tensor<2x10x!FHE.eint<16>>
return %c : !FHE.eint<16>
}
tests:
- inputs:
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
shape: [2,10]
- scalar: 0
- scalar: 0
outputs:
- scalar: 65535
- inputs:
- tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
shape: [2,10]
- scalar: 0
- scalar: 9
outputs:
- scalar: 0
- inputs:
- tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
shape: [2,10]
- scalar: 1
- scalar: 0
outputs:
- scalar: 0
- inputs:
- tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
shape: [2,10]
- scalar: 1
- scalar: 9
outputs:
- scalar: 9
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: insert
program: |
func.func @main(%t: tensor<2x10x!FHE.eint<6>>, %i: index, %j: index, %x: !FHE.eint<6>) -> tensor<2x10x!FHE.eint<6>> {
%r = tensor.insert %x into %t[%i, %j] : tensor<2x10x!FHE.eint<6>>
return %r : tensor<2x10x!FHE.eint<6>>
}
tests:
- inputs:
- tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
shape: [2,10]
- scalar: 0
- scalar: 0
- scalar: 42
outputs:
- tensor: [42, 12, 7, 43, 52, 9, 26, 34, 22, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
shape: [2,10]
- inputs:
- tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
shape: [2,10]
- scalar: 0
- scalar: 1
- scalar: 42
outputs:
- tensor: [63, 42, 7, 43, 52, 9, 26, 34, 22, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
shape: [2,10]
- inputs:
- tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0,
42, 1, 2, 3, 4, 5, 6, 7, 8, 9]
shape: [2,10]
- scalar: 1
- scalar: 0
- scalar: 42
outputs:
- tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0,
42, 1, 2, 3, 4, 5, 6, 7, 8, 9]
shape: [2,10]
- inputs:
- tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0,
42, 1, 2, 3, 4, 5, 6, 7, 8, 9]
shape: [2,10]
- scalar: 1
- scalar: 9
- scalar: 42
outputs:
- tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0,
42, 1, 2, 3, 4, 5, 6, 7, 8, 42]
shape: [2,10]
---
description: insert_16bits
program: |
func.func @main(%t: tensor<2x10x!FHE.eint<16>>, %i: index, %j: index, %x: !FHE.eint<16>) -> tensor<2x10x!FHE.eint<16>> {
%r = tensor.insert %x into %t[%i, %j] : tensor<2x10x!FHE.eint<16>>
return %r : tensor<2x10x!FHE.eint<16>>
}
tests:
- inputs:
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
shape: [2,10]
- scalar: 0
- scalar: 0
- scalar: 42
outputs:
- tensor: [42, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
shape: [2,10]
- inputs:
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
shape: [2,10]
- scalar: 0
- scalar: 1
- scalar: 42
outputs:
- tensor: [65535, 42, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
shape: [2,10]
- inputs:
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
shape: [2,10]
- scalar: 1
- scalar: 0
- scalar: 42
outputs:
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
42, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
shape: [2,10]
- inputs:
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
shape: [2,10]
- scalar: 1
- scalar: 9
- scalar: 42
outputs:
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 42]
shape: [2,10]
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: extract_slice
program: |
func.func @main(%t: tensor<2x10x!FHE.eint<6>>) -> tensor<1x5x!FHE.eint<6>> {
@@ -91,6 +274,24 @@ tests:
- tensor: [ 5, 6, 7, 8, 9]
shape: [5]
---
description: extract_slice_16bits
program: |
func.func @main(%t: tensor<2x10x!FHE.eint<16>>) -> tensor<1x5x!FHE.eint<16>> {
%r = tensor.extract_slice %t[1, 5][1, 5][1, 1] : tensor<2x10x!FHE.eint<16>> to tensor<1x5x!FHE.eint<16>>
return %r : tensor<1x5x!FHE.eint<16>>
}
tests:
- inputs:
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
shape: [2,10]
outputs:
- tensor: [64814, 65491, 4271, 9294, 0]
shape: [1,5]
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: extract_slice_stride
program: |
func.func @main(%t: tensor<2x10x!FHE.eint<6>>) -> tensor<1x5x!FHE.eint<6>> {
@@ -122,6 +323,24 @@ tests:
- tensor: [3, 2, 1]
shape: [3]
---
description: extract_slice_stride_16bits
program: |
func.func @main(%t: tensor<2x10x!FHE.eint<6>>) -> tensor<1x5x!FHE.eint<6>> {
%r = tensor.extract_slice %t[1, 0][1, 5][1, 2] : tensor<2x10x!FHE.eint<6>> to tensor<1x5x!FHE.eint<6>>
return %r : tensor<1x5x!FHE.eint<6>>
}
tests:
- inputs:
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
shape: [2,10]
outputs:
- tensor: [38786, 65112, 60515, 65491, 9294]
shape: [1,5]
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: insert_slice
program: |
func.func @main(%t0: tensor<2x10x!FHE.eint<6>>, %t1: tensor<2x2x!FHE.eint<6>>) -> tensor<2x10x!FHE.eint<6>> {
@@ -179,3 +398,25 @@ tests:
- tensor: [0, 1, 2,
3, 4, 5]
shape: [2, 3]
---
description: insert_slice_16bits
program: |
func.func @main(%t0: tensor<2x10x!FHE.eint<16>>, %t1: tensor<2x2x!FHE.eint<16>>) -> tensor<2x10x!FHE.eint<16>> {
%r = tensor.insert_slice %t1 into %t0[0, 5][2, 2][1, 1] : tensor<2x2x!FHE.eint<16>> into tensor<2x10x!FHE.eint<16>>
return %r : tensor<2x10x!FHE.eint<16>>
}
tests:
- inputs:
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
shape: [2,10]
- tensor: [1000, 1001,
1002, 1003]
shape: [2,2]
outputs:
- tensor: [65535, 46706, 18752, 55384, 55709, 1000, 1001, 57650, 45551, 5769,
38786, 36362, 65112, 5748, 60515, 1002, 1003, 4271, 9294, 0]
shape: [2,10]
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]

View File

@@ -9,6 +9,28 @@ tests:
outputs:
- scalar: 1
---
description: identity_16bits
program: |
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
return %arg0: !FHE.eint<16>
}
tests:
- inputs:
- scalar: 1
outputs:
- scalar: 1
- inputs:
- scalar: 0
outputs:
- scalar: 0
- inputs:
- scalar: 72071
outputs:
- scalar: 72071
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: zero_tensor
program: |
func.func @main() -> tensor<2x2x4x!FHE.eint<6>> {
@@ -20,6 +42,20 @@ tests:
- tensor: [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
shape: [2,2,4]
---
description: zero_tensor_16bits
program: |
func.func @main() -> tensor<2x2x4x!FHE.eint<16>> {
%0 = "FHE.zero_tensor"() : () -> tensor<2x2x4x!FHE.eint<16>>
return %0 : tensor<2x2x4x!FHE.eint<16>>
}
tests:
- outputs:
- tensor: [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
shape: [2,2,4]
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: add_eint_int_cst
program: |
func.func @main(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
@@ -32,18 +68,30 @@ tests:
- scalar: 0
outputs:
- scalar: 1
---
description: add_eint_int_cst_16bits
program: |
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
%0 = arith.constant 1 : i17
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<16>, i17) -> (!FHE.eint<16>)
return %1: !FHE.eint<16>
}
tests:
- inputs:
- scalar: 0
outputs:
- scalar: 1
outputs:
- scalar: 2
- inputs:
- scalar: 2
- scalar: 72070
outputs:
- scalar: 3
- scalar: 72071
- inputs:
- scalar: 3
- scalar: 72071
outputs:
- scalar: 4
- scalar: 0
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: add_eint_int_arg
program: |
@@ -63,30 +111,75 @@ tests:
outputs:
- scalar: 3
---
description: sub_int_eint_cst
description: add_eint_int_arg_16bits
program: |
func.func @main(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
%0 = arith.constant 7 : i3
%1 = "FHE.sub_int_eint"(%0, %arg0): (i3, !FHE.eint<2>) -> (!FHE.eint<2>)
func.func @main(%arg0: !FHE.eint<16>, %arg1: i17) -> !FHE.eint<16> {
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<16>, i17) -> (!FHE.eint<16>)
return %1: !FHE.eint<16>
}
tests:
- inputs:
- scalar: 0
- scalar: 0
outputs:
- scalar: 0
- inputs:
- scalar: 0
- scalar: 5
outputs:
- scalar: 5
- inputs:
- scalar: 36036
- scalar: 36035
outputs:
- scalar: 72071
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: add_eint
program: |
func.func @main(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}
tests:
- inputs:
- scalar: 0
- scalar: 1
outputs:
- scalar: 6
- scalar: 1
- inputs:
- scalar: 1
- scalar: 2
outputs:
- scalar: 3
---
description: add_eint_16bits
program: |
func.func @main(%arg0: !FHE.eint<16>, %arg1: !FHE.eint<16>) -> !FHE.eint<16> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<16>, !FHE.eint<16>) -> (!FHE.eint<16>)
return %1: !FHE.eint<16>
}
tests:
- inputs:
- scalar: 0
- scalar: 0
outputs:
- scalar: 0
- inputs:
- scalar: 0
- scalar: 5
outputs:
- scalar: 5
- inputs:
- scalar: 3
- scalar: 36036
- scalar: 36035
outputs:
- scalar: 4
- inputs:
- scalar: 4
outputs:
- scalar: 3
- scalar: 72071
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: sub_eint_int_cst
program: |
@@ -105,6 +198,22 @@ tests:
outputs:
- scalar: 0
---
description: sub_eint_int_cst_16bits
program: |
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
%0 = arith.constant 7 : i17
%1 = "FHE.sub_eint_int"(%arg0, %0): (!FHE.eint<16>, i17) -> (!FHE.eint<16>)
return %1: !FHE.eint<16>
}
tests:
- inputs:
- scalar: 7
outputs:
- scalar: 0
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: sub_eint_int_arg
program: |
func.func @main(%arg0: !FHE.eint<4>, %arg1: i5) -> !FHE.eint<4> {
@@ -128,6 +237,22 @@ tests:
outputs:
- scalar: 3
---
description: sub_eint_int_arg_16bits_fixme
program: |
func.func @main(%arg0: !FHE.eint<16>, %arg1: i17) -> !FHE.eint<16> {
%1 = "FHE.sub_eint_int"(%arg0, %arg1): (!FHE.eint<16>, i17) -> (!FHE.eint<16>)
return %1: !FHE.eint<16>
}
tests:
- inputs:
- scalar: 72071
- scalar: 2
outputs:
- scalar: 72069
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: sub_eint
program: |
func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
@@ -151,6 +276,32 @@ tests:
outputs:
- scalar: 3
---
description: sub_eint_16bits
program: |
func.func @main(%arg0: !FHE.eint<16>, %arg1: !FHE.eint<16>) -> !FHE.eint<16> {
%1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<16>, !FHE.eint<16>) -> (!FHE.eint<16>)
return %1: !FHE.eint<16>
}
tests:
- inputs:
- scalar: 0
- scalar: 0
outputs:
- scalar: 0
- inputs:
- scalar: 72071
- scalar: 72071
outputs:
- scalar: 0
- inputs:
- scalar: 7
- scalar: 4
outputs:
- scalar: 3
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: sub_int_eint_arg
program: |
func.func @main(%arg0: i3, %arg1: !FHE.eint<2>) -> !FHE.eint<2> {
@@ -174,6 +325,79 @@ tests:
outputs:
- scalar: 5
---
description: sub_int_eint_arg_16bits
program: |
func.func @main(%arg0: i17, %arg1: !FHE.eint<16>) -> !FHE.eint<16> {
%1 = "FHE.sub_int_eint"(%arg0, %arg1): (i17, !FHE.eint<16>) -> (!FHE.eint<16>)
return %1: !FHE.eint<16>
}
tests:
- inputs:
- scalar: 0
- scalar: 0
outputs:
- scalar: 0
- inputs:
- scalar: 0
- scalar: -1
outputs:
- scalar: 1
#- inputs:
# - scalar: 72071
# - scalar: 0
# outputs:
# - scalar: 72071
#- inputs:
# - scalar: 7
# - scalar: 4
# outputs:
# - scalar: 3
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: sub_int_eint_cst
program: |
func.func @main(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
%0 = arith.constant 7 : i3
%1 = "FHE.sub_int_eint"(%0, %arg0): (i3, !FHE.eint<2>) -> (!FHE.eint<2>)
return %1 : !FHE.eint<2>
}
tests:
- inputs:
- scalar: 1
outputs:
- scalar: 6
- inputs:
- scalar: 2
outputs:
- scalar: 5
---
description: sub_int_eint_cst_16bits
program: |
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
%0 = arith.constant -1 : i17
%1 = "FHE.sub_int_eint"(%0, %arg0): (i17, !FHE.eint<16>) -> (!FHE.eint<16>)
return %1: !FHE.eint<16>
}
tests:
- inputs:
- scalar: 72071
outputs:
- scalar: 0
- inputs:
- scalar: 0
outputs:
- scalar: 72071
- inputs:
- scalar: 32000
outputs:
- scalar: 40071
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: neg_eint
program: |
func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
@@ -198,6 +422,29 @@ tests:
outputs:
- scalar: 6
---
description: neg_eint_16bits
program: |
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<16>) -> (!FHE.eint<16>)
return %1: !FHE.eint<16>
}
tests:
- inputs:
- scalar: 0
outputs:
- scalar: 0
- inputs:
- scalar: 1
outputs:
- scalar: 72071
- inputs:
- scalar: 72071
outputs:
- scalar: 1
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: neg_eint_3bits
program: |
func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
@@ -247,6 +494,30 @@ tests:
outputs:
- scalar: 4
---
description: mul_eint_int_cst_16bits
program: |
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
%0 = arith.constant 2 : i17
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<16>, i17) -> (!FHE.eint<16>)
return %1: !FHE.eint<16>
}
tests:
- inputs:
- scalar: 0
outputs:
- scalar: 0
- inputs:
- scalar: 1
outputs:
- scalar: 2
- inputs:
- scalar: 36035
outputs:
- scalar: 72070
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: mul_eint_int_arg
program: |
func.func @main(%arg0: !FHE.eint<2>, %arg1: i3) -> !FHE.eint<2> {
@@ -270,28 +541,31 @@ tests:
outputs:
- scalar: 4
---
description: add_eint
description: mul_eint_int_arg_16bits
program: |
func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
func.func @main(%arg0: !FHE.eint<16>, %arg1: i17) -> !FHE.eint<16> {
%1 = "FHE.mul_eint_int"(%arg0, %arg1): (!FHE.eint<16>, i17) -> (!FHE.eint<16>)
return %1: !FHE.eint<16>
}
tests:
- inputs:
- scalar: 1
- scalar: 2
- scalar: 0
- scalar: 87
outputs:
- scalar: 3
- inputs:
- scalar: 4
- scalar: 5
outputs:
- scalar: 9
- scalar: 0
- inputs:
- scalar: 1
- scalar: 1
- scalar: 72071
outputs:
- scalar: 72071
- inputs:
- scalar: 2
- scalar: 3572
outputs:
- scalar: 7144
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: apply_lookup_table_1_bits
program: |

View File

@@ -1,3 +1,229 @@
description: add_eint_int_term_to_term
program: |
// Returns the term to term addition of `%a0` with `%a1`
func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!FHE.eint<6>> {
%res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4xi7>) -> tensor<4x!FHE.eint<6>>
return %res : tensor<4x!FHE.eint<6>>
}
tests:
- inputs:
- tensor: [31, 6, 12, 9]
shape: [4]
width: 8
- tensor: [32, 9, 2, 3]
shape: [4]
width: 8
outputs:
- tensor: [63, 15, 14, 12]
shape: [4]
---
description: add_eint_int_term_to_term_16bits
program: |
// Returns the term to term addition of `%a0` with `%a1`
func.func @main(%a0: tensor<4x!FHE.eint<16>>, %a1: tensor<4xi17>) -> tensor<4x!FHE.eint<16>> {
%res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<16>>, tensor<4xi17>) -> tensor<4x!FHE.eint<16>>
return %res : tensor<4x!FHE.eint<16>>
}
tests:
- inputs:
- tensor: [32767, 1276, 10212, 0]
shape: [4]
- tensor: [32768, 20967, 3, 0]
shape: [4]
outputs:
- tensor: [65535, 22243, 10215, 0]
shape: [4]
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: add_eint_term_to_term
program: |
func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> {
%res = "FHELinalg.add_eint"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>>
return %res : tensor<4x!FHE.eint<6>>
}
tests:
- inputs:
- tensor: [31, 6, 12, 9]
shape: [4]
- tensor: [32, 9, 2, 3]
shape: [4]
outputs:
- tensor: [63, 15, 14, 12]
shape: [4]
---
description: add_eint_term_to_term_16bits
program: |
func.func @main(%a0: tensor<4x!FHE.eint<16>>, %a1: tensor<4x!FHE.eint<16>>) -> tensor<4x!FHE.eint<16>> {
%res = "FHELinalg.add_eint"(%a0, %a1) : (tensor<4x!FHE.eint<16>>, tensor<4x!FHE.eint<16>>) -> tensor<4x!FHE.eint<16>>
return %res : tensor<4x!FHE.eint<16>>
}
tests:
- inputs:
- tensor: [32767, 1276, 10212, 0]
shape: [4]
- tensor: [32768, 20967, 3, 0]
shape: [4]
outputs:
- tensor: [65535, 22243, 10215, 0]
shape: [4]
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: sub_int_eint_term_to_term
program: |
// Returns the term to term substraction of `%a0` with `%a1`
func.func @main(%a0: tensor<4xi5>, %a1: tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>> {
%res = "FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4xi5>, tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>>
return %res : tensor<4x!FHE.eint<4>>
}
tests:
- inputs:
- tensor: [15, 9, 12, 9]
shape: [4]
width: 8
- tensor: [15, 6, 2, 3]
shape: [4]
width: 8
outputs:
- tensor: [0, 3, 10, 6]
shape: [4]
---
description: sub_int_eint_term_to_term_16bits
program: |
// Returns the term to term substraction of `%a0` with `%a1`
func.func @main(%a0: tensor<4xi17>, %a1: tensor<4x!FHE.eint<16>>) -> tensor<4x!FHE.eint<16>> {
%res = "FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4xi17>, tensor<4x!FHE.eint<16>>) -> tensor<4x!FHE.eint<16>>
return %res : tensor<4x!FHE.eint<16>>
}
tests:
- inputs:
- tensor: [65535, 22243, 10215, 0]
shape: [4]
- tensor: [65535, 1276, 10212, 0]
shape: [4]
outputs:
- tensor: [0, 20967, 3, 0]
shape: [4]
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: sub_eint_int_term_to_term
program: |
func.func @main(%a0: tensor<4xi5>, %a1: tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>> {
%res = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<4x!FHE.eint<4>>, tensor<4xi5>) -> tensor<4x!FHE.eint<4>>
return %res : tensor<4x!FHE.eint<4>>
}
tests:
- inputs:
- tensor: [15, 6, 2, 3]
shape: [4]
width: 8
- tensor: [15, 9, 12, 9]
shape: [4]
width: 8
outputs:
- tensor: [0, 3, 10, 6]
shape: [4]
---
description: sub_eint_int_term_to_term_16bits
program: |
func.func @main(%a0: tensor<4xi17>, %a1: tensor<4x!FHE.eint<16>>) -> tensor<4x!FHE.eint<16>> {
%res = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<4x!FHE.eint<16>>, tensor<4xi17>) -> tensor<4x!FHE.eint<16>>
return %res : tensor<4x!FHE.eint<16>>
}
tests:
- inputs:
- tensor: [65535, 1276, 10212, 0]
shape: [4]
- tensor: [65535, 22243, 10215, 0]
shape: [4]
outputs:
- tensor: [0, 20967, 3, 0]
shape: [4]
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: sub_eint_term_to_term
program: |
func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> {
%res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>>
return %res : tensor<4x!FHE.eint<6>>
}
tests:
- inputs:
- tensor: [31, 6, 12, 9]
shape: [4]
width: 8
- tensor: [4, 2, 9, 3]
shape: [4]
width: 8
outputs:
- tensor: [27, 4, 3, 6]
shape: [4]
---
description: sub_eint_term_to_term_16bits
program: |
func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> {
%res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>>
return %res : tensor<4x!FHE.eint<6>>
}
tests:
- inputs:
- tensor: [65535, 22243, 10215, 0]
shape: [4]
- tensor: [65535, 1276, 10212, 0]
shape: [4]
outputs:
- tensor: [0, 20967, 3, 0]
shape: [4]
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: mul_eint_int_term_to_term
program: |
// Returns the term to term multiplication of `%a0` with `%a1`
func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!FHE.eint<6>> {
%res = "FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4xi7>) -> tensor<4x!FHE.eint<6>>
return %res : tensor<4x!FHE.eint<6>>
}
tests:
- inputs:
- tensor: [31, 6, 12, 9]
shape: [4]
width: 8
- tensor: [2, 3, 2, 3]
shape: [4]
width: 8
outputs:
- tensor: [62, 18, 24, 27]
shape: [4]
---
description: mul_eint_int_term_to_term_16bits
program: |
// Returns the term to term multiplication of `%a0` with `%a1`
func.func @main(%a0: tensor<4x!FHE.eint<16>>, %a1: tensor<4xi17>) -> tensor<4x!FHE.eint<16>> {
%res = "FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<16>>, tensor<4xi17>) -> tensor<4x!FHE.eint<16>>
return %res : tensor<4x!FHE.eint<16>>
}
tests:
- inputs:
- tensor: [1, 65535, 12, 0]
shape: [4]
- tensor: [65535, 1, 1987, 0]
shape: [4]
outputs:
- tensor: [65535, 65535, 23844, 0]
shape: [4]
v0-constraint: [16, 0]
v0-parameter: [1,11,565,1,23,5,3]
large-integer-crt-decomposition: [7,8,9,11,13]
---
description: transpose1d
program: |
func.func @main(%input: tensor<3x!FHE.eint<6>>) -> tensor<3x!FHE.eint<6>> {

View File

@@ -20,6 +20,9 @@ void compile_and_run(EndToEndDesc desc, LambdaSupport support) {
if (desc.v0Parameter.hasValue()) {
options.v0Parameter = *desc.v0Parameter;
}
if (desc.largeIntegerParameter.hasValue()) {
options.largeIntegerParameter = *desc.largeIntegerParameter;
}
/* 0 - Enable parallel testing where required */
#ifdef CONCRETELANG_PARALLEL_TESTING_ENABLED

View File

@@ -70,37 +70,6 @@ func.func @main(%arg0: tensor<3x2x!FHE.eint<4>>) -> tensor<3x!FHE.eint<4>> {
EXPECT_EQ((*res)[2], 6);
}
TEST(End2EndJit_FHELinalg, add_eint_int_term_to_term) {
checkedJit(lambda, R"XXX(
// Returns the term to term addition of `%a0` with `%a1`
func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!FHE.eint<6>> {
%res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4xi7>) -> tensor<4x!FHE.eint<6>>
return %res : tensor<4x!FHE.eint<6>>
}
)XXX");
std::vector<uint8_t> a0{31, 6, 12, 9};
std::vector<uint8_t> a1{32, 9, 2, 3};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg0(a0);
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg1(a1);
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (size_t)4);
for (size_t i = 0; i < 4; i++) {
EXPECT_EQ((*res)[i], (uint64_t)a0[i] + a1[i]);
}
}
// Same as add_eint_int_term_to_term test above, but returning a lambda argument
TEST(End2EndJit_FHELinalg, add_eint_int_term_to_term_ret_lambda_argument) {
@@ -371,40 +340,6 @@ TEST(End2EndJit_FHELinalg, add_eint_int_matrix_line_missing_dim) {
// FHELinalg add_eint ///////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(End2EndJit_FHELinalg, add_eint_term_to_term) {
checkedJit(lambda, R"XXX(
// Returns the term to term addition of `%a0` with `%a1`
func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> {
%res = "FHELinalg.add_eint"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>>
return %res : tensor<4x!FHE.eint<6>>
}
)XXX");
std::vector<uint8_t> a0{31, 6, 12, 9};
std::vector<uint8_t> a1{32, 9, 2, 3};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg0(a0);
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg1(a1);
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)4);
for (size_t i = 0; i < 4; i++) {
EXPECT_EQ((*res)[i], (uint64_t)a0[i] + a1[i])
<< "result differ at pos " << i << ", expect " << a0[i] + a1[i]
<< " got " << (*res)[i];
}
}
TEST(End2EndJit_FHELinalg, add_eint_term_to_term_broadcast) {
checkedJit(lambda, R"XXX(
@@ -854,36 +789,6 @@ TEST(End2EndJit_FHELinalg, sub_int_eint_matrix_line_missing_dim) {
// FHELinalg sub_eint_int ///////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(End2EndJit_FHELinalg, sub_eint_int_term_to_term) {
checkedJit(lambda, R"XXX(
func.func @main(%a0: tensor<4xi5>, %a1: tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>> {
%res = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<4x!FHE.eint<4>>, tensor<4xi5>) -> tensor<4x!FHE.eint<4>>
return %res : tensor<4x!FHE.eint<4>>
}
)XXX");
std::vector<uint8_t> a0{31, 6, 2, 3};
std::vector<uint8_t> a1{32, 9, 12, 9};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg0(a0);
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg1(a1);
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)4);
for (size_t i = 0; i < 4; i++) {
EXPECT_EQ((*res)[i], (uint64_t)a1[i] - a0[i]);
}
}
TEST(End2EndJit_FHELinalg, sub_eint_int_term_to_term_broadcast) {
checkedJit(lambda, R"XXX(

View File

@@ -6,8 +6,13 @@
#include <gtest/gtest.h>
#define ASSERT_LLVM_ERROR(err) \
if (err) { \
ASSERT_TRUE(false) << llvm::toString(err); \
{ \
llvm::Error e = err; \
if (e) { \
handleAllErrors(std::move(e), [](const llvm::ErrorInfoBase &ei) { \
ASSERT_TRUE(false) << ei.message(); \
}); \
} \
}
// Checks that the value `val` is not in an error state. Returns

View File

@@ -8,6 +8,7 @@ add_unittest(
unit_tests_concretelang_clientlib
ClientParameters.cpp
CRT.cpp
KeySet.cpp
)

View File

@@ -0,0 +1,50 @@
#include <gtest/gtest.h>
#include "concretelang/ClientLib/CRT.h"
#include "tests_tools/assert.h"
namespace {
namespace crt = concretelang::clientlib::crt;
typedef std::vector<int64_t> CRTModuli;
// Define a fixture for instantiate test with client parameters
class CRTTest : public ::testing::TestWithParam<CRTModuli> {};
TEST_P(CRTTest, crt_iCrt) {
auto moduli = GetParam();
// Max representable value from moduli
uint64_t maxValue = 1;
for (auto modulus : moduli)
maxValue *= modulus;
maxValue = maxValue - 1;
std::vector<uint64_t> valuesToTest{0, maxValue / 2, maxValue};
for (auto a : valuesToTest) {
auto remainders = crt::crt(moduli, a);
auto b = crt::iCrt(moduli, remainders);
ASSERT_EQ(a, b);
}
}
std::vector<CRTModuli> generateAllParameters() {
return {
// This is our default moduli for the 16 bits
{7, 8, 9, 11, 13},
};
}
INSTANTIATE_TEST_SUITE_P(CRTSuite, CRTTest,
::testing::ValuesIn(generateAllParameters()),
[](const testing::TestParamInfo<CRTModuli> info) {
auto moduli = info.param;
std::string desc("mod");
if (!moduli.empty()) {
for (auto b : moduli) {
desc = desc + "_" + std::to_string(b);
}
}
return desc;
});
} // namespace

View File

@@ -42,7 +42,7 @@ TEST(Support, client_parameters_json_serde) {
}}};
params0.inputs = {
{
/*.encryption = */ {{clientlib::SMALL_KEY, 0.00, {4}}},
/*.encryption = */ {{clientlib::SMALL_KEY, 0.00, {4, {1, 2, 3, 4}}}},
/*.shape = */ {32, {1, 2, 3, 4}, 1 * 2 * 3 * 4},
},
{

View File

@@ -8,14 +8,14 @@
namespace clientlib = concretelang::clientlib;
// Define a fixture for instantiate test with client parameters
class ClientParametersTest
class KeySetTest
: public ::testing::TestWithParam<clientlib::ClientParameters> {
protected:
clientlib::ClientParameters clientParameters;
};
// Test case encrypt and decrypt
TEST_P(ClientParametersTest, encrypt_decrypt) {
TEST_P(KeySetTest, encrypt_decrypt) {
auto clientParameters = GetParam();
@@ -29,7 +29,7 @@ TEST_P(ClientParametersTest, encrypt_decrypt) {
ASSERT_OUTCOME_HAS_VALUE(keySet->allocate_lwe(0, &ciphertext, size));
// Encrypt
uint64_t input = 3;
uint64_t input = 0;
ASSERT_OUTCOME_HAS_VALUE(keySet->encrypt_lwe(0, ciphertext, input));
// Decrypt
@@ -45,9 +45,9 @@ TEST_P(ClientParametersTest, encrypt_decrypt) {
/// Create a client parameters with just one secret key of `dimension` and with
/// one input scalar gate and one output scalar gate on the same key
clientlib::ClientParameters
generateClientParameterOneScalarOneScalar(clientlib::LweDimension dimension,
clientlib::Precision precision) {
clientlib::ClientParameters generateClientParameterOneScalarOneScalar(
clientlib::LweDimension dimension, clientlib::Precision precision,
clientlib::CRTDecomposition crtDecomposition) {
// One secret key with the given dimension
clientlib::ClientParameters params;
params.secretKeys.insert({clientlib::SMALL_KEY, {/*.dimension =*/dimension}});
@@ -56,6 +56,7 @@ generateClientParameterOneScalarOneScalar(clientlib::LweDimension dimension,
clientlib::EncryptionGate encryption;
encryption.secretKeyID = clientlib::SMALL_KEY;
encryption.encoding.precision = precision;
encryption.encoding.crt = crtDecomposition;
clientlib::CircuitGate gate;
gate.encryption = encryption;
params.inputs.push_back(gate);
@@ -74,13 +75,23 @@ std::vector<clientlib::ClientParameters> generateAllParameters() {
llvm::for_each(llvm::enumerate(precisions),
[](auto p) { p.value() = p.index() + 1; });
// All crt decomposition to test
std::vector<clientlib::CRTDecomposition> crtDecompositions{
// Empty crt decompositon means no decomposition
{},
// The default decomposition for 16 bits
{7, 8, 9, 11, 13},
};
// All client parameters to test
std::vector<clientlib::ClientParameters> parameters;
for (auto dimension : lweDimensions) {
for (auto precision : precisions) {
parameters.push_back(
generateClientParameterOneScalarOneScalar(dimension, precision));
for (auto crtDecomposition : crtDecompositions) {
parameters.push_back(generateClientParameterOneScalarOneScalar(
dimension, precision, crtDecomposition));
}
}
}
@@ -88,13 +99,21 @@ std::vector<clientlib::ClientParameters> generateAllParameters() {
}
INSTANTIATE_TEST_SUITE_P(
OneScalarOnScalar, ClientParametersTest,
::testing::ValuesIn(generateAllParameters()),
OneScalarOnScalar, KeySetTest, ::testing::ValuesIn(generateAllParameters()),
[](const testing::TestParamInfo<clientlib::ClientParameters> info) {
auto cp = info.param;
auto input_0 = cp.inputs[0];
return std::string("lweDimension_") +
std::to_string(cp.lweSecretKeyParam(input_0).value().dimension) +
"_precision_" +
std::to_string(input_0.encryption.getValue().encoding.precision);
auto paramDescription =
std::string("lweDimension_") +
std::to_string(cp.lweSecretKeyParam(input_0).value().dimension) +
"_precision_" +
std::to_string(input_0.encryption.getValue().encoding.precision);
auto crt = input_0.encryption.getValue().encoding.crt;
if (!crt.empty()) {
paramDescription = paramDescription + "_crt_";
for (auto b : crt) {
paramDescription = paramDescription + "_" + std::to_string(b);
}
}
return paramDescription;
});