mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(compiler): First draft to support FHE.eint up to 16bits
For now what it works are only levelled ops with user parameters. (take a look to the tests) Done: - Add parameters to the fhe parameters to support CRT-based large integers - Add command line options and tests options to allows the user to give those new parameters - Update the dialects and pipeline to handle new fhe parameters for CRT-based large integers - Update the client parameters and the client library to handle the CRT-based large integers Todo: - Plug the optimizer to compute the CRT-based large interger parameters - Plug the pbs for the CRT-based large integer
This commit is contained in:
46
compiler/include/concretelang/ClientLib/CRT.h
Normal file
46
compiler/include/concretelang/ClientLib/CRT.h
Normal file
@@ -0,0 +1,46 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_CRT_H_
|
||||
#define CONCRETELANG_CLIENTLIB_CRT_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
namespace crt {
|
||||
|
||||
/// Compute the product of the moduli of the crt decomposition.
|
||||
///
|
||||
/// \param moduli The moduli of the crt decomposition
|
||||
/// \returns The product of moduli
|
||||
uint64_t productOfModuli(std::vector<int64_t> moduli);
|
||||
|
||||
/// Compute the crt decomposition of a `val` according the given `moduli`.
|
||||
///
|
||||
/// \param moduli The moduli to compute the decomposition.
|
||||
/// \param val The value to decompose.
|
||||
/// \returns The remainders.
|
||||
std::vector<int64_t> crt(std::vector<int64_t> moduli, uint64_t val);
|
||||
|
||||
/// Compute the inverse of the crt decomposition.
|
||||
///
|
||||
/// \param moduli The moduli used to compute the inverse decomposition.
|
||||
/// \param remainders The remainders of the decomposition.
|
||||
uint64_t iCrt(std::vector<int64_t> moduli, std::vector<int64_t> remainders);
|
||||
|
||||
/// Encode the plaintext with the given modulus and the product of moduli of the
|
||||
/// crt decomposition
|
||||
uint64_t encode(int64_t plaintext, uint64_t modulus, uint64_t product);
|
||||
|
||||
/// Decode follow the crt encoding
|
||||
uint64_t decode(uint64_t val, uint64_t modulus);
|
||||
|
||||
} // namespace crt
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -7,6 +7,7 @@
|
||||
#define CONCRETELANG_CLIENTLIB_CLIENTPARAMETERS_H_
|
||||
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@@ -31,6 +32,7 @@ typedef size_t DecompositionBaseLog;
|
||||
typedef size_t PolynomialSize;
|
||||
typedef size_t Precision;
|
||||
typedef double Variance;
|
||||
typedef std::vector<int64_t> CRTDecomposition;
|
||||
|
||||
typedef uint64_t LweDimension;
|
||||
typedef uint64_t GlweDimension;
|
||||
@@ -87,6 +89,7 @@ static inline bool operator==(const KeyswitchKeyParam &lhs,
|
||||
|
||||
struct Encoding {
|
||||
Precision precision;
|
||||
CRTDecomposition crt;
|
||||
};
|
||||
static inline bool operator==(const Encoding &lhs, const Encoding &rhs) {
|
||||
return lhs.precision == rhs.precision;
|
||||
@@ -168,6 +171,52 @@ struct ClientParameters {
|
||||
}
|
||||
return secretKey->second;
|
||||
}
|
||||
|
||||
/// bufferSize returns the size of the whole buffer of a gate.
|
||||
int64_t bufferSize(CircuitGate gate) {
|
||||
if (!gate.encryption.hasValue()) {
|
||||
// Value is not encrypted just returns the tensor size
|
||||
return gate.shape.size;
|
||||
}
|
||||
auto shapeSize = gate.shape.size == 0 ? 1 : gate.shape.size;
|
||||
// Size of the ciphertext
|
||||
return shapeSize * lweBufferSize(gate);
|
||||
}
|
||||
|
||||
/// lweBufferSize returns the size of one ciphertext of a gate.
|
||||
int64_t lweBufferSize(CircuitGate gate) {
|
||||
assert(gate.encryption.hasValue());
|
||||
auto nbBlocks = gate.encryption->encoding.crt.size();
|
||||
nbBlocks = nbBlocks == 0 ? 1 : nbBlocks;
|
||||
|
||||
auto param = lweSecretKeyParam(gate);
|
||||
assert(param.has_value());
|
||||
return param.value().lweSize() * nbBlocks;
|
||||
}
|
||||
|
||||
/// bufferShape returns the shape of the tensor for the given gate. It returns
|
||||
/// the shape used at low-level, i.e. contains the dimensions for ciphertexts.
|
||||
std::vector<int64_t> bufferShape(CircuitGate gate) {
|
||||
if (!gate.encryption.hasValue()) {
|
||||
// Value is not encrypted just returns the tensor shape
|
||||
return gate.shape.dimensions;
|
||||
}
|
||||
auto lweSecreteKeyParam = lweSecretKeyParam(gate);
|
||||
assert(lweSecreteKeyParam.has_value());
|
||||
|
||||
// Copy the shape
|
||||
std::vector<int64_t> shape(gate.shape.dimensions);
|
||||
|
||||
auto crt = gate.encryption->encoding.crt;
|
||||
|
||||
// CRT case: Add one dimension equals to the number of blocks
|
||||
if (!crt.empty()) {
|
||||
shape.push_back(crt.size());
|
||||
}
|
||||
// Add one dimension for the size of ciphertext(s)
|
||||
shape.push_back(lweSecreteKeyParam.value().lweSize());
|
||||
return shape;
|
||||
}
|
||||
};
|
||||
|
||||
static inline bool operator==(const ClientParameters &lhs,
|
||||
|
||||
@@ -23,11 +23,24 @@ using concretelang::error::StringError;
|
||||
|
||||
class PublicArguments;
|
||||
|
||||
inline size_t bitWidthAsWord(size_t exactBitWidth) {
|
||||
if (exactBitWidth <= 8)
|
||||
return 8;
|
||||
if (exactBitWidth <= 16)
|
||||
return 16;
|
||||
if (exactBitWidth <= 32)
|
||||
return 32;
|
||||
if (exactBitWidth <= 64)
|
||||
return 64;
|
||||
assert(false && "Bit witdh > 64 not supported");
|
||||
}
|
||||
|
||||
/// Temporary object used to hold and encrypt parameters before calling a
|
||||
/// ClientLambda. Use preferably TypeClientLambda and serializeCall(Args...).
|
||||
/// Otherwise convert it to a PublicArguments and use
|
||||
/// serializeCall(PublicArguments, KeySet).
|
||||
class EncryptedArguments {
|
||||
|
||||
public:
|
||||
EncryptedArguments() : currentPos(0) {}
|
||||
|
||||
@@ -73,7 +86,10 @@ public:
|
||||
|
||||
/// Add a vector-tensor argument.
|
||||
outcome::checked<void, StringError> pushArg(std::vector<uint8_t> arg,
|
||||
KeySet &keySet);
|
||||
KeySet &keySet) {
|
||||
return pushArg((uint8_t *)arg.data(),
|
||||
llvm::ArrayRef<int64_t>{(int64_t)arg.size()}, keySet);
|
||||
}
|
||||
|
||||
/// Add a 1D tensor argument with data and size of the dimension.
|
||||
template <typename T>
|
||||
@@ -82,26 +98,20 @@ public:
|
||||
return pushArg(std::vector<uint8_t>(data, data + dim1), keySet);
|
||||
}
|
||||
|
||||
// Add a tensor argument.
|
||||
template <typename T>
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(const T *data, llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
|
||||
return pushArg(8 * sizeof(T), static_cast<const void *>(data), shape,
|
||||
keySet);
|
||||
}
|
||||
|
||||
/// Add a 1D tensor argument.
|
||||
template <size_t size>
|
||||
outcome::checked<void, StringError> pushArg(std::array<uint8_t, size> arg,
|
||||
KeySet &keySet) {
|
||||
return pushArg(8, (void *)arg.data(), {size}, keySet);
|
||||
return pushArg((uint8_t *)arg.data(), llvm::ArrayRef<int64_t>{size},
|
||||
keySet);
|
||||
}
|
||||
|
||||
/// Add a 2D tensor argument.
|
||||
template <size_t size0, size_t size1>
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(std::array<std::array<uint8_t, size1>, size0> arg, KeySet &keySet) {
|
||||
return pushArg(8, (void *)arg.data(), {size0, size1}, keySet);
|
||||
return pushArg((uint8_t *)arg.data(), llvm::ArrayRef<int64_t>{size0, size1},
|
||||
keySet);
|
||||
}
|
||||
|
||||
/// Add a 3D tensor argument.
|
||||
@@ -109,7 +119,8 @@ public:
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(std::array<std::array<std::array<uint8_t, size2>, size1>, size0> arg,
|
||||
KeySet &keySet) {
|
||||
return pushArg(8, (void *)arg.data(), {size0, size1, size2}, keySet);
|
||||
return pushArg((uint8_t *)arg.data(),
|
||||
llvm::ArrayRef<int64_t>{size0, size1, size2}, keySet);
|
||||
}
|
||||
|
||||
// Generalize by computing shape by template recursion
|
||||
@@ -125,13 +136,94 @@ public:
|
||||
template <typename T>
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(T *data, llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
|
||||
return pushArg(8 * sizeof(T), static_cast<const void *>(data), shape,
|
||||
keySet);
|
||||
return pushArg(static_cast<const T *>(data), shape, keySet);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError> pushArg(size_t width, const void *data,
|
||||
llvm::ArrayRef<int64_t> shape,
|
||||
KeySet &keySet);
|
||||
template <typename T>
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(const T *data, llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
|
||||
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
|
||||
auto pos = currentPos;
|
||||
CircuitGate input = keySet.inputGate(pos);
|
||||
// Check the width of data
|
||||
if (input.shape.width > 64) {
|
||||
return StringError("argument #")
|
||||
<< pos << " width > 64 bits is not supported";
|
||||
}
|
||||
// Check the shape of tensor
|
||||
if (input.shape.dimensions.empty()) {
|
||||
return StringError("argument #") << pos << "is not a tensor";
|
||||
}
|
||||
if (shape.size() != input.shape.dimensions.size()) {
|
||||
return StringError("argument #")
|
||||
<< pos << "has not the expected number of dimension, got "
|
||||
<< shape.size() << " expected " << input.shape.dimensions.size();
|
||||
}
|
||||
// Allocate empty
|
||||
ciphertextBuffers.resize(ciphertextBuffers.size() + 1);
|
||||
TensorData &values_and_sizes = ciphertextBuffers.back();
|
||||
|
||||
// Check shape
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
if (shape[i] != input.shape.dimensions[i]) {
|
||||
return StringError("argument #")
|
||||
<< pos << " has not the expected dimension #" << i << " , got "
|
||||
<< shape[i] << " expected " << input.shape.dimensions[i];
|
||||
}
|
||||
}
|
||||
// Set sizes
|
||||
values_and_sizes.sizes = keySet.clientParameters().bufferShape(input);
|
||||
|
||||
if (input.encryption.hasValue()) {
|
||||
// Allocate values
|
||||
values_and_sizes.values.resize(
|
||||
keySet.clientParameters().bufferSize(input));
|
||||
auto lweSize = keySet.clientParameters().lweBufferSize(input);
|
||||
auto &values = values_and_sizes.values;
|
||||
for (size_t i = 0, offset = 0; i < input.shape.size;
|
||||
i++, offset += lweSize) {
|
||||
OUTCOME_TRYV(keySet.encrypt_lwe(pos, values.data() + offset, data[i]));
|
||||
}
|
||||
} else {
|
||||
// Allocate values take care of gate bitwidth
|
||||
auto bitsPerValue = bitWidthAsWord(input.shape.width);
|
||||
auto bytesPerValue = bitsPerValue / 8;
|
||||
auto nbWordPerValue = 8 / bytesPerValue;
|
||||
// ceil division
|
||||
auto size = (input.shape.size / nbWordPerValue) +
|
||||
(input.shape.size % nbWordPerValue != 0);
|
||||
size = size == 0 ? 1 : size;
|
||||
values_and_sizes.values.resize(size);
|
||||
auto v = (uint8_t *)values_and_sizes.values.data();
|
||||
for (size_t i = 0; i < input.shape.size; i++) {
|
||||
auto dst = v + i * bytesPerValue;
|
||||
auto src = (const uint8_t *)&data[i];
|
||||
for (size_t j = 0; j < bytesPerValue; j++) {
|
||||
dst[j] = src[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
// allocated
|
||||
preparedArgs.push_back(nullptr);
|
||||
// aligned
|
||||
preparedArgs.push_back((void *)values_and_sizes.values.data());
|
||||
// offset
|
||||
preparedArgs.push_back((void *)0);
|
||||
// sizes
|
||||
for (size_t size : values_and_sizes.sizes) {
|
||||
preparedArgs.push_back((void *)size);
|
||||
}
|
||||
|
||||
// Set the stride for each dimension, equal to the product of the
|
||||
// following dimensions.
|
||||
int64_t stride = values_and_sizes.length();
|
||||
for (size_t size : values_and_sizes.sizes) {
|
||||
stride = (size == 0 ? 0 : (stride / size));
|
||||
preparedArgs.push_back((void *)stride);
|
||||
}
|
||||
currentPos++;
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
/// Recursive case for scalars: extract first scalar argument from
|
||||
/// parameter pack and forward rest
|
||||
|
||||
@@ -37,7 +37,10 @@ public:
|
||||
static outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
generate(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
/// isInputEncrypted return true if the input at the given pos is encrypted.
|
||||
/// Returns the ClientParameters associated with the KeySet.
|
||||
ClientParameters clientParameters() { return _clientParameters; }
|
||||
|
||||
// isInputEncrypted return true if the input at the given pos is encrypted.
|
||||
bool isInputEncrypted(size_t pos);
|
||||
|
||||
/// getInputLweSecretKeyParam returns the parameters of the lwe secret key for
|
||||
@@ -155,6 +158,8 @@ private:
|
||||
std::map<LweSecretKeyID,
|
||||
std::pair<KeyswitchKeyParam, std::shared_ptr<LweKeyswitchKey>>>
|
||||
keyswitchKeys);
|
||||
|
||||
clientlib::ClientParameters _clientParameters;
|
||||
};
|
||||
|
||||
} // namespace clientlib
|
||||
|
||||
@@ -108,8 +108,7 @@ struct PublicResult {
|
||||
}
|
||||
|
||||
auto buffer = buffers[pos];
|
||||
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
|
||||
|
||||
auto lweSize = clientParameters.lweBufferSize(gate);
|
||||
std::vector<T> decryptedValues(buffer.length() / lweSize);
|
||||
for (size_t i = 0; i < decryptedValues.size(); i++) {
|
||||
auto ciphertext = &buffer.values[i * lweSize];
|
||||
@@ -120,6 +119,13 @@ struct PublicResult {
|
||||
return decryptedValues;
|
||||
}
|
||||
|
||||
/// Return the shape of the clear tensor of a result.
|
||||
outcome::checked<std::vector<int64_t>, StringError>
|
||||
asClearTextShape(size_t pos) {
|
||||
OUTCOME_TRY(auto gate, clientParameters.ouput(pos));
|
||||
return gate.shape.dimensions;
|
||||
}
|
||||
|
||||
// private: TODO tmp
|
||||
friend class ::concretelang::serverlib::ServerLambda;
|
||||
ClientParameters clientParameters;
|
||||
|
||||
@@ -10,8 +10,17 @@
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
// ApplyLookupTableLowering indicates the strategy to lower an
|
||||
// FHE.apply_loopup_table ops
|
||||
enum ApplyLookupTableLowering {
|
||||
KeySwitchBoostrapLowering,
|
||||
WopPBSLowering,
|
||||
};
|
||||
|
||||
/// Create a pass to convert `FHE` dialect to `TFHE` dialect.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertFHEToTFHEPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertFHEToTFHEPass(ApplyLookupTableLowering lower);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -23,7 +23,8 @@ using TFHE::GLWECipherTextType;
|
||||
GLWECipherTextType
|
||||
convertTypeEncryptedIntegerToGLWE(mlir::MLIRContext *context,
|
||||
EncryptedIntegerType eint) {
|
||||
return GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth());
|
||||
return GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth(),
|
||||
llvm::ArrayRef<int64_t>());
|
||||
}
|
||||
|
||||
/// Converts the type `t` to `TFHE::GlweCiphetext` if `t` is a
|
||||
|
||||
@@ -26,7 +26,8 @@ LweCiphertextType convertTypeToLWE(mlir::MLIRContext *context,
|
||||
auto glwe = type.dyn_cast_or_null<GLWECipherTextType>();
|
||||
if (glwe != nullptr) {
|
||||
assert(glwe.getPolynomialSize() == 1);
|
||||
return LweCiphertextType::get(context, glwe.getDimension(), glwe.getP());
|
||||
return LweCiphertextType::get(context, glwe.getDimension(), glwe.getP(),
|
||||
glwe.getCrtDecomposition());
|
||||
}
|
||||
auto lwe = type.dyn_cast_or_null<LweCiphertextType>();
|
||||
if (lwe != nullptr) {
|
||||
@@ -122,19 +123,10 @@ mlir::Value createConcreteOpFromTFHE(mlir::PatternRewriter &rewriter,
|
||||
mlir::Value createAddPlainLweCiphertextWithGlwe(
|
||||
mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Value arg0,
|
||||
mlir::Value arg1, mlir::OpResult result, mlir::Type encryptedType) {
|
||||
PlaintextType encoded_type =
|
||||
convertPlaintextTypeFromType(rewriter.getContext(), encryptedType);
|
||||
// encode int into plaintext
|
||||
mlir::Value encoded = rewriter
|
||||
.create<mlir::concretelang::Concrete::EncodeIntOp>(
|
||||
loc, encoded_type, arg1)
|
||||
.plaintext();
|
||||
|
||||
// replace op using the encoded plaintext instead of int
|
||||
auto op =
|
||||
rewriter
|
||||
.create<mlir::concretelang::Concrete::AddPlaintextLweCiphertextOp>(
|
||||
loc, result.getType(), arg0, encoded);
|
||||
loc, result.getType(), arg0, arg1);
|
||||
|
||||
convertOperandAndResultTypes(rewriter, op, convertTypeToLWEIfTFHEType);
|
||||
|
||||
@@ -171,21 +163,11 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter &rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
mlir::Value arg1,
|
||||
mlir::OpResult result) {
|
||||
auto inType = arg0.getType();
|
||||
CleartextType encoded_type =
|
||||
convertCleartextTypeFromType(rewriter.getContext(), inType);
|
||||
// encode int into plaintext
|
||||
mlir::Value encoded =
|
||||
rewriter
|
||||
.create<mlir::concretelang::Concrete::IntToCleartextOp>(
|
||||
loc, encoded_type, arg1)
|
||||
.cleartext();
|
||||
|
||||
// replace op using the encoded plaintext instead of int
|
||||
auto op =
|
||||
rewriter
|
||||
.create<mlir::concretelang::Concrete::MulCleartextLweCiphertextOp>(
|
||||
loc, result.getType(), arg0, encoded);
|
||||
loc, result.getType(), arg0, arg1);
|
||||
|
||||
convertOperandAndResultTypes(rewriter, op, convertTypeToLWEIfTFHEType);
|
||||
|
||||
|
||||
@@ -6,15 +6,44 @@
|
||||
#ifndef CONCRETELANG_CONVERSION_GLOBALFHECONTEXT_H_
|
||||
#define CONCRETELANG_CONVERSION_GLOBALFHECONTEXT_H_
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/Optional.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
typedef std::vector<int64_t> CRTDecomposition;
|
||||
|
||||
struct V0FHEConstraint {
|
||||
size_t norm2;
|
||||
size_t p;
|
||||
};
|
||||
|
||||
struct PackingKeySwitchParameter {
|
||||
size_t inputLweDimension;
|
||||
size_t inputLweCount;
|
||||
size_t outputPolynomialSize;
|
||||
size_t level;
|
||||
size_t baseLog;
|
||||
};
|
||||
|
||||
struct CitcuitBoostrapParameter {
|
||||
size_t level;
|
||||
size_t baseLog;
|
||||
};
|
||||
|
||||
struct WopPBSParameter {
|
||||
PackingKeySwitchParameter packingKeySwitch;
|
||||
CitcuitBoostrapParameter circuitBootstrap;
|
||||
};
|
||||
|
||||
struct LargeIntegerParameter {
|
||||
CRTDecomposition crtDecomposition;
|
||||
WopPBSParameter wopPBS;
|
||||
};
|
||||
|
||||
struct V0Parameter {
|
||||
size_t glweDimension;
|
||||
size_t logPolynomialSize;
|
||||
@@ -24,6 +53,8 @@ struct V0Parameter {
|
||||
size_t ksLevel;
|
||||
size_t ksLogBase;
|
||||
|
||||
llvm::Optional<LargeIntegerParameter> largeInteger;
|
||||
|
||||
V0Parameter() = delete;
|
||||
|
||||
V0Parameter(size_t glweDimension, size_t logPolynomialSize, size_t nSmall,
|
||||
|
||||
@@ -28,6 +28,11 @@ populateWithTensorTypeConverterPatterns(mlir::RewritePatternSet &patterns,
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::tensor::ExtractSliceOp>(target,
|
||||
typeConverter);
|
||||
|
||||
// InsertOp
|
||||
patterns.add<GenericTypeConverterPattern<mlir::tensor::InsertOp>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::tensor::InsertOp>(target, typeConverter);
|
||||
// InsertSliceOp
|
||||
patterns.add<GenericTypeConverterPattern<mlir::tensor::InsertSliceOp>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
|
||||
@@ -20,21 +20,56 @@ def BConcrete_AddLweBuffersOp : BConcrete_Op<"add_lwe_buffer"> {
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_AddCRTLweBuffersOp : BConcrete_Op<"add_crt_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
2DTensorOf<[I64]>:$lhs,
|
||||
2DTensorOf<[I64]>:$rhs,
|
||||
I64ArrayAttr:$crtDecomposition
|
||||
);
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_AddPlaintextLweBufferOp : BConcrete_Op<"add_plaintext_lwe_buffer"> {
|
||||
let arguments = (ins 1DTensorOf<[I64]>:$lhs, I64:$rhs);
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_AddPlaintextCRTLweBufferOp : BConcrete_Op<"add_plaintext_crt_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
2DTensorOf<[I64]>:$lhs,
|
||||
AnyInteger:$rhs,
|
||||
I64ArrayAttr:$crtDecomposition
|
||||
);
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_MulCleartextLweBufferOp : BConcrete_Op<"mul_cleartext_lwe_buffer"> {
|
||||
let arguments = (ins 1DTensorOf<[I64]>:$lhs, I64:$rhs);
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_MulCleartextCRTLweBufferOp : BConcrete_Op<"mul_cleartext_crt_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
2DTensorOf<[I64]>:$lhs,
|
||||
AnyInteger:$rhs,
|
||||
I64ArrayAttr:$crtDecomposition
|
||||
);
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_NegateLweBufferOp : BConcrete_Op<"negate_lwe_buffer"> {
|
||||
let arguments = (ins 1DTensorOf<[I64]>:$ciphertext);
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_NegateCRTLweBufferOp : BConcrete_Op<"negate_crt_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
2DTensorOf<[I64]>:$ciphertext,
|
||||
I64ArrayAttr:$crtDecomposition
|
||||
);
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_FillGlweFromTable : BConcrete_Op<"fill_glwe_from_table"> {
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]>:$glwe,
|
||||
@@ -67,5 +102,28 @@ def BConcrete_BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> {
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
// TODO(16bits): hack
|
||||
def BConcrete_WopPBSCRTLweBufferOp : BConcrete_Op<"wop_pbs_crt_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
2DTensorOf<[I64]>:$ciphertext,
|
||||
1DTensorOf<[I64]>:$lookupTable,
|
||||
// Bootstrap parameters
|
||||
I32Attr : $bootstrapLevel,
|
||||
I32Attr : $bootstrapBaseLog,
|
||||
// Keyswitch parameters
|
||||
I32Attr : $keyswitchLevel,
|
||||
I32Attr : $keyswitchBaseLog,
|
||||
// Packing keyswitch key parameters
|
||||
I32Attr : $packingKeySwitchInputLweDimension,
|
||||
I32Attr : $packingKeySwitchinputLweCount,
|
||||
I32Attr : $packingKeySwitchoutputPolynomialSize,
|
||||
I32Attr : $packingKeySwitchLevel,
|
||||
I32Attr : $packingKeySwitchBaseLog,
|
||||
// Circuit bootstrap parameters
|
||||
I32Attr : $circuitBootstrapLevel,
|
||||
I32Attr : $circuitBootstrapBaseLog
|
||||
);
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createAddRuntimeContext();
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createEliminateCRTOps();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -16,4 +16,9 @@ def AddRuntimeContext : Pass<"add-runtime-context", "mlir::ModuleOp"> {
|
||||
let constructor = "mlir::concretelang::createAddRuntimeContext()";
|
||||
}
|
||||
|
||||
def EliminateCRTOps : Pass<"eliminate-bconcrete-crt-ops", "mlir::func::FuncOp"> {
|
||||
let summary = "Eliminate the crt bconcrete operators.";
|
||||
let constructor = "mlir::concretelang::createEliminateCRTOpsPass()";
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
|
||||
|
||||
@@ -34,14 +34,14 @@ def Concrete_AddLweCiphertextsOp : Concrete_Op<"add_lwe_ciphertexts"> {
|
||||
def Concrete_AddPlaintextLweCiphertextOp : Concrete_Op<"add_plaintext_lwe_ciphertext"> {
|
||||
let summary = "Returns the sum of a clear integer and a lwe ciphertext";
|
||||
|
||||
let arguments = (ins Concrete_LweCiphertextType:$lhs, Concrete_PlaintextType:$rhs);
|
||||
let arguments = (ins Concrete_LweCiphertextType:$lhs, AnyInteger:$rhs);
|
||||
let results = (outs Concrete_LweCiphertextType:$result);
|
||||
}
|
||||
|
||||
def Concrete_MulCleartextLweCiphertextOp : Concrete_Op<"mul_cleartext_lwe_ciphertext"> {
|
||||
let summary = "Returns the product of a clear integer and a lwe ciphertext";
|
||||
|
||||
let arguments = (ins Concrete_LweCiphertextType:$lhs, Concrete_CleartextType:$rhs);
|
||||
let arguments = (ins Concrete_LweCiphertextType:$lhs, AnyInteger:$rhs);
|
||||
let results = (outs Concrete_LweCiphertextType:$result);
|
||||
}
|
||||
|
||||
@@ -82,18 +82,30 @@ def Concrete_KeySwitchLweOp : Concrete_Op<"keyswitch_lwe"> {
|
||||
let results = (outs Concrete_LweCiphertextType:$result);
|
||||
}
|
||||
|
||||
def Concrete_EncodeIntOp : Concrete_Op<"encode_int"> {
|
||||
let summary = "Encodes an integer (for it to later be added to a LWE ciphertext)";
|
||||
// TODO(16bits): hack
|
||||
def Concrete_WopPBSLweOp : Concrete_Op<"wop_pbs_lwe"> {
|
||||
let summary = "";
|
||||
|
||||
let arguments = (ins AnyInteger:$i);
|
||||
let results = (outs Concrete_PlaintextType:$plaintext);
|
||||
}
|
||||
|
||||
def Concrete_IntToCleartextOp : Concrete_Op<"int_to_cleartext", [NoSideEffect]> {
|
||||
let summary = "Keyswitches a LWE ciphertext";
|
||||
|
||||
let arguments = (ins AnyInteger:$i);
|
||||
let results = (outs Concrete_CleartextType:$cleartext);
|
||||
let arguments = (ins
|
||||
Concrete_LweCiphertextType:$ciphertext,
|
||||
1DTensorOf<[I64]>:$accumulator,
|
||||
// Bootstrap parameters
|
||||
I32Attr : $bootstrapLevel,
|
||||
I32Attr : $bootstrapBaseLog,
|
||||
// Keyswitch parameters
|
||||
I32Attr : $keyswitchLevel,
|
||||
I32Attr : $keyswitchBaseLog,
|
||||
// Packing keyswitch key parameters
|
||||
I32Attr : $packingKeySwitchInputLweDimension,
|
||||
I32Attr : $packingKeySwitchinputLweCount,
|
||||
I32Attr : $packingKeySwitchoutputPolynomialSize,
|
||||
I32Attr : $packingKeySwitchLevel,
|
||||
I32Attr : $packingKeySwitchBaseLog,
|
||||
// Circuit bootstrap parameters
|
||||
I32Attr : $circuitBootstrapLevel,
|
||||
I32Attr : $circuitBootstrapBaseLog
|
||||
);
|
||||
let results = (outs Concrete_LweCiphertextType:$result);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -40,7 +40,10 @@ def Concrete_LweCiphertextType : Concrete_Type<"LweCiphertext", [MemRefElementTy
|
||||
// The dimension of the lwe ciphertext
|
||||
"signed":$dimension,
|
||||
// Precision of the lwe ciphertext
|
||||
"signed":$p
|
||||
"signed":$p,
|
||||
// CRT decomposition for large integers
|
||||
ArrayRefParameter<"int64_t", "CRT decomposition">:$crtDecomposition
|
||||
|
||||
);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
@@ -115,4 +115,29 @@ def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe"> {
|
||||
let results = (outs TFHE_GLWECipherTextType : $result);
|
||||
}
|
||||
|
||||
def TFHE_WopPBSGLWEOp : TFHE_Op<"wop_pbs_glwe"> {
|
||||
let summary = "";
|
||||
|
||||
let arguments = (ins
|
||||
TFHE_GLWECipherTextType : $ciphertext,
|
||||
1DTensorOf<[I64]> : $lookupTable,
|
||||
// Bootstrap parameters
|
||||
I32Attr : $bootstrapLevel,
|
||||
I32Attr : $bootstrapBaseLog,
|
||||
// Keyswitch parameters
|
||||
I32Attr : $keyswitchLevel,
|
||||
I32Attr : $keyswitchBaseLog,
|
||||
// Packing keyswitch key parameters
|
||||
I32Attr : $packingKeySwitchInputLweDimension,
|
||||
I32Attr : $packingKeySwitchinputLweCount,
|
||||
I32Attr : $packingKeySwitchoutputPolynomialSize,
|
||||
I32Attr : $packingKeySwitchLevel,
|
||||
I32Attr : $packingKeySwitchBaseLog,
|
||||
// Circuit bootstrap parameters
|
||||
I32Attr : $circuitBootstrapLevel,
|
||||
I32Attr : $circuitBootstrapBaseLog
|
||||
);
|
||||
let results = (outs TFHE_GLWECipherTextType:$result);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -6,18 +6,18 @@
|
||||
include "concretelang/Dialect/TFHE/IR/TFHEDialect.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
|
||||
class TFHE_Type<string name, list<Trait> traits = []> : TypeDef<TFHE_Dialect, name, traits> { }
|
||||
class TFHE_Type<string name, list<Trait> traits = []>
|
||||
: TypeDef<TFHE_Dialect, name, traits> {}
|
||||
|
||||
def TFHE_GLWECipherTextType : TFHE_Type<"GLWECipherText", [MemRefElementTypeInterface]> {
|
||||
let mnemonic = "glwe";
|
||||
def TFHE_GLWECipherTextType
|
||||
: TFHE_Type<"GLWECipherText", [MemRefElementTypeInterface]> {
|
||||
let mnemonic = "glwe";
|
||||
|
||||
let summary = "A GLWE ciphertext";
|
||||
let summary = "A GLWE ciphertext";
|
||||
|
||||
let description = [{
|
||||
An GLWE cipher text
|
||||
}];
|
||||
let description = [{An GLWE cipher text}];
|
||||
|
||||
let parameters = (ins
|
||||
let parameters = (ins
|
||||
// The size of the mask
|
||||
"signed":$dimension,
|
||||
// Size of the polynome
|
||||
@@ -25,22 +25,22 @@ def TFHE_GLWECipherTextType : TFHE_Type<"GLWECipherText", [MemRefElementTypeInte
|
||||
// Number of bits of the ciphertext
|
||||
"signed":$bits,
|
||||
// Number of bits of the plain text representation
|
||||
"signed":$p
|
||||
"signed":$p,
|
||||
// CRT decomposition for large integers
|
||||
ArrayRefParameter<"int64_t", "CRT decomposition">:$crtDecomposition
|
||||
);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let genVerifyDecl = true;
|
||||
let genVerifyDecl = true;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns true if has an unparametrized parameters
|
||||
bool hasUnparametrizedParameters() {
|
||||
return getDimension() ==-1 ||
|
||||
getPolynomialSize() == -1 ||
|
||||
getBits() == -1 ||
|
||||
getP() == -1;
|
||||
};
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
// Returns true if has an unparametrized parameters
|
||||
bool hasUnparametrizedParameters() {
|
||||
return getDimension() == -1 || getPolynomialSize() == -1 ||
|
||||
getBits() == -1 || getP() == -1;
|
||||
};
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -58,6 +58,24 @@ void memref_bootstrap_lwe_u64(
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
uint64_t encode_crt(int64_t plaintext, uint64_t modulus, uint64_t product);
|
||||
|
||||
// TODO(16bits): Hackish wrapper for the 16 bits quick win
|
||||
void memref_wop_pbs_crt_buffer(
|
||||
// Output memref 2D memref
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset_0,
|
||||
uint64_t out_offset_1, uint64_t out_size_0, uint64_t out_size_1,
|
||||
uint64_t out_stride_0, uint64_t out_stride_1,
|
||||
// Input memref
|
||||
uint64_t *in_allocated, uint64_t *in_aligned, uint64_t in_offset_0,
|
||||
uint64_t in_offset_1, uint64_t in_size_0, uint64_t in_size_1,
|
||||
uint64_t in_stride_0, uint64_t in_stride_1,
|
||||
// clear text lut
|
||||
uint64_t *lut_ct_allocated, uint64_t *lut_ct_aligned,
|
||||
uint64_t lut_ct_offset, uint64_t lut_ct_size, uint64_t lut_ct_stride,
|
||||
// runtime context that hold evluation keys
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned,
|
||||
uint64_t src_offset, uint64_t src_size,
|
||||
uint64_t src_stride, uint64_t *dst_allocated,
|
||||
|
||||
@@ -42,6 +42,11 @@ struct CompilationOptions {
|
||||
|
||||
llvm::Optional<mlir::concretelang::V0Parameter> v0Parameter;
|
||||
|
||||
/// largeIntegerParameter force the compiler engine to lower FHE.eint using
|
||||
/// the large integers strategy with the given parameters.
|
||||
llvm::Optional<mlir::concretelang::LargeIntegerParameter>
|
||||
largeIntegerParameter;
|
||||
|
||||
bool verifyDiagnostics;
|
||||
|
||||
bool autoParallelize;
|
||||
|
||||
@@ -94,14 +94,15 @@ buildTensorLambdaResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result) {
|
||||
llvm::Expected<std::vector<T>> tensorOrError =
|
||||
typedResult<std::vector<T>>(keySet, result);
|
||||
|
||||
if (auto err = tensorOrError.takeError())
|
||||
return std::move(err);
|
||||
std::vector<int64_t> tensorDim(result.buffers[0].sizes.begin(),
|
||||
result.buffers[0].sizes.end() - 1);
|
||||
|
||||
auto tensorDim = result.asClearTextShape(0);
|
||||
if (tensorDim.has_error())
|
||||
return StreamStringError(tensorDim.error().mesg);
|
||||
|
||||
return std::make_unique<TensorLambdaArgument<IntLambdaArgument<T>>>(
|
||||
*tensorOrError, tensorDim);
|
||||
*tensorOrError, tensorDim.value());
|
||||
}
|
||||
|
||||
/// pecialization of `typedResult()` for a single result wrapped into
|
||||
|
||||
@@ -36,6 +36,7 @@ markFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
add_compile_options( -Werror )
|
||||
add_compile_options(-Werror)
|
||||
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
# using Clang
|
||||
add_compile_options( -Wno-error=pessimizing-move -Wno-pessimizing-move )
|
||||
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
||||
add_compile_options(-Wno-error=pessimizing-move -Wno-pessimizing-move)
|
||||
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
||||
# using GCC
|
||||
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
|
||||
add_compile_options( -Werror -Wno-error=pessimizing-move -Wno-pessimizing-move )
|
||||
add_compile_options(-Werror -Wno-error=pessimizing-move -Wno-pessimizing-move)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -14,6 +14,7 @@ add_mlir_library(
|
||||
ConcretelangClientLib
|
||||
ClientLambda.cpp
|
||||
ClientParameters.cpp
|
||||
CRT.cpp
|
||||
EncryptedArguments.cpp
|
||||
KeySet.cpp
|
||||
KeySetCache.cpp
|
||||
@@ -23,8 +24,6 @@ add_mlir_library(
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/ClientLib
|
||||
|
||||
LINK_LIBS
|
||||
ConcretelangRuntime
|
||||
LINK_LIBS PUBLIC
|
||||
Concrete
|
||||
)
|
||||
|
||||
99
compiler/lib/ClientLib/CRT.cpp
Normal file
99
compiler/lib/ClientLib/CRT.cpp
Normal file
@@ -0,0 +1,99 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <cstddef>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "concretelang/ClientLib/CRT.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
namespace crt {
|
||||
uint64_t productOfModuli(std::vector<int64_t> moduli) {
|
||||
uint64_t product = 1;
|
||||
for (auto modulus : moduli) {
|
||||
product *= modulus;
|
||||
}
|
||||
return product;
|
||||
}
|
||||
|
||||
std::vector<int64_t> crt(std::vector<int64_t> moduli, uint64_t val) {
|
||||
std::vector<int64_t> remainders(moduli.size(), 0);
|
||||
|
||||
for (size_t i = 0; i < moduli.size(); i++) {
|
||||
remainders[i] = val % moduli[i];
|
||||
}
|
||||
return remainders;
|
||||
}
|
||||
|
||||
// https://www.geeksforgeeks.org/multiplicative-inverse-under-modulo-m/
|
||||
// Returns modulo inverse of a with respect
|
||||
// to m using extended Euclid Algorithm
|
||||
// Assumption: a and m are coprimes, i.e.,
|
||||
// gcd(a, m) = 1
|
||||
int64_t modInverse(int64_t a, int64_t m) {
|
||||
int64_t m0 = m;
|
||||
int64_t y = 0, x = 1;
|
||||
|
||||
if (m == 1)
|
||||
return 0;
|
||||
|
||||
while (a > 1) {
|
||||
// q is quotient
|
||||
int64_t q = a / m;
|
||||
int64_t t = m;
|
||||
|
||||
// m is remainder now, process same as
|
||||
// Euclid's algo
|
||||
m = a % m;
|
||||
a = t;
|
||||
t = y;
|
||||
|
||||
// Update y and x
|
||||
y = x - q * y;
|
||||
x = t;
|
||||
}
|
||||
|
||||
// Make x positive
|
||||
if (x < 0)
|
||||
x += m0;
|
||||
|
||||
return x;
|
||||
}
|
||||
|
||||
uint64_t iCrt(std::vector<int64_t> moduli, std::vector<int64_t> remainders) {
|
||||
// Compute the product of moduli
|
||||
int64_t product = productOfModuli(moduli);
|
||||
|
||||
int64_t result = 0;
|
||||
|
||||
// Apply above formula
|
||||
for (size_t i = 0; i < remainders.size(); i++) {
|
||||
int tmp = product / moduli[i];
|
||||
result += remainders[i] * modInverse(tmp, moduli[i]) * tmp;
|
||||
}
|
||||
|
||||
return result % product;
|
||||
}
|
||||
|
||||
uint64_t encode(int64_t plaintext, uint64_t modulus, uint64_t product) {
|
||||
// values are represented on the interval [0; product[ so we represent
|
||||
// plantext on this interval
|
||||
if (plaintext < 0) {
|
||||
plaintext = product + plaintext;
|
||||
}
|
||||
__uint128_t m = plaintext % modulus;
|
||||
return m * ((__uint128_t)(1) << 64) / modulus;
|
||||
}
|
||||
|
||||
uint64_t decode(uint64_t val, uint64_t modulus) {
|
||||
auto result = (__uint128_t)val * (__uint128_t)modulus;
|
||||
result = result + ((result & ((__uint128_t)(1) << 63)) << 1);
|
||||
result = result / ((__uint128_t)(1) << 64);
|
||||
return (uint64_t)result % modulus;
|
||||
}
|
||||
} // namespace crt
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
@@ -234,6 +234,9 @@ llvm::json::Value toJSON(const Encoding &v) {
|
||||
llvm::json::Object object{
|
||||
{"precision", v.precision},
|
||||
};
|
||||
if (!v.crt.empty()) {
|
||||
object.insert({"crt", v.crt});
|
||||
}
|
||||
return object;
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, Encoding &v, llvm::json::Path p) {
|
||||
@@ -248,6 +251,18 @@ bool fromJSON(const llvm::json::Value j, Encoding &v, llvm::json::Path p) {
|
||||
return false;
|
||||
}
|
||||
v.precision = precision.getValue();
|
||||
auto crt = obj->getArray("crt");
|
||||
if (crt != nullptr) {
|
||||
for (auto dim : *crt) {
|
||||
auto iDim = dim.getAsInteger();
|
||||
if (!iDim.hasValue()) {
|
||||
p.report("dimensions must be integer");
|
||||
return false;
|
||||
}
|
||||
v.crt.push_back(iDim.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -11,17 +11,6 @@ namespace clientlib {
|
||||
|
||||
using StringError = concretelang::error::StringError;
|
||||
|
||||
size_t bitWidthAsWord(size_t exactBitWidth) {
|
||||
size_t sortedWordBitWidths[] = {8, 16, 32, 64};
|
||||
size_t previousWidth = 0;
|
||||
for (auto currentWidth : sortedWordBitWidths) {
|
||||
if (previousWidth < exactBitWidth && exactBitWidth <= currentWidth) {
|
||||
return currentWidth;
|
||||
}
|
||||
}
|
||||
return exactBitWidth;
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
EncryptedArguments::exportPublicArguments(ClientParameters clientParameters,
|
||||
RuntimeContext runtimeContext) {
|
||||
@@ -33,7 +22,7 @@ outcome::checked<void, StringError>
|
||||
EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
|
||||
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
|
||||
auto pos = currentPos++;
|
||||
CircuitGate input = keySet.inputGate(pos);
|
||||
OUTCOME_TRY(CircuitGate input, keySet.clientParameters().input(pos));
|
||||
if (input.shape.size != 0) {
|
||||
return StringError("argument #") << pos << " is not a scalar";
|
||||
}
|
||||
@@ -42,12 +31,11 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
|
||||
preparedArgs.push_back((void *)arg);
|
||||
return outcome::success();
|
||||
}
|
||||
ciphertextBuffers.resize(ciphertextBuffers.size() + 1); // Allocate empty
|
||||
// Allocate empty
|
||||
ciphertextBuffers.resize(ciphertextBuffers.size() + 1);
|
||||
TensorData &values_and_sizes = ciphertextBuffers.back();
|
||||
auto lweSize = keySet.getInputLweSecretKeyParam(pos).lweSize();
|
||||
values_and_sizes.sizes.push_back(lweSize);
|
||||
values_and_sizes.values.resize(lweSize);
|
||||
|
||||
values_and_sizes.sizes = keySet.clientParameters().bufferShape(input);
|
||||
values_and_sizes.values.resize(keySet.clientParameters().bufferSize(input));
|
||||
OUTCOME_TRYV(keySet.encrypt_lwe(pos, values_and_sizes.values.data(), arg));
|
||||
// Note: Since we bufferized lwe ciphertext take care of memref calling
|
||||
// convention
|
||||
@@ -57,97 +45,18 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
|
||||
preparedArgs.push_back((void *)values_and_sizes.values.data());
|
||||
// offset
|
||||
preparedArgs.push_back((void *)0);
|
||||
// size
|
||||
preparedArgs.push_back((void *)values_and_sizes.values.size());
|
||||
// stride
|
||||
preparedArgs.push_back((void *)1);
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArguments::pushArg(std::vector<uint8_t> arg, KeySet &keySet) {
|
||||
return pushArg(8, (void *)arg.data(), {(int64_t)arg.size()}, keySet);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArguments::pushArg(size_t width, const void *data,
|
||||
llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
|
||||
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
|
||||
auto pos = currentPos;
|
||||
CircuitGate input = keySet.inputGate(pos);
|
||||
// Check the width of data
|
||||
if (input.shape.width > 64) {
|
||||
return StringError("argument #")
|
||||
<< pos << " width > 64 bits is not supported";
|
||||
}
|
||||
auto roundedSize = bitWidthAsWord(input.shape.width);
|
||||
if (width != roundedSize) {
|
||||
return StringError("argument #") << pos << "width mismatch, got " << width
|
||||
<< " expected " << roundedSize;
|
||||
}
|
||||
// Check the shape of tensor
|
||||
if (input.shape.dimensions.empty()) {
|
||||
return StringError("argument #") << pos << "is not a tensor";
|
||||
}
|
||||
if (shape.size() != input.shape.dimensions.size()) {
|
||||
return StringError("argument #")
|
||||
<< pos << "has not the expected number of dimension, got "
|
||||
<< shape.size() << " expected " << input.shape.dimensions.size();
|
||||
}
|
||||
ciphertextBuffers.resize(ciphertextBuffers.size() + 1); // Allocate empty
|
||||
TensorData &values_and_sizes = ciphertextBuffers.back();
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
values_and_sizes.sizes.push_back(shape[i]);
|
||||
if (shape[i] != input.shape.dimensions[i]) {
|
||||
return StringError("argument #")
|
||||
<< pos << " has not the expected dimension #" << i << " , got "
|
||||
<< shape[i] << " expected " << input.shape.dimensions[i];
|
||||
}
|
||||
}
|
||||
if (input.encryption.hasValue()) {
|
||||
auto lweSize = keySet.getInputLweSecretKeyParam(pos).lweSize();
|
||||
values_and_sizes.sizes.push_back(lweSize);
|
||||
|
||||
// Encrypted tensor: for now we support only 8 bits for encrypted tensor
|
||||
if (width != 8) {
|
||||
return StringError("argument #")
|
||||
<< pos << " width mismatch, expected 8 got " << width;
|
||||
}
|
||||
const uint8_t *data8 = (const uint8_t *)data;
|
||||
|
||||
// Allocate a buffer for ciphertexts of size of tensor
|
||||
values_and_sizes.values.resize(input.shape.size * lweSize);
|
||||
auto &values = values_and_sizes.values;
|
||||
// Allocate ciphertexts and encrypt, for every values in tensor
|
||||
for (size_t i = 0, offset = 0; i < input.shape.size;
|
||||
i++, offset += lweSize) {
|
||||
OUTCOME_TRYV(keySet.encrypt_lwe(pos, values.data() + offset, data8[i]));
|
||||
}
|
||||
} else {
|
||||
values_and_sizes.values.resize(input.shape.size);
|
||||
for (size_t i = 0; i < input.shape.size; i++) {
|
||||
values_and_sizes.values[i] = ((const uint64_t *)data)[i];
|
||||
}
|
||||
}
|
||||
// allocated
|
||||
preparedArgs.push_back(nullptr);
|
||||
// aligned
|
||||
preparedArgs.push_back((void *)values_and_sizes.values.data());
|
||||
// offset
|
||||
preparedArgs.push_back((void *)0);
|
||||
// sizes
|
||||
for (size_t size : values_and_sizes.sizes) {
|
||||
for (auto size : values_and_sizes.sizes) {
|
||||
preparedArgs.push_back((void *)size);
|
||||
}
|
||||
|
||||
// Set the stride for each dimension, equal to the product of the
|
||||
// following dimensions.
|
||||
// strides
|
||||
int64_t stride = values_and_sizes.length();
|
||||
for (size_t size : values_and_sizes.sizes) {
|
||||
for (size_t i = 0; i < values_and_sizes.sizes.size() - 1; i++) {
|
||||
auto size = values_and_sizes.sizes[i];
|
||||
stride = (size == 0 ? 0 : (stride / size));
|
||||
preparedArgs.push_back((void *)stride);
|
||||
}
|
||||
currentPos++;
|
||||
preparedArgs.push_back((void *)1);
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/ClientLib/CRT.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
|
||||
#define CAPI_ERR_TO_STRINGERROR(instr, msg) \
|
||||
@@ -31,16 +32,16 @@ outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
KeySet::generate(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
auto keySet = std::make_unique<KeySet>();
|
||||
|
||||
OUTCOME_TRYV(keySet->generateKeysFromParams(params, seed_msb, seed_lsb));
|
||||
OUTCOME_TRYV(keySet->setupEncryptionMaterial(params, seed_msb, seed_lsb));
|
||||
|
||||
return std::move(keySet);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::setupEncryptionMaterial(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
_clientParameters = params;
|
||||
|
||||
// Set inputs and outputs LWE secret keys
|
||||
{
|
||||
for (auto param : params.inputs) {
|
||||
@@ -189,9 +190,16 @@ KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size) {
|
||||
return StringError("allocate_lwe position of argument is too high");
|
||||
}
|
||||
auto inputSk = inputs[argPos];
|
||||
auto encryption = std::get<0>(inputSk).encryption;
|
||||
if (!encryption.hasValue()) {
|
||||
return StringError("allocate_lwe argument #")
|
||||
<< argPos << "is not encypeted";
|
||||
}
|
||||
auto numBlocks =
|
||||
encryption->encoding.crt.empty() ? 1 : encryption->encoding.crt.size();
|
||||
|
||||
size = std::get<1>(inputSk).lweSize();
|
||||
*ciphertext = (uint64_t *)malloc(sizeof(uint64_t) * size);
|
||||
*ciphertext = (uint64_t *)malloc(sizeof(uint64_t) * size * numBlocks);
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
@@ -205,20 +213,40 @@ bool KeySet::isOutputEncrypted(size_t argPos) {
|
||||
std::get<0>(outputs[argPos]).encryption.hasValue();
|
||||
}
|
||||
|
||||
/// Return the number of bits to represents the given value
|
||||
uint64_t bitWidthOfValue(uint64_t value) { return std::ceil(std::log2(value)); }
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) {
|
||||
if (argPos >= inputs.size()) {
|
||||
return StringError("encrypt_lwe position of argument is too high");
|
||||
}
|
||||
auto inputSk = inputs[argPos];
|
||||
if (!std::get<0>(inputSk).encryption.hasValue()) {
|
||||
auto encryption = std::get<0>(inputSk).encryption;
|
||||
if (!encryption.hasValue()) {
|
||||
return StringError("encrypt_lwe the positional argument is not encrypted");
|
||||
}
|
||||
// Encode - TODO we could check if the input value is in the right range
|
||||
uint64_t plaintext =
|
||||
input << (64 - (std::get<0>(inputSk).encryption->encoding.precision + 1));
|
||||
::encrypt_lwe_u64(engine, std::get<2>(inputSk), ciphertext, plaintext,
|
||||
std::get<0>(inputSk).encryption->variance);
|
||||
auto encoding = encryption->encoding;
|
||||
auto lweSecretKeyParam = std::get<1>(inputSk);
|
||||
auto lweSecretKey = std::get<2>(inputSk);
|
||||
// CRT encoding - N blocks with crt encoding
|
||||
auto crt = encryption->encoding.crt;
|
||||
if (!crt.empty()) {
|
||||
// Put each decomposition into a new ciphertext
|
||||
auto product = crt::productOfModuli(crt);
|
||||
for (auto modulus : crt) {
|
||||
auto plaintext = crt::encode(input, modulus, product);
|
||||
::encrypt_lwe_u64(engine, lweSecretKey, ciphertext, plaintext,
|
||||
encryption->variance);
|
||||
ciphertext = ciphertext + lweSecretKeyParam.lweSize();
|
||||
}
|
||||
return outcome::success();
|
||||
}
|
||||
// Simple TFHE integers - 1 blocks with one padding bits
|
||||
// TODO we could check if the input value is in the right range
|
||||
uint64_t plaintext = input << (64 - (encryption->encoding.precision + 1));
|
||||
::encrypt_lwe_u64(engine, lweSecretKey, ciphertext, plaintext,
|
||||
encryption->variance);
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
@@ -228,13 +256,31 @@ KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
|
||||
return StringError("decrypt_lwe: position of argument is too high");
|
||||
}
|
||||
auto outputSk = outputs[argPos];
|
||||
if (!std::get<0>(outputSk).encryption.hasValue()) {
|
||||
auto lweSecretKey = std::get<2>(outputSk);
|
||||
auto lweSecretKeyParam = std::get<1>(outputSk);
|
||||
auto encryption = std::get<0>(outputSk).encryption;
|
||||
if (!encryption.hasValue()) {
|
||||
return StringError("decrypt_lwe: the positional argument is not encrypted");
|
||||
}
|
||||
uint64_t plaintext =
|
||||
::decrypt_lwe_u64(engine, std::get<2>(outputSk), ciphertext);
|
||||
auto crt = encryption->encoding.crt;
|
||||
// CRT encoding - N blocks with crt encoding
|
||||
if (!crt.empty()) {
|
||||
std::vector<int64_t> remainders;
|
||||
// decrypt and decode remainders
|
||||
for (auto modulus : crt) {
|
||||
auto decrypted = ::decrypt_lwe_u64(engine, lweSecretKey, ciphertext);
|
||||
auto plaintext = crt::decode(decrypted, modulus);
|
||||
remainders.push_back(plaintext);
|
||||
ciphertext = ciphertext + lweSecretKeyParam.lweSize();
|
||||
}
|
||||
// compute the inverse crt
|
||||
output = crt::iCrt(crt, remainders);
|
||||
return outcome::success();
|
||||
}
|
||||
// Simple TFHE integers - 1 blocks with one padding bits
|
||||
uint64_t plaintext = ::decrypt_lwe_u64(engine, lweSecretKey, ciphertext);
|
||||
// Decode
|
||||
size_t precision = std::get<0>(outputSk).encryption->encoding.precision;
|
||||
size_t precision = encryption->encoding.precision;
|
||||
output = plaintext >> (64 - precision - 2);
|
||||
size_t carry = output % 2;
|
||||
output = ((output >> 1) + carry) % (1 << (precision + 1));
|
||||
|
||||
@@ -90,7 +90,6 @@ KeySetCache::loadKeys(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb, std::string folderPath) {
|
||||
// TODO: text dump of all parameter in /hash
|
||||
auto key_set = std::make_unique<KeySet>();
|
||||
|
||||
// Mark the folder as recently use.
|
||||
// e.g. so the CI can do some cleanup of unused keys.
|
||||
utime(folderPath.c_str(), nullptr);
|
||||
|
||||
@@ -38,6 +38,9 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/IR/Function.h"
|
||||
|
||||
namespace Concrete = ::mlir::concretelang::Concrete;
|
||||
namespace BConcrete = ::mlir::concretelang::BConcrete;
|
||||
|
||||
namespace {
|
||||
struct ConcreteToBConcretePass
|
||||
: public ConcreteToBConcreteBase<ConcreteToBConcretePass> {
|
||||
@@ -68,9 +71,14 @@ public:
|
||||
});
|
||||
addConversion([&](mlir::concretelang::Concrete::LweCiphertextType type) {
|
||||
assert(type.getDimension() != -1);
|
||||
llvm::SmallVector<int64_t, 2> shape;
|
||||
auto crt = type.getCrtDecomposition();
|
||||
if (!crt.empty()) {
|
||||
shape.push_back(crt.size());
|
||||
}
|
||||
shape.push_back(type.getDimension() + 1);
|
||||
return mlir::RankedTensorType::get(
|
||||
{type.getDimension() + 1},
|
||||
mlir::IntegerType::get(type.getContext(), 64));
|
||||
shape, mlir::IntegerType::get(type.getContext(), 64));
|
||||
});
|
||||
addConversion([&](mlir::concretelang::Concrete::GlweCiphertextType type) {
|
||||
assert(type.getGlweDimension() != -1);
|
||||
@@ -91,91 +99,18 @@ public:
|
||||
mlir::SmallVector<int64_t> newShape;
|
||||
newShape.reserve(type.getShape().size() + 1);
|
||||
newShape.append(type.getShape().begin(), type.getShape().end());
|
||||
auto crt = lwe.getCrtDecomposition();
|
||||
if (!crt.empty()) {
|
||||
newShape.push_back(crt.size());
|
||||
}
|
||||
newShape.push_back(lwe.getDimension() + 1);
|
||||
mlir::Type r = mlir::RankedTensorType::get(
|
||||
newShape, mlir::IntegerType::get(type.getContext(), 64));
|
||||
return r;
|
||||
});
|
||||
addConversion([&](mlir::MemRefType type) {
|
||||
auto lwe = type.getElementType()
|
||||
.dyn_cast_or_null<
|
||||
mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
if (lwe == nullptr) {
|
||||
return (mlir::Type)(type);
|
||||
}
|
||||
assert(lwe.getDimension() != -1);
|
||||
mlir::SmallVector<int64_t> newShape;
|
||||
newShape.reserve(type.getShape().size() + 1);
|
||||
newShape.append(type.getShape().begin(), type.getShape().end());
|
||||
newShape.push_back(lwe.getDimension() + 1);
|
||||
mlir::Type r = mlir::MemRefType::get(
|
||||
newShape, mlir::IntegerType::get(type.getContext(), 64));
|
||||
return r;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct ConcreteEncodeIntOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::EncodeIntOp> {
|
||||
ConcreteEncodeIntOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<mlir::concretelang::Concrete::EncodeIntOp>(
|
||||
context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::EncodeIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
{
|
||||
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
|
||||
op.getLoc(), rewriter.getIntegerType(64), op->getOperands().front());
|
||||
mlir::Value constantShiftOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI64IntegerAttr(64 - op.getType().getP()));
|
||||
|
||||
mlir::Type resultType = rewriter.getIntegerType(64);
|
||||
rewriter.replaceOpWithNewOp<mlir::arith::ShLIOp>(
|
||||
op, resultType, castedInt, constantShiftOp);
|
||||
}
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct ConcreteIntToCleartextOpPattern
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::IntToCleartextOp> {
|
||||
ConcreteIntToCleartextOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<mlir::concretelang::Concrete::IntToCleartextOp>(
|
||||
context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::IntToCleartextOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(
|
||||
op, rewriter.getIntegerType(64), op->getOperands().front());
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
/// This rewrite pattern transforms any instance of `Concrete.zero_tensor`
|
||||
/// operators.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = "Concrete.zero_tensor" () :
|
||||
/// tensor<...x!Concrete.lwe_ciphertext<lweDim,p>>
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = tensor.generate {
|
||||
/// ^bb0(... : index):
|
||||
/// %c0 = arith.constant 0 : i64
|
||||
/// tensor.yield %z
|
||||
/// }: tensor<...xlweDim+1xi64>
|
||||
/// i64>
|
||||
/// ```
|
||||
template <typename ZeroOp>
|
||||
struct ZeroOpPattern : public mlir::OpRewritePattern<ZeroOp> {
|
||||
ZeroOpPattern(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
@@ -205,20 +140,7 @@ struct ZeroOpPattern : public mlir::OpRewritePattern<ZeroOp> {
|
||||
};
|
||||
};
|
||||
|
||||
/// This template rewrite pattern transforms any instance of
|
||||
/// `ConcreteOp` to an instance of `BConcreteOp`.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// %0 = "ConcreteOp"(%arg0, ...) :
|
||||
/// (!Concrete.lwe_ciphertext<lwe_dimension, p>, ...) ->
|
||||
/// (!Concrete.lwe_ciphertext<lwe_dimension, p>)
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// %0 = "BConcreteOp"(%arg0, ...) : (tensor<dimension+1, i64>>, ..., ) ->
|
||||
/// (tensor<dimension+1, i64>>)
|
||||
template <typename ConcreteOp, typename BConcreteOp>
|
||||
template <typename ConcreteOp, typename BConcreteOp, typename BConcreteCRTOp>
|
||||
struct LowToBConcrete : public mlir::OpRewritePattern<ConcreteOp> {
|
||||
LowToBConcrete(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<ConcreteOp>(context, benefit) {}
|
||||
@@ -236,9 +158,126 @@ struct LowToBConcrete : public mlir::OpRewritePattern<ConcreteOp> {
|
||||
llvm::ArrayRef<::mlir::NamedAttribute> attributes =
|
||||
concreteOp.getOperation()->getAttrs();
|
||||
|
||||
BConcreteOp bConcreteOp = rewriter.replaceOpWithNewOp<BConcreteOp>(
|
||||
concreteOp, newResultTy, concreteOp.getOperation()->getOperands(),
|
||||
attributes);
|
||||
auto crt = resultTy.getCrtDecomposition();
|
||||
mlir::Operation *bConcreteOp;
|
||||
if (crt.empty()) {
|
||||
bConcreteOp = rewriter.replaceOpWithNewOp<BConcreteOp>(
|
||||
concreteOp, newResultTy, concreteOp.getOperation()->getOperands(),
|
||||
attributes);
|
||||
} else {
|
||||
auto newAttributes = attributes.vec();
|
||||
newAttributes.push_back(rewriter.getNamedAttr(
|
||||
"crtDecomposition", rewriter.getI64ArrayAttr(crt)));
|
||||
bConcreteOp = rewriter.replaceOpWithNewOp<BConcreteCRTOp>(
|
||||
concreteOp, newResultTy, concreteOp.getOperation()->getOperands(),
|
||||
newAttributes);
|
||||
}
|
||||
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
return converter.convertType(t);
|
||||
});
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct AddPlaintextLweCiphertextOpPattern
|
||||
: public mlir::OpRewritePattern<Concrete::AddPlaintextLweCiphertextOp> {
|
||||
AddPlaintextLweCiphertextOpPattern(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<Concrete::AddPlaintextLweCiphertextOp>(
|
||||
context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(Concrete::AddPlaintextLweCiphertextOp concreteOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
auto loc = concreteOp.getLoc();
|
||||
mlir::concretelang::Concrete::LweCiphertextType resultTy =
|
||||
((mlir::Type)concreteOp->getResult(0).getType())
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
auto newResultTy =
|
||||
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
|
||||
|
||||
llvm::ArrayRef<::mlir::NamedAttribute> attributes =
|
||||
concreteOp.getOperation()->getAttrs();
|
||||
|
||||
auto crt = resultTy.getCrtDecomposition();
|
||||
mlir::Operation *bConcreteOp;
|
||||
if (crt.empty()) {
|
||||
// Encode the plaintext value
|
||||
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
|
||||
loc, rewriter.getIntegerType(64), concreteOp.rhs());
|
||||
mlir::Value constantShiftOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
loc,
|
||||
rewriter.getI64IntegerAttr(64 - concreteOp.getType().getP() - 1));
|
||||
auto encoded = rewriter.create<mlir::arith::ShLIOp>(
|
||||
loc, rewriter.getI64Type(), castedInt, constantShiftOp);
|
||||
bConcreteOp =
|
||||
rewriter.replaceOpWithNewOp<BConcrete::AddPlaintextLweBufferOp>(
|
||||
concreteOp, newResultTy,
|
||||
mlir::ValueRange{concreteOp.lhs(), encoded}, attributes);
|
||||
} else {
|
||||
// The encoding is done when we eliminate CRT ops
|
||||
auto newAttributes = attributes.vec();
|
||||
newAttributes.push_back(rewriter.getNamedAttr(
|
||||
"crtDecomposition", rewriter.getI64ArrayAttr(crt)));
|
||||
bConcreteOp =
|
||||
rewriter.replaceOpWithNewOp<BConcrete::AddPlaintextCRTLweBufferOp>(
|
||||
concreteOp, newResultTy, concreteOp.getOperation()->getOperands(),
|
||||
newAttributes);
|
||||
}
|
||||
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
return converter.convertType(t);
|
||||
});
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct MulCleartextLweCiphertextOpPattern
|
||||
: public mlir::OpRewritePattern<Concrete::MulCleartextLweCiphertextOp> {
|
||||
MulCleartextLweCiphertextOpPattern(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<Concrete::MulCleartextLweCiphertextOp>(
|
||||
context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(Concrete::MulCleartextLweCiphertextOp concreteOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
auto loc = concreteOp.getLoc();
|
||||
mlir::concretelang::Concrete::LweCiphertextType resultTy =
|
||||
((mlir::Type)concreteOp->getResult(0).getType())
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
auto newResultTy =
|
||||
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
|
||||
|
||||
llvm::ArrayRef<::mlir::NamedAttribute> attributes =
|
||||
concreteOp.getOperation()->getAttrs();
|
||||
|
||||
auto crt = resultTy.getCrtDecomposition();
|
||||
mlir::Operation *bConcreteOp;
|
||||
if (crt.empty()) {
|
||||
// Encode the plaintext value
|
||||
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
|
||||
loc, rewriter.getIntegerType(64), concreteOp.rhs());
|
||||
bConcreteOp =
|
||||
rewriter.replaceOpWithNewOp<BConcrete::MulCleartextLweBufferOp>(
|
||||
concreteOp, newResultTy,
|
||||
mlir::ValueRange{concreteOp.lhs(), castedInt}, attributes);
|
||||
} else {
|
||||
auto newAttributes = attributes.vec();
|
||||
newAttributes.push_back(rewriter.getNamedAttr(
|
||||
"crtDecomposition", rewriter.getI64ArrayAttr(crt)));
|
||||
bConcreteOp =
|
||||
rewriter.replaceOpWithNewOp<BConcrete::MulCleartextCRTLweBufferOp>(
|
||||
concreteOp, newResultTy, concreteOp.getOperation()->getOperands(),
|
||||
newAttributes);
|
||||
}
|
||||
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
@@ -249,27 +288,6 @@ struct LowToBConcrete : public mlir::OpRewritePattern<ConcreteOp> {
|
||||
};
|
||||
};
|
||||
|
||||
/// This rewrite pattern transforms any instance of
|
||||
/// `Concrete.glwe_from_table` operators.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = "Concrete.glwe_from_table"(%tlu)
|
||||
/// : (tensor<$Dxi64>) ->
|
||||
/// !Concrete.glwe_ciphertext<$polynomialSize,$glweDimension,$p>
|
||||
/// ```
|
||||
///
|
||||
/// with $D = 2^$p
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = linalg.init_tensor [polynomialSize*(glweDimension+1)]
|
||||
/// : tensor<polynomialSize*(glweDimension+1), i64>
|
||||
/// "BConcrete.fill_glwe_from_table" : (%0, polynomialSize, glweDimension, %tlu)
|
||||
/// : tensor<polynomialSize*(glweDimension+1), i64>, i64, i64, tensor<$Dxi64>
|
||||
/// ```
|
||||
struct GlweFromTablePattern : public mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::GlweFromTable> {
|
||||
GlweFromTablePattern(::mlir::MLIRContext *context,
|
||||
@@ -307,26 +325,6 @@ struct GlweFromTablePattern : public mlir::OpRewritePattern<
|
||||
};
|
||||
};
|
||||
|
||||
/// This rewrite pattern transforms any instance of
|
||||
/// `tensor.extract_slice` operators that operates on tensor of lwe ciphertext.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = tensor.extract_slice %arg0
|
||||
/// [offsets...] [sizes...] [strides...]
|
||||
/// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>> to
|
||||
/// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = tensor.extract_slice %arg0
|
||||
/// [offsets..., 0] [sizes..., lweDimension+1] [strides..., 1]
|
||||
/// : tensor<...xlweDimension+1,i64> to
|
||||
/// tensor<...xlweDimension+1,i64>
|
||||
/// ```
|
||||
struct ExtractSliceOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::tensor::ExtractSliceOp> {
|
||||
ExtractSliceOpPattern(::mlir::MLIRContext *context,
|
||||
@@ -339,29 +337,41 @@ struct ExtractSliceOpPattern
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
auto resultTy = extractSliceOp.result().getType();
|
||||
auto resultEltTy =
|
||||
auto lweResultTy =
|
||||
resultTy.cast<mlir::RankedTensorType>()
|
||||
.getElementType()
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
auto newResultTy = converter.convertType(resultTy);
|
||||
auto nbBlock = lweResultTy.getCrtDecomposition().size();
|
||||
auto newResultTy =
|
||||
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
|
||||
|
||||
// add 0 to the static_offsets
|
||||
mlir::SmallVector<mlir::Attribute> staticOffsets;
|
||||
staticOffsets.append(extractSliceOp.static_offsets().begin(),
|
||||
extractSliceOp.static_offsets().end());
|
||||
if (nbBlock != 0) {
|
||||
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
|
||||
}
|
||||
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
|
||||
|
||||
// add the lweSize to the sizes
|
||||
mlir::SmallVector<mlir::Attribute> staticSizes;
|
||||
staticSizes.append(extractSliceOp.static_sizes().begin(),
|
||||
extractSliceOp.static_sizes().end());
|
||||
staticSizes.push_back(
|
||||
rewriter.getI64IntegerAttr(resultEltTy.getDimension() + 1));
|
||||
if (nbBlock != 0) {
|
||||
staticSizes.push_back(rewriter.getI64IntegerAttr(
|
||||
newResultTy.getDimSize(newResultTy.getRank() - 2)));
|
||||
}
|
||||
staticSizes.push_back(rewriter.getI64IntegerAttr(
|
||||
newResultTy.getDimSize(newResultTy.getRank() - 1)));
|
||||
|
||||
// add 1 to the strides
|
||||
mlir::SmallVector<mlir::Attribute> staticStrides;
|
||||
staticStrides.append(extractSliceOp.static_strides().begin(),
|
||||
extractSliceOp.static_strides().end());
|
||||
if (nbBlock != 0) {
|
||||
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
}
|
||||
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
|
||||
// replace tensor.extract_slice to the new one
|
||||
@@ -382,26 +392,6 @@ struct ExtractSliceOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
/// This rewrite pattern transforms any instance of
|
||||
/// `tensor.extract` operators that operates on tensor of lwe ciphertext.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = tensor.extract %t[offsets...]
|
||||
/// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %1 = tensor.extract_slice %arg0
|
||||
/// [offsets...] [1..., lweDimension+1] [1...]
|
||||
/// : tensor<...xlweDimension+1,i64> to
|
||||
/// tensor<1...xlweDimension+1,i64>
|
||||
/// %0 = linalg.tensor_collapse_shape %0 [[...]] :
|
||||
/// tensor<1x1xlweDimension+1xi64> into tensor<lweDimension+1xi64>
|
||||
/// ```
|
||||
// TODO: since they are a bug on lowering extract_slice with rank reduction we
|
||||
// add a linalg.tensor_collapse_shape after the extract_slice without rank
|
||||
// reduction. See
|
||||
@@ -424,20 +414,29 @@ struct ExtractOpPattern
|
||||
if (lweResultTy == nullptr) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
auto nbBlock = lweResultTy.getCrtDecomposition().size();
|
||||
auto newResultTy =
|
||||
converter.convertType(lweResultTy).cast<mlir::RankedTensorType>();
|
||||
auto rankOfResult = extractOp.indices().size() + 1;
|
||||
|
||||
auto rankOfResult = extractOp.indices().size() +
|
||||
/* for the lwe dimension */ 1 +
|
||||
/* for the block dimension */
|
||||
(nbBlock == 0 ? 0 : 1);
|
||||
// [min..., 0] for static_offsets ()
|
||||
mlir::SmallVector<mlir::Attribute> staticOffsets(
|
||||
rankOfResult,
|
||||
rewriter.getI64IntegerAttr(std::numeric_limits<int64_t>::min()));
|
||||
if (nbBlock != 0) {
|
||||
staticOffsets[staticOffsets.size() - 2] = rewriter.getI64IntegerAttr(0);
|
||||
}
|
||||
staticOffsets[staticOffsets.size() - 1] = rewriter.getI64IntegerAttr(0);
|
||||
|
||||
// [1..., lweDimension+1] for static_sizes
|
||||
// [1..., lweDimension+1] for static_sizes or
|
||||
// [1..., nbBlock, lweDimension+1]
|
||||
mlir::SmallVector<mlir::Attribute> staticSizes(
|
||||
rankOfResult, rewriter.getI64IntegerAttr(1));
|
||||
if (nbBlock != 0) {
|
||||
staticSizes[staticSizes.size() - 2] = rewriter.getI64IntegerAttr(nbBlock);
|
||||
}
|
||||
staticSizes[staticSizes.size() - 1] = rewriter.getI64IntegerAttr(
|
||||
newResultTy.getDimSize(newResultTy.getRank() - 1));
|
||||
|
||||
@@ -446,38 +445,45 @@ struct ExtractOpPattern
|
||||
rankOfResult, rewriter.getI64IntegerAttr(1));
|
||||
|
||||
// replace tensor.extract_slice to the new one
|
||||
mlir::SmallVector<int64_t> extractedSliceShape(
|
||||
extractOp.indices().size() + 1, 0);
|
||||
extractedSliceShape.reserve(extractOp.indices().size() + 1);
|
||||
for (size_t i = 0; i < extractedSliceShape.size() - 1; i++) {
|
||||
extractedSliceShape[i] = 1;
|
||||
mlir::SmallVector<int64_t> extractedSliceShape(rankOfResult, 1);
|
||||
if (nbBlock != 0) {
|
||||
extractedSliceShape[extractedSliceShape.size() - 2] = nbBlock;
|
||||
extractedSliceShape[extractedSliceShape.size() - 1] =
|
||||
newResultTy.getDimSize(1);
|
||||
} else {
|
||||
extractedSliceShape[extractedSliceShape.size() - 1] =
|
||||
newResultTy.getDimSize(0);
|
||||
}
|
||||
extractedSliceShape[extractedSliceShape.size() - 1] =
|
||||
newResultTy.getDimSize(0);
|
||||
|
||||
auto extractedSliceType =
|
||||
mlir::RankedTensorType::get(extractedSliceShape, rewriter.getI64Type());
|
||||
|
||||
auto extractedSlice = rewriter.create<mlir::tensor::ExtractSliceOp>(
|
||||
extractOp.getLoc(), extractedSliceType, extractOp.tensor(),
|
||||
extractOp.indices(), mlir::SmallVector<mlir::Value>{},
|
||||
mlir::SmallVector<mlir::Value>{}, rewriter.getArrayAttr(staticOffsets),
|
||||
rewriter.getArrayAttr(staticSizes),
|
||||
rewriter.getArrayAttr(staticStrides));
|
||||
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, extractedSlice, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
return converter.convertType(t);
|
||||
});
|
||||
|
||||
mlir::ReassociationIndices reassociation;
|
||||
for (int64_t i = 0; i < extractedSliceType.getRank(); i++) {
|
||||
for (int64_t i = 0;
|
||||
i < extractedSliceType.getRank() - (nbBlock == 0 ? 0 : 1); i++) {
|
||||
reassociation.push_back(i);
|
||||
}
|
||||
|
||||
mlir::SmallVector<mlir::ReassociationIndices> reassocs{reassociation};
|
||||
|
||||
if (nbBlock != 0) {
|
||||
reassocs.push_back({extractedSliceType.getRank() - 1});
|
||||
}
|
||||
|
||||
mlir::tensor::CollapseShapeOp collapseOp =
|
||||
rewriter.replaceOpWithNewOp<mlir::tensor::CollapseShapeOp>(
|
||||
extractOp, newResultTy, extractedSlice,
|
||||
mlir::SmallVector<mlir::ReassociationIndices>{reassociation});
|
||||
extractOp, newResultTy, extractedSlice, reassocs);
|
||||
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, collapseOp, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
@@ -488,26 +494,6 @@ struct ExtractOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
/// This rewrite pattern transforms any instance of
|
||||
/// `tensor.insert_slice` operators that operates on tensor of lwe ciphertext.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = tensor.insert_slice %arg1
|
||||
/// into %arg0[offsets...] [sizes...] [strides...]
|
||||
/// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>> into
|
||||
/// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = tensor.insert_slice %arg1
|
||||
/// into %arg0[offsets..., 0] [sizes..., lweDimension+1] [strides..., 1]
|
||||
/// : tensor<...xlweDimension+1xi64> into
|
||||
/// tensor<...xlweDimension+1xi64>
|
||||
/// ```
|
||||
struct InsertSliceOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::tensor::InsertSliceOp> {
|
||||
InsertSliceOpPattern(::mlir::MLIRContext *context,
|
||||
@@ -520,7 +506,14 @@ struct InsertSliceOpPattern
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
auto resultTy = insertSliceOp.result().getType();
|
||||
|
||||
auto lweResultTy =
|
||||
resultTy.cast<mlir::RankedTensorType>()
|
||||
.getElementType()
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
if (lweResultTy == nullptr) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto nbBlock = lweResultTy.getCrtDecomposition().size();
|
||||
auto newResultTy =
|
||||
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
|
||||
|
||||
@@ -528,12 +521,19 @@ struct InsertSliceOpPattern
|
||||
mlir::SmallVector<mlir::Attribute> staticOffsets;
|
||||
staticOffsets.append(insertSliceOp.static_offsets().begin(),
|
||||
insertSliceOp.static_offsets().end());
|
||||
if (nbBlock != 0) {
|
||||
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
|
||||
}
|
||||
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
|
||||
|
||||
// add lweDimension+1 to static_sizes
|
||||
mlir::SmallVector<mlir::Attribute> staticSizes;
|
||||
staticSizes.append(insertSliceOp.static_sizes().begin(),
|
||||
insertSliceOp.static_sizes().end());
|
||||
if (nbBlock != 0) {
|
||||
staticSizes.push_back(rewriter.getI64IntegerAttr(
|
||||
newResultTy.getDimSize(newResultTy.getRank() - 2)));
|
||||
}
|
||||
staticSizes.push_back(rewriter.getI64IntegerAttr(
|
||||
newResultTy.getDimSize(newResultTy.getRank() - 1)));
|
||||
|
||||
@@ -541,6 +541,9 @@ struct InsertSliceOpPattern
|
||||
mlir::SmallVector<mlir::Attribute> staticStrides;
|
||||
staticStrides.append(insertSliceOp.static_strides().begin(),
|
||||
insertSliceOp.static_strides().end());
|
||||
if (nbBlock != 0) {
|
||||
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
}
|
||||
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
|
||||
// replace tensor.insert_slice with the new one
|
||||
@@ -560,28 +563,6 @@ struct InsertSliceOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
/// This rewrite pattern transforms any instance of `tensor.insert`
|
||||
/// operators that operates on an lwe ciphertexts to a
|
||||
/// `tensor.insert_slice` op operating on the bufferized representation
|
||||
/// of the ciphertext.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = tensor.insert %arg1
|
||||
/// into %arg0[offsets...]
|
||||
/// : !Concrete.lwe_ciphertext<lweDimension,p> into
|
||||
/// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = tensor.insert_slice %arg1
|
||||
/// into %arg0[offsets..., 0] [sizes..., lweDimension+1] [strides..., 1]
|
||||
/// : tensor<lweDimension+1xi64> into
|
||||
/// tensor<...xlweDimension+1xi64>
|
||||
/// ```
|
||||
struct InsertOpPattern : public mlir::OpRewritePattern<mlir::tensor::InsertOp> {
|
||||
InsertOpPattern(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
@@ -591,14 +572,24 @@ struct InsertOpPattern : public mlir::OpRewritePattern<mlir::tensor::InsertOp> {
|
||||
matchAndRewrite(mlir::tensor::InsertOp insertOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
mlir::Type resultTy = insertOp.result().getType();
|
||||
auto resultTy =
|
||||
insertOp.result().getType().dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
auto lweResultTy = resultTy.getElementType()
|
||||
.dyn_cast_or_null<Concrete::LweCiphertextType>();
|
||||
if (lweResultTy == nullptr) {
|
||||
return mlir::failure();
|
||||
};
|
||||
auto hasBlock = lweResultTy.getCrtDecomposition().size() != 0;
|
||||
mlir::RankedTensorType newResultTy =
|
||||
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
|
||||
|
||||
// add 0 to static_offsets
|
||||
// add zeros to static_offsets
|
||||
mlir::SmallVector<mlir::OpFoldResult> offsets;
|
||||
offsets.append(insertOp.indices().begin(), insertOp.indices().end());
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
if (hasBlock) {
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
}
|
||||
|
||||
// Inserting a smaller tensor into a (potentially) bigger one. Set
|
||||
// dimensions for all leading dimensions of the target tensor not
|
||||
@@ -607,6 +598,10 @@ struct InsertOpPattern : public mlir::OpRewritePattern<mlir::tensor::InsertOp> {
|
||||
rewriter.getI64IntegerAttr(1));
|
||||
|
||||
// Add size for the bufferized source element
|
||||
if (hasBlock) {
|
||||
sizes.push_back(rewriter.getI64IntegerAttr(
|
||||
newResultTy.getDimSize(newResultTy.getRank() - 2)));
|
||||
}
|
||||
sizes.push_back(rewriter.getI64IntegerAttr(
|
||||
newResultTy.getDimSize(newResultTy.getRank() - 1)));
|
||||
|
||||
@@ -711,26 +706,26 @@ struct FromElementsOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
/// This template rewrite pattern transforms any instance of
|
||||
/// `ShapeOp` operators that operates on tensor of lwe ciphertext by adding the
|
||||
/// lwe size as a size of the tensor result and by adding a trivial
|
||||
/// reassociation at the end of the reassociations map.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = "ShapeOp" %arg0 [reassocations...]
|
||||
/// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>> into
|
||||
/// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = "ShapeOp" %arg0 [reassociations..., [inRank or outRank]]
|
||||
/// : tensor<...xlweDimesion+1xi64> into
|
||||
/// tensor<...xlweDimesion+1xi64>
|
||||
/// ```
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// `ShapeOp` operators that operates on tensor of lwe ciphertext by adding the
|
||||
// lwe size as a size of the tensor result and by adding a trivial
|
||||
// reassociation at the end of the reassociations map.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "ShapeOp" %arg0 [reassocations...]
|
||||
// : tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>> into
|
||||
// tensor<...x!Concrete.lwe_ciphertext<lweDimension,p>>
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "ShapeOp" %arg0 [reassociations..., [inRank or outRank]]
|
||||
// : tensor<...xlweDimesion+1xi64> into
|
||||
// tensor<...xlweDimesion+1xi64>
|
||||
// ```
|
||||
template <typename ShapeOp, typename VecTy, bool inRank>
|
||||
struct TensorShapeOpPattern : public mlir::OpRewritePattern<ShapeOp> {
|
||||
TensorShapeOpPattern(::mlir::MLIRContext *context,
|
||||
@@ -741,22 +736,35 @@ struct TensorShapeOpPattern : public mlir::OpRewritePattern<ShapeOp> {
|
||||
matchAndRewrite(ShapeOp shapeOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
auto resultTy = shapeOp.result().getType();
|
||||
auto resultTy = ((mlir::Type)shapeOp.result().getType()).cast<VecTy>();
|
||||
auto lweResultTy =
|
||||
((mlir::Type)resultTy.getElementType())
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
|
||||
auto newResultTy =
|
||||
((mlir::Type)converter.convertType(resultTy)).cast<VecTy>();
|
||||
|
||||
// add [rank] to reassociations
|
||||
auto oldReassocs = shapeOp.getReassociationIndices();
|
||||
mlir::SmallVector<mlir::ReassociationIndices> newReassocs;
|
||||
newReassocs.append(oldReassocs.begin(), oldReassocs.end());
|
||||
mlir::ReassociationIndices lweAssoc;
|
||||
auto reassocTy =
|
||||
((mlir::Type)converter.convertType(
|
||||
(inRank ? shapeOp.src() : shapeOp.result()).getType()))
|
||||
.cast<VecTy>();
|
||||
lweAssoc.push_back(reassocTy.getRank() - 1);
|
||||
newReassocs.push_back(lweAssoc);
|
||||
|
||||
auto oldReassocs = shapeOp.getReassociationIndices();
|
||||
mlir::SmallVector<mlir::ReassociationIndices> newReassocs;
|
||||
newReassocs.append(oldReassocs.begin(), oldReassocs.end());
|
||||
// add [rank-1] to reassociations if crt decomp
|
||||
if (!lweResultTy.getCrtDecomposition().empty()) {
|
||||
mlir::ReassociationIndices lweAssoc;
|
||||
lweAssoc.push_back(reassocTy.getRank() - 2);
|
||||
newReassocs.push_back(lweAssoc);
|
||||
}
|
||||
|
||||
// add [rank] to reassociations
|
||||
{
|
||||
mlir::ReassociationIndices lweAssoc;
|
||||
lweAssoc.push_back(reassocTy.getRank() - 1);
|
||||
newReassocs.push_back(lweAssoc);
|
||||
}
|
||||
|
||||
ShapeOp op = rewriter.replaceOpWithNewOp<ShapeOp>(
|
||||
shapeOp, newResultTy, shapeOp.src(), newReassocs);
|
||||
@@ -785,20 +793,6 @@ void insertTensorShapeOpPattern(mlir::MLIRContext &context,
|
||||
});
|
||||
}
|
||||
|
||||
/// Rewrites `bufferization.alloc_tensor` ops for which the converted type in
|
||||
/// BConcrete is different from the original type.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```
|
||||
/// bufferization.alloc_tensor() : tensor<4x!Concrete.lwe_ciphertext<4096,6>>
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```
|
||||
/// bufferization.alloc_tensor() : tensor<4x4097xi64>
|
||||
/// ```
|
||||
struct AllocTensorOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::bufferization::AllocTensorOp> {
|
||||
AllocTensorOpPattern(::mlir::MLIRContext *context,
|
||||
@@ -877,10 +871,6 @@ void ConcreteToBConcretePass::runOnOperation() {
|
||||
// Add Concrete ops are illegal after the conversion
|
||||
target.addIllegalDialect<mlir::concretelang::Concrete::ConcreteDialect>();
|
||||
|
||||
// Add patterns to convert cleartext and plaintext to i64
|
||||
patterns
|
||||
.insert<ConcreteEncodeIntOpPattern, ConcreteIntToCleartextOpPattern>(
|
||||
&getContext());
|
||||
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
|
||||
|
||||
// Add patterns to convert the zero ops to tensor.generate
|
||||
@@ -894,27 +884,25 @@ void ConcreteToBConcretePass::runOnOperation() {
|
||||
// BConcrete op
|
||||
patterns.insert<
|
||||
LowToBConcrete<mlir::concretelang::Concrete::AddLweCiphertextsOp,
|
||||
mlir::concretelang::BConcrete::AddLweBuffersOp>,
|
||||
LowToBConcrete<
|
||||
mlir::concretelang::Concrete::AddPlaintextLweCiphertextOp,
|
||||
mlir::concretelang::BConcrete::AddPlaintextLweBufferOp>,
|
||||
LowToBConcrete<
|
||||
mlir::concretelang::Concrete::MulCleartextLweCiphertextOp,
|
||||
mlir::concretelang::BConcrete::MulCleartextLweBufferOp>,
|
||||
LowToBConcrete<
|
||||
mlir::concretelang::Concrete::MulCleartextLweCiphertextOp,
|
||||
mlir::concretelang::BConcrete::MulCleartextLweBufferOp>,
|
||||
mlir::concretelang::BConcrete::AddLweBuffersOp,
|
||||
BConcrete::AddCRTLweBuffersOp>,
|
||||
AddPlaintextLweCiphertextOpPattern, MulCleartextLweCiphertextOpPattern,
|
||||
LowToBConcrete<mlir::concretelang::Concrete::NegateLweCiphertextOp,
|
||||
mlir::concretelang::BConcrete::NegateLweBufferOp>,
|
||||
mlir::concretelang::BConcrete::NegateLweBufferOp,
|
||||
BConcrete::NegateCRTLweBufferOp>,
|
||||
LowToBConcrete<mlir::concretelang::Concrete::KeySwitchLweOp,
|
||||
mlir::concretelang::BConcrete::KeySwitchLweBufferOp,
|
||||
mlir::concretelang::BConcrete::KeySwitchLweBufferOp>,
|
||||
LowToBConcrete<mlir::concretelang::Concrete::BootstrapLweOp,
|
||||
mlir::concretelang::BConcrete::BootstrapLweBufferOp>>(
|
||||
&getContext());
|
||||
mlir::concretelang::BConcrete::BootstrapLweBufferOp,
|
||||
mlir::concretelang::BConcrete::KeySwitchLweBufferOp>,
|
||||
LowToBConcrete<Concrete::WopPBSLweOp, BConcrete::WopPBSCRTLweBufferOp,
|
||||
BConcrete::WopPBSCRTLweBufferOp>>(&getContext());
|
||||
|
||||
patterns.insert<GlweFromTablePattern>(&getContext());
|
||||
|
||||
// Add patterns to rewrite tensor operators that works on encrypted tensors
|
||||
// Add patterns to rewrite tensor operators that works on encrypted
|
||||
// tensors
|
||||
patterns
|
||||
.insert<ExtractSliceOpPattern, ExtractOpPattern, InsertSliceOpPattern,
|
||||
InsertOpPattern, FromElementsOpPattern>(&getContext());
|
||||
@@ -939,7 +927,8 @@ void ConcreteToBConcretePass::runOnOperation() {
|
||||
patterns.insert<ForOpPattern>(&getContext());
|
||||
|
||||
// Add patterns to rewrite some of memref ops that was introduced by the
|
||||
// linalg bufferization of encrypted tensor (first conversion of this pass)
|
||||
// linalg bufferization of encrypted tensor (first conversion of this
|
||||
// pass)
|
||||
insertTensorShapeOpPattern<mlir::memref::ExpandShapeOp, mlir::MemRefType,
|
||||
false>(getContext(), patterns, target);
|
||||
insertTensorShapeOpPattern<mlir::tensor::ExpandShapeOp, mlir::TensorType,
|
||||
|
||||
@@ -26,10 +26,6 @@ namespace FHE = mlir::concretelang::FHE;
|
||||
namespace TFHE = mlir::concretelang::TFHE;
|
||||
|
||||
namespace {
|
||||
struct FHEToTFHEPass : public FHEToTFHEBase<FHEToTFHEPass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
using mlir::concretelang::FHE::EncryptedIntegerType;
|
||||
using mlir::concretelang::TFHE::GLWECipherTextType;
|
||||
@@ -84,10 +80,10 @@ public:
|
||||
/// : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) ->
|
||||
/// !TFHE.glwe<{_,_,_}{2}>
|
||||
/// ```
|
||||
struct ApplyLookupTableEintOpPattern
|
||||
struct ApplyLookupTableEintOpToKeyswitchBootstrapPattern
|
||||
: public mlir::OpRewritePattern<FHE::ApplyLookupTableEintOp> {
|
||||
ApplyLookupTableEintOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
ApplyLookupTableEintOpToKeyswitchBootstrapPattern(
|
||||
mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<FHE::ApplyLookupTableEintOp>(context,
|
||||
benefit) {}
|
||||
|
||||
@@ -115,6 +111,51 @@ struct ApplyLookupTableEintOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
/// This rewrite pattern transforms any instance of `FHE.apply_lookup_table`
|
||||
/// operators.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = "FHE.apply_lookup_table"(%ct, %lut): (!FHE.eint<2>, tensor<4xi64>)
|
||||
/// ->(!FHE.eint<2>)
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = "TFHE.wop_pbs_glwe"(%ct, %lut)
|
||||
/// : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) ->
|
||||
/// (!TFHE.glwe<{_,_,_}{2}>)
|
||||
/// ```
|
||||
struct ApplyLookupTableEintOpToWopPBSPattern
|
||||
: public mlir::OpRewritePattern<FHE::ApplyLookupTableEintOp> {
|
||||
ApplyLookupTableEintOpToWopPBSPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<FHE::ApplyLookupTableEintOp>(context,
|
||||
benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::ApplyLookupTableEintOp lutOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
FHEToTFHETypeConverter converter;
|
||||
auto inputTy = converter.convertType(lutOp.a().getType())
|
||||
.cast<TFHE::GLWECipherTextType>();
|
||||
auto resultTy = converter.convertType(lutOp.getType());
|
||||
// %0 = "TFHE.wop_pbs_glwe"(%ct, %lut)
|
||||
// : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) ->
|
||||
// (!TFHE.glwe<{_,_,_}{2}>)
|
||||
auto wopPBS = rewriter.replaceOpWithNewOp<TFHE::WopPBSGLWEOp>(
|
||||
lutOp, resultTy, lutOp.a(), lutOp.lut(), -1, -1, -1, -1, -1, -1, -1, -1,
|
||||
-1, -1, -1);
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, wopPBS, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
return converter.convertType(t);
|
||||
});
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
/// This rewrite pattern transforms any instance of `FHE.sub_eint_int`
|
||||
/// operators to a negation and an addition.
|
||||
struct SubEintIntOpPattern : public mlir::OpRewritePattern<FHE::SubEintIntOp> {
|
||||
@@ -194,97 +235,118 @@ struct SubEintOpPattern : public mlir::OpRewritePattern<FHE::SubEintOp> {
|
||||
};
|
||||
};
|
||||
|
||||
void FHEToTFHEPass::runOnOperation() {
|
||||
auto op = this->getOperation();
|
||||
struct FHEToTFHEPass : public FHEToTFHEBase<FHEToTFHEPass> {
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
FHEToTFHETypeConverter converter;
|
||||
FHEToTFHEPass(mlir::concretelang::ApplyLookupTableLowering lutLowerStrategy)
|
||||
: lutLowerStrategy(lutLowerStrategy) {}
|
||||
|
||||
// Mark ops from the target dialect as legal operations
|
||||
target.addLegalDialect<mlir::concretelang::TFHE::TFHEDialect>();
|
||||
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
|
||||
void runOnOperation() {
|
||||
auto op = this->getOperation();
|
||||
|
||||
// Make sure that no ops from `FHE` remain after the lowering
|
||||
target.addIllegalDialect<mlir::concretelang::FHE::FHEDialect>();
|
||||
mlir::ConversionTarget target(getContext());
|
||||
FHEToTFHETypeConverter converter;
|
||||
|
||||
// Make sure that no ops `linalg.generic` that have illegal types
|
||||
target
|
||||
.addDynamicallyLegalOp<mlir::linalg::GenericOp, mlir::tensor::GenerateOp>(
|
||||
[&](mlir::Operation *op) {
|
||||
return (
|
||||
converter.isLegal(op->getOperandTypes()) &&
|
||||
converter.isLegal(op->getResultTypes()) &&
|
||||
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
|
||||
});
|
||||
// Mark ops from the target dialect as legal operations
|
||||
target.addLegalDialect<mlir::concretelang::TFHE::TFHEDialect>();
|
||||
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
|
||||
|
||||
// Make sure that func has legal signature
|
||||
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
|
||||
[&](mlir::func::FuncOp funcOp) {
|
||||
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
|
||||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
// Make sure that no ops from `FHE` remain after the lowering
|
||||
target.addIllegalDialect<mlir::concretelang::FHE::FHEDialect>();
|
||||
|
||||
// Add all patterns required to lower all ops from `FHE` to
|
||||
// `TFHE`
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
// Make sure that no ops `linalg.generic` that have illegal types
|
||||
target.addDynamicallyLegalOp<mlir::linalg::GenericOp,
|
||||
mlir::tensor::GenerateOp>(
|
||||
[&](mlir::Operation *op) {
|
||||
return (
|
||||
converter.isLegal(op->getOperandTypes()) &&
|
||||
converter.isLegal(op->getResultTypes()) &&
|
||||
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
|
||||
});
|
||||
|
||||
populateWithGeneratedFHEToTFHE(patterns);
|
||||
// Make sure that func has legal signature
|
||||
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
|
||||
[&](mlir::func::FuncOp funcOp) {
|
||||
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
|
||||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
|
||||
patterns.add<
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>>(
|
||||
patterns.getContext(), converter);
|
||||
// Add all patterns required to lower all ops from `FHE` to
|
||||
// `TFHE`
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
patterns.add<ApplyLookupTableEintOpPattern>(&getContext());
|
||||
patterns.add<SubEintOpPattern>(&getContext());
|
||||
patterns.add<SubEintIntOpPattern>(&getContext());
|
||||
populateWithGeneratedFHEToTFHE(patterns);
|
||||
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
|
||||
FHEToTFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>>(
|
||||
patterns.getContext(), converter);
|
||||
|
||||
patterns.add<
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::linalg::YieldOp>>(
|
||||
patterns.getContext(), converter);
|
||||
switch (lutLowerStrategy) {
|
||||
case mlir::concretelang::KeySwitchBoostrapLowering:
|
||||
patterns.add<ApplyLookupTableEintOpToKeyswitchBootstrapPattern>(
|
||||
&getContext());
|
||||
break;
|
||||
case mlir::concretelang::WopPBSLowering:
|
||||
patterns.add<ApplyLookupTableEintOpToWopPBSPattern>(&getContext());
|
||||
break;
|
||||
}
|
||||
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
|
||||
FHEToTFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<SubEintOpPattern>(&getContext());
|
||||
patterns.add<SubEintIntOpPattern>(&getContext());
|
||||
|
||||
patterns.add<
|
||||
RegionOpTypeConverterPattern<mlir::scf::ForOp, FHEToTFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<mlir::concretelang::GenericTypeAndOpConverterPattern<
|
||||
mlir::concretelang::FHE::ZeroTensorOp,
|
||||
mlir::concretelang::TFHE::ZeroTensorGLWEOp>>(&getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
|
||||
FHEToTFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
mlir::concretelang::populateWithTensorTypeConverterPatterns(patterns, target,
|
||||
converter);
|
||||
patterns.add<
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::linalg::YieldOp>>(
|
||||
patterns.getContext(), converter);
|
||||
|
||||
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
|
||||
patterns, converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
|
||||
FHEToTFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
// Conversion of RT Dialect Ops
|
||||
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DataflowTaskOp>>(patterns.getContext(),
|
||||
converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DataflowTaskOp>(target, converter);
|
||||
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DataflowYieldOp>>(patterns.getContext(),
|
||||
converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DataflowYieldOp>(target, converter);
|
||||
patterns.add<
|
||||
RegionOpTypeConverterPattern<mlir::scf::ForOp, FHEToTFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<mlir::concretelang::GenericTypeAndOpConverterPattern<
|
||||
mlir::concretelang::FHE::ZeroTensorOp,
|
||||
mlir::concretelang::TFHE::ZeroTensorGLWEOp>>(&getContext(), converter);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
mlir::concretelang::populateWithTensorTypeConverterPatterns(
|
||||
patterns, target, converter);
|
||||
|
||||
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
|
||||
patterns, converter);
|
||||
|
||||
// Conversion of RT Dialect Ops
|
||||
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DataflowTaskOp>>(patterns.getContext(),
|
||||
converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DataflowTaskOp>(target, converter);
|
||||
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DataflowYieldOp>>(patterns.getContext(),
|
||||
converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DataflowYieldOp>(target, converter);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
mlir::concretelang::ApplyLookupTableLowering lutLowerStrategy;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertFHEToTFHEPass() {
|
||||
return std::make_unique<FHEToTFHEPass>();
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertFHEToTFHEPass(ApplyLookupTableLowering lower) {
|
||||
return std::make_unique<FHEToTFHEPass>(lower);
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -59,12 +59,17 @@ public:
|
||||
auto dimension = cryptoParameters.getNBigGlweDimension();
|
||||
auto polynomialSize = 1;
|
||||
auto precision = (signed)type.getP();
|
||||
auto crtDecomposition =
|
||||
cryptoParameters.largeInteger.hasValue()
|
||||
? cryptoParameters.largeInteger->crtDecomposition
|
||||
: mlir::concretelang::CRTDecomposition{};
|
||||
if ((int)dimension == type.getDimension() &&
|
||||
(int)polynomialSize == type.getPolynomialSize()) {
|
||||
return type;
|
||||
}
|
||||
return TFHE::GLWECipherTextType::get(type.getContext(), dimension,
|
||||
polynomialSize, bits, precision);
|
||||
polynomialSize, bits, precision,
|
||||
crtDecomposition);
|
||||
}
|
||||
|
||||
TFHE::GLWECipherTextType glweLookupTableType(GLWECipherTextType &type) {
|
||||
@@ -73,7 +78,7 @@ public:
|
||||
auto polynomialSize = cryptoParameters.getPolynomialSize();
|
||||
auto precision = (signed)type.getP();
|
||||
return TFHE::GLWECipherTextType::get(type.getContext(), dimension,
|
||||
polynomialSize, bits, precision);
|
||||
polynomialSize, bits, precision, {});
|
||||
}
|
||||
|
||||
TFHE::GLWECipherTextType glweIntraPBSType(GLWECipherTextType &type) {
|
||||
@@ -82,7 +87,7 @@ public:
|
||||
auto polynomialSize = 1;
|
||||
auto precision = (signed)type.getP();
|
||||
return TFHE::GLWECipherTextType::get(type.getContext(), dimension,
|
||||
polynomialSize, bits, precision);
|
||||
polynomialSize, bits, precision, {});
|
||||
}
|
||||
|
||||
mlir::concretelang::V0Parameter cryptoParameters;
|
||||
@@ -154,6 +159,49 @@ private:
|
||||
mlir::concretelang::V0Parameter &cryptoParameters;
|
||||
};
|
||||
|
||||
struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
|
||||
WopPBSGLWEOpPattern(mlir::MLIRContext *context,
|
||||
TFHEGlobalParametrizationTypeConverter &converter,
|
||||
mlir::concretelang::V0Parameter &cryptoParameters,
|
||||
mlir::PatternBenefit benefit =
|
||||
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<TFHE::WopPBSGLWEOp>(context, benefit),
|
||||
converter(converter), cryptoParameters(cryptoParameters) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(TFHE::WopPBSGLWEOp wopPBSOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto newOp = rewriter.replaceOpWithNewOp<TFHE::WopPBSGLWEOp>(
|
||||
wopPBSOp, converter.convertType(wopPBSOp.result().getType()),
|
||||
wopPBSOp.ciphertext(), wopPBSOp.lookupTable(),
|
||||
// Bootstrap parameters
|
||||
cryptoParameters.brLevel, cryptoParameters.brLogBase,
|
||||
// Keyswitch parameters
|
||||
cryptoParameters.ksLevel, cryptoParameters.ksLogBase,
|
||||
// Packing keyswitch key parameters
|
||||
cryptoParameters.largeInteger->wopPBS.packingKeySwitch
|
||||
.inputLweDimension,
|
||||
cryptoParameters.largeInteger->wopPBS.packingKeySwitch.inputLweCount,
|
||||
cryptoParameters.largeInteger->wopPBS.packingKeySwitch
|
||||
.outputPolynomialSize,
|
||||
cryptoParameters.largeInteger->wopPBS.packingKeySwitch.level,
|
||||
cryptoParameters.largeInteger->wopPBS.packingKeySwitch.baseLog,
|
||||
// Circuit bootstrap parameters
|
||||
cryptoParameters.largeInteger->wopPBS.circuitBootstrap.level,
|
||||
cryptoParameters.largeInteger->wopPBS.circuitBootstrap.baseLog);
|
||||
rewriter.startRootUpdate(newOp);
|
||||
auto ciphertextType =
|
||||
wopPBSOp.ciphertext().getType().cast<TFHE::GLWECipherTextType>();
|
||||
newOp.ciphertext().setType(converter.glweInterPBSType(ciphertextType));
|
||||
rewriter.finalizeRootUpdate(newOp);
|
||||
return mlir::success();
|
||||
};
|
||||
|
||||
private:
|
||||
TFHEGlobalParametrizationTypeConverter &converter;
|
||||
mlir::concretelang::V0Parameter &cryptoParameters;
|
||||
};
|
||||
|
||||
/// This rewrite pattern transforms any instance of `TFHE.glwe_from_table` by
|
||||
/// parametrize GLWE return type and pad the table if the precision has been
|
||||
/// changed.
|
||||
@@ -271,6 +319,16 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
|
||||
return converter.isLegal(op->getResultTypes());
|
||||
});
|
||||
|
||||
// Parametrize wop pbs
|
||||
patterns.add<WopPBSGLWEOpPattern>(&getContext(), converter,
|
||||
cryptoParameters);
|
||||
target.addDynamicallyLegalOp<TFHE::WopPBSGLWEOp>(
|
||||
[&](TFHE::WopPBSGLWEOp op) {
|
||||
return !op.getType()
|
||||
.cast<TFHE::GLWECipherTextType>()
|
||||
.hasUnparametrizedParameters();
|
||||
});
|
||||
|
||||
// Add all patterns to convert TFHE types
|
||||
populateWithTFHEOpTypeConversionPatterns(patterns, target, converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<
|
||||
|
||||
@@ -109,6 +109,45 @@ private:
|
||||
mlir::TypeConverter &converter;
|
||||
};
|
||||
|
||||
struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
|
||||
WopPBSGLWEOpPattern(mlir::MLIRContext *context,
|
||||
mlir::TypeConverter &converter,
|
||||
mlir::PatternBenefit benefit = 100)
|
||||
: mlir::OpRewritePattern<TFHE::WopPBSGLWEOp>(context, benefit),
|
||||
converter(converter) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(TFHE::WopPBSGLWEOp wopOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::Type resultType = converter.convertType(wopOp.getType());
|
||||
|
||||
auto newOp = rewriter.replaceOpWithNewOp<Concrete::WopPBSLweOp>(
|
||||
wopOp, resultType, wopOp.ciphertext(), wopOp.lookupTable(),
|
||||
// Bootstrap parameters
|
||||
wopOp.bootstrapLevel(), wopOp.bootstrapBaseLog(),
|
||||
// Keyswitch parameters
|
||||
wopOp.keyswitchLevel(), wopOp.keyswitchBaseLog(),
|
||||
// Packing keyswitch key parameters
|
||||
wopOp.packingKeySwitchInputLweDimension(),
|
||||
wopOp.packingKeySwitchinputLweCount(),
|
||||
wopOp.packingKeySwitchoutputPolynomialSize(),
|
||||
wopOp.packingKeySwitchLevel(), wopOp.packingKeySwitchBaseLog(),
|
||||
// Circuit bootstrap parameters
|
||||
wopOp.circuitBootstrapLevel(), wopOp.circuitBootstrapBaseLog());
|
||||
|
||||
rewriter.startRootUpdate(newOp);
|
||||
|
||||
newOp.ciphertext().setType(
|
||||
converter.convertType(wopOp.ciphertext().getType()));
|
||||
|
||||
rewriter.finalizeRootUpdate(newOp);
|
||||
return ::mlir::success();
|
||||
}
|
||||
|
||||
private:
|
||||
mlir::TypeConverter &converter;
|
||||
};
|
||||
|
||||
void TFHEToConcretePass::runOnOperation() {
|
||||
auto op = this->getOperation();
|
||||
|
||||
@@ -147,6 +186,7 @@ void TFHEToConcretePass::runOnOperation() {
|
||||
mlir::concretelang::Concrete::ZeroTensorLWEOp>>(&getContext(), converter);
|
||||
patterns.add<GLWEFromTableOpPattern>(&getContext());
|
||||
patterns.add<BootstrapGLWEOpPattern>(&getContext(), converter);
|
||||
patterns.add<WopPBSGLWEOpPattern>(&getContext(), converter);
|
||||
target.addDynamicallyLegalOp<Concrete::BootstrapLweOp>(
|
||||
[&](Concrete::BootstrapLweOp op) {
|
||||
return (converter.isLegal(op->getOperandTypes()) &&
|
||||
|
||||
@@ -36,23 +36,23 @@ namespace {} // namespace
|
||||
|
||||
namespace {
|
||||
|
||||
mlir::Type getDynamic1DMemrefWithUnknownOffset(mlir::RewriterBase &rewriter) {
|
||||
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
|
||||
size_t rank) {
|
||||
mlir::MLIRContext *ctx = rewriter.getContext();
|
||||
|
||||
return mlir::MemRefType::get(
|
||||
{-1}, rewriter.getI64Type(),
|
||||
mlir::AffineMap::get(1, 1,
|
||||
mlir::getAffineDimExpr(0, ctx) +
|
||||
mlir::getAffineSymbolExpr(0, ctx)));
|
||||
std::vector<int64_t> shape(rank, -1);
|
||||
return mlir::MemRefType::get(shape, rewriter.getI64Type(),
|
||||
rewriter.getMultiDimIdentityMap(rank));
|
||||
}
|
||||
|
||||
/// Returns `memref.cast %0 : memref<AxT> to memref<?xT>` if %0 a 1D memref
|
||||
mlir::Value getCasted1DMemRef(mlir::RewriterBase &rewriter, mlir::Location loc,
|
||||
mlir::Value value) {
|
||||
// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>`
|
||||
mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Location loc,
|
||||
mlir::Value value) {
|
||||
mlir::Type valueType = value.getType();
|
||||
if (valueType.isa<mlir::MemRefType>()) {
|
||||
if (auto memrefTy = valueType.dyn_cast_or_null<mlir::MemRefType>()) {
|
||||
return rewriter.create<mlir::memref::CastOp>(
|
||||
loc, getDynamic1DMemrefWithUnknownOffset(rewriter), value);
|
||||
loc,
|
||||
getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()),
|
||||
value);
|
||||
} else {
|
||||
return value;
|
||||
}
|
||||
@@ -69,10 +69,13 @@ char memref_bootstrap_lwe_u64[] = "memref_bootstrap_lwe_u64";
|
||||
char memref_expand_lut_in_trivial_glwe_ct_u64[] =
|
||||
"memref_expand_lut_in_trivial_glwe_ct_u64";
|
||||
|
||||
char memref_wop_pbs_crt_buffer[] = "memref_wop_pbs_crt_buffer";
|
||||
|
||||
mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
mlir::Operation *op, mlir::RewriterBase &rewriter, char const *funcName) {
|
||||
|
||||
auto memref1DType = getDynamic1DMemrefWithUnknownOffset(rewriter);
|
||||
auto memref1DType = getDynamicMemrefWithUnknownOffset(rewriter, 1);
|
||||
auto memref2DType = getDynamicMemrefWithUnknownOffset(rewriter, 2);
|
||||
auto contextType =
|
||||
mlir::concretelang::Concrete::ContextType::get(rewriter.getContext());
|
||||
|
||||
@@ -109,6 +112,15 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
memref1DType,
|
||||
},
|
||||
{});
|
||||
} else if (funcName == memref_wop_pbs_crt_buffer) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
memref2DType,
|
||||
memref2DType,
|
||||
memref1DType,
|
||||
contextType,
|
||||
},
|
||||
{});
|
||||
} else {
|
||||
op->emitError("unknwon external function") << funcName;
|
||||
return mlir::failure();
|
||||
@@ -185,7 +197,7 @@ struct BufferizableWithCallOpInterface
|
||||
|
||||
// The first operand is the result
|
||||
mlir::SmallVector<mlir::Value, 3> operands{
|
||||
getCasted1DMemRef(rewriter, loc, *outMemref),
|
||||
getCastedMemRef(rewriter, loc, *outMemref),
|
||||
};
|
||||
// For all tensor operand get the corresponding casted buffer
|
||||
for (auto &operand : op->getOpOperands()) {
|
||||
@@ -194,7 +206,7 @@ struct BufferizableWithCallOpInterface
|
||||
} else {
|
||||
auto memrefOperand =
|
||||
bufferization::getBuffer(rewriter, operand.get(), options);
|
||||
operands.push_back(getCasted1DMemRef(rewriter, loc, memrefOperand));
|
||||
operands.push_back(getCastedMemRef(rewriter, loc, memrefOperand));
|
||||
}
|
||||
}
|
||||
// Append the context argument
|
||||
@@ -264,14 +276,14 @@ struct BufferizableGlweFromTableOpInterface
|
||||
auto loc = op->getLoc();
|
||||
auto castOp = cast<BConcrete::FillGlweFromTable>(op);
|
||||
|
||||
auto glweOp = getCasted1DMemRef(
|
||||
rewriter, loc,
|
||||
bufferization::getBuffer(rewriter, castOp->getOpOperand(0).get(),
|
||||
options));
|
||||
auto lutOp = getCasted1DMemRef(
|
||||
rewriter, loc,
|
||||
bufferization::getBuffer(rewriter, castOp->getOpOperand(1).get(),
|
||||
options));
|
||||
auto glweOp =
|
||||
getCastedMemRef(rewriter, loc,
|
||||
bufferization::getBuffer(
|
||||
rewriter, castOp->getOpOperand(0).get(), options));
|
||||
auto lutOp =
|
||||
getCastedMemRef(rewriter, loc,
|
||||
bufferization::getBuffer(
|
||||
rewriter, castOp->getOpOperand(1).get(), options));
|
||||
|
||||
auto polySizeOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(), rewriter.getI32IntegerAttr(castOp.polynomialSize()));
|
||||
@@ -326,6 +338,10 @@ void mlir::concretelang::BConcrete::
|
||||
BConcrete::BootstrapLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::BootstrapLweBufferOp,
|
||||
memref_bootstrap_lwe_u64, true>>(*ctx);
|
||||
// TODO(16bits): hack
|
||||
BConcrete::WopPBSCRTLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::WopPBSCRTLweBufferOp,
|
||||
memref_wop_pbs_crt_buffer, true>>(*ctx);
|
||||
BConcrete::FillGlweFromTable::attachInterface<
|
||||
BufferizableGlweFromTableOpInterface>(*ctx);
|
||||
});
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
add_mlir_dialect_library(ConcretelangBConcreteTransforms
|
||||
BufferizableOpInterfaceImpl.cpp
|
||||
AddRuntimeContext.cpp
|
||||
EliminateCRTOps.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/BConcrete
|
||||
@@ -18,4 +19,4 @@ add_mlir_dialect_library(ConcretelangBConcreteTransforms
|
||||
MLIRMemRefDialect
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
)
|
||||
)
|
||||
|
||||
561
compiler/lib/Dialect/BConcrete/Transforms/EliminateCRTOps.cpp
Normal file
561
compiler/lib/Dialect/BConcrete/Transforms/EliminateCRTOps.cpp
Normal file
@@ -0,0 +1,561 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/ClientLib/CRT.h"
|
||||
#include "concretelang/Conversion/Tools.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
|
||||
#include "concretelang/Dialect/BConcrete/Transforms/Passes.h"
|
||||
|
||||
namespace arith = mlir::arith;
|
||||
namespace tensor = mlir::tensor;
|
||||
namespace bufferization = mlir::bufferization;
|
||||
namespace scf = mlir::scf;
|
||||
namespace BConcrete = mlir::concretelang::BConcrete;
|
||||
namespace crt = concretelang::clientlib::crt;
|
||||
|
||||
namespace {
|
||||
|
||||
char encode_crt[] = "encode_crt";
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// `BConcreteCRTOp` operators to `BConcreteOp` on
|
||||
// each block.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "BConcreteCRTOp"(%arg0, %arg1) {crtDecomposition = [...]}
|
||||
// : (tensor<nbBlocksxlweSizexi64>, tensor<nbBlocksxlweSizexi64>) ->
|
||||
// (tensor<nbBlocksxlweSizexi64>)
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %c0 = arith.constant 0 : index
|
||||
// %c1 = arith.constant 1 : index
|
||||
// %cB = arith.constant nbBlocks : index
|
||||
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
|
||||
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
|
||||
// (tensor<nbBlocksxlweSizexi64>) {
|
||||
// %blockArg = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64>
|
||||
// %tmp = "BConcreteOp"(%blockArg)
|
||||
// : (tensor<lweSizexi64>) -> (tensor<lweSizexi64>)
|
||||
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
|
||||
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
|
||||
// }
|
||||
// ```
|
||||
template <typename BConcreteCRTOp, typename BConcreteOp>
|
||||
struct BConcreteCRTUnaryOpPattern
|
||||
: public mlir::OpRewritePattern<BConcreteCRTOp> {
|
||||
BConcreteCRTUnaryOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<BConcreteCRTOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(BConcreteCRTOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto resultTy =
|
||||
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
|
||||
auto loc = op.getLoc();
|
||||
assert(resultTy.getShape().size() == 2);
|
||||
auto shape = resultTy.getShape();
|
||||
|
||||
// %c0 = arith.constant 0 : index
|
||||
// %c1 = arith.constant 1 : index
|
||||
// %cB = arith.constant nbBlocks : index
|
||||
auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
auto cB = rewriter.create<arith::ConstantIndexOp>(loc, shape[0]);
|
||||
|
||||
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
|
||||
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
|
||||
op.getLoc(), resultTy, mlir::ValueRange{});
|
||||
|
||||
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
|
||||
// (tensor<nbBlocksxlweSizexi64>) {
|
||||
rewriter.replaceOpWithNewOp<scf::ForOp>(
|
||||
op, c0, cB, c1, init,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i,
|
||||
mlir::ValueRange iterArgs) {
|
||||
// [%i, 0]
|
||||
mlir::SmallVector<mlir::OpFoldResult> offsets{
|
||||
i, rewriter.getI64IntegerAttr(0)};
|
||||
// [1, lweSize]
|
||||
mlir::SmallVector<mlir::OpFoldResult> sizes{
|
||||
rewriter.getI64IntegerAttr(1),
|
||||
rewriter.getI64IntegerAttr(shape[1])};
|
||||
// [1, 1]
|
||||
mlir::SmallVector<mlir::OpFoldResult> strides{
|
||||
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
|
||||
|
||||
auto blockTy = mlir::RankedTensorType::get({shape[1]},
|
||||
resultTy.getElementType());
|
||||
|
||||
// %blockArg = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64>
|
||||
auto blockArg = builder.create<tensor::ExtractSliceOp>(
|
||||
loc, blockTy, op.ciphertext(), offsets, sizes, strides);
|
||||
// %tmp = "BConcrete.add_lwe_buffer"(%blockArg0, %blockArg1)
|
||||
// : (tensor<lweSizexi64>, tensor<lweSizexi64>) ->
|
||||
// (tensor<lweSizexi64>)
|
||||
auto tmp = builder.create<BConcreteOp>(loc, blockTy, blockArg);
|
||||
|
||||
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
|
||||
// 1] : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
|
||||
auto res = builder.create<tensor::InsertSliceOp>(
|
||||
loc, tmp, iterArgs[0], offsets, sizes, strides);
|
||||
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
|
||||
builder.create<scf::YieldOp>(loc, (mlir::Value)res);
|
||||
});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// `BConcreteCRTOp` operators to `BConcreteOp` on
|
||||
// each block.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "BConcreteCRTOp"(%arg0, %arg1) {crtDecomposition = [...]}
|
||||
// : (tensor<nbBlocksxlweSizexi64>, tensor<nbBlocksxlweSizexi64>) ->
|
||||
// (tensor<nbBlocksxlweSizexi64>)
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %c0 = arith.constant 0 : index
|
||||
// %c1 = arith.constant 1 : index
|
||||
// %cB = arith.constant nbBlocks : index
|
||||
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
|
||||
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
|
||||
// (tensor<nbBlocksxlweSizexi64>) {
|
||||
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64>
|
||||
// %blockArg1 = tensor.extract_slice %arg1[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64>
|
||||
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
|
||||
// : (tensor<lweSizexi64>, tensor<lweSizexi64>) ->
|
||||
// (tensor<lweSizexi64>)
|
||||
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
|
||||
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
|
||||
// }
|
||||
// ```
|
||||
template <typename BConcreteCRTOp, typename BConcreteOp>
|
||||
struct BConcreteCRTBinaryOpPattern
|
||||
: public mlir::OpRewritePattern<BConcreteCRTOp> {
|
||||
BConcreteCRTBinaryOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<BConcreteCRTOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(BConcreteCRTOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto resultTy =
|
||||
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
|
||||
auto loc = op.getLoc();
|
||||
assert(resultTy.getShape().size() == 2);
|
||||
auto shape = resultTy.getShape();
|
||||
|
||||
// %c0 = arith.constant 0 : index
|
||||
// %c1 = arith.constant 1 : index
|
||||
// %cB = arith.constant nbBlocks : index
|
||||
auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
auto cB = rewriter.create<arith::ConstantIndexOp>(loc, shape[0]);
|
||||
|
||||
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
|
||||
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
|
||||
op.getLoc(), resultTy, mlir::ValueRange{});
|
||||
|
||||
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
|
||||
// (tensor<nbBlocksxlweSizexi64>) {
|
||||
rewriter.replaceOpWithNewOp<scf::ForOp>(
|
||||
op, c0, cB, c1, init,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i,
|
||||
mlir::ValueRange iterArgs) {
|
||||
// [%i, 0]
|
||||
mlir::SmallVector<mlir::OpFoldResult> offsets{
|
||||
i, rewriter.getI64IntegerAttr(0)};
|
||||
// [1, lweSize]
|
||||
mlir::SmallVector<mlir::OpFoldResult> sizes{
|
||||
rewriter.getI64IntegerAttr(1),
|
||||
rewriter.getI64IntegerAttr(shape[1])};
|
||||
// [1, 1]
|
||||
mlir::SmallVector<mlir::OpFoldResult> strides{
|
||||
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
|
||||
|
||||
auto blockTy = mlir::RankedTensorType::get({shape[1]},
|
||||
resultTy.getElementType());
|
||||
|
||||
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64>
|
||||
auto blockArg0 = builder.create<tensor::ExtractSliceOp>(
|
||||
loc, blockTy, op.lhs(), offsets, sizes, strides);
|
||||
// %blockArg1 = tensor.extract_slice %arg1[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64>
|
||||
auto blockArg1 = builder.create<tensor::ExtractSliceOp>(
|
||||
loc, blockTy, op.rhs(), offsets, sizes, strides);
|
||||
// %tmp = "BConcrete.add_lwe_buffer"(%blockArg0, %blockArg1)
|
||||
// : (tensor<lweSizexi64>, tensor<lweSizexi64>) ->
|
||||
// (tensor<lweSizexi64>)
|
||||
auto tmp =
|
||||
builder.create<BConcreteOp>(loc, blockTy, blockArg0, blockArg1);
|
||||
|
||||
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
|
||||
// 1] : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
|
||||
auto res = builder.create<tensor::InsertSliceOp>(
|
||||
loc, tmp, iterArgs[0], offsets, sizes, strides);
|
||||
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
|
||||
builder.create<scf::YieldOp>(loc, (mlir::Value)res);
|
||||
});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// `BConcreteCRTOp` operators to `BConcreteOp` on
|
||||
// each block with the crt decomposition of the cleartext.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "BConcreteCRTOp"(%arg0, %x) {crtDecomposition = [d0...dn]}
|
||||
// : (tensor<nbBlocksxlweSizexi64>, i64) -> (tensor<nbBlocksxlweSizexi64>)
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// // Build the decomposition of the plaintext
|
||||
// %x0_a = arith.constant 64/d0 : f64
|
||||
// %x0_b = arith.mulf %x, %x0_a : i64
|
||||
// %x0 = arith.fptoui %x0_b : f64 to i64
|
||||
// ...
|
||||
// %xn_a = arith.constant 64/dn : f64
|
||||
// %xn_b = arith.mulf %x, %xn_a : i64
|
||||
// %xn = arith.fptoui %xn_b : f64 to i64
|
||||
// %x_decomp = tensor.from_elements %x0, ..., %xn : tensor<nbBlocksxi64>
|
||||
// // Loop on blocks
|
||||
// %c0 = arith.constant 0 : index
|
||||
// %c1 = arith.constant 1 : index
|
||||
// %cB = arith.constant nbBlocks : index
|
||||
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
|
||||
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
|
||||
// (tensor<nbBlocksxlweSizexi64>) {
|
||||
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64>
|
||||
// %blockArg1 = tensor.extract %x_decomp[%i] : tensor<nbBlocksxi64>
|
||||
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
|
||||
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
|
||||
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
|
||||
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
|
||||
// }
|
||||
// ```
|
||||
struct AddPlaintextCRTLweBufferOpPattern
|
||||
: public mlir::OpRewritePattern<BConcrete::AddPlaintextCRTLweBufferOp> {
|
||||
AddPlaintextCRTLweBufferOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<BConcrete::AddPlaintextCRTLweBufferOp>(context,
|
||||
benefit) {
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(BConcrete::AddPlaintextCRTLweBufferOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto resultTy =
|
||||
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
|
||||
auto loc = op.getLoc();
|
||||
assert(resultTy.getShape().size() == 2);
|
||||
auto shape = resultTy.getShape();
|
||||
|
||||
auto rhs = op.rhs();
|
||||
mlir::SmallVector<mlir::Value, 5> plaintextElements;
|
||||
uint64_t moduliProduct = 1;
|
||||
for (mlir::Attribute di : op.crtDecomposition()) {
|
||||
moduliProduct *= di.cast<mlir::IntegerAttr>().getValue().getZExtValue();
|
||||
}
|
||||
if (auto cst =
|
||||
mlir::dyn_cast_or_null<arith::ConstantIntOp>(rhs.getDefiningOp())) {
|
||||
auto apCst = cst.getValue().cast<mlir::IntegerAttr>().getValue();
|
||||
auto value = apCst.getSExtValue();
|
||||
|
||||
// constant value, encode at compile time
|
||||
for (mlir::Attribute di : op.crtDecomposition()) {
|
||||
auto modulus = di.cast<mlir::IntegerAttr>().getValue().getZExtValue();
|
||||
|
||||
auto encoded = crt::encode(value, modulus, moduliProduct);
|
||||
plaintextElements.push_back(
|
||||
rewriter.create<arith::ConstantIntOp>(loc, encoded, 64));
|
||||
}
|
||||
} else {
|
||||
// dynamic value, encode at runtime
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, encode_crt,
|
||||
mlir::FunctionType::get(rewriter.getContext(),
|
||||
{rewriter.getI64Type(),
|
||||
rewriter.getI64Type(),
|
||||
rewriter.getI64Type()},
|
||||
{rewriter.getI64Type()}))
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto extOp =
|
||||
rewriter.create<arith::ExtSIOp>(loc, rewriter.getI64Type(), rhs);
|
||||
auto moduliProductOp =
|
||||
rewriter.create<arith::ConstantIntOp>(loc, moduliProduct, 64);
|
||||
for (mlir::Attribute di : op.crtDecomposition()) {
|
||||
auto modulus = di.cast<mlir::IntegerAttr>().getValue().getZExtValue();
|
||||
auto modulusOp =
|
||||
rewriter.create<arith::ConstantIntOp>(loc, modulus, 64);
|
||||
plaintextElements.push_back(
|
||||
rewriter
|
||||
.create<mlir::func::CallOp>(
|
||||
loc, encode_crt, mlir::TypeRange{rewriter.getI64Type()},
|
||||
mlir::ValueRange{extOp, modulusOp, moduliProductOp})
|
||||
.getResult(0));
|
||||
}
|
||||
}
|
||||
|
||||
// %x_decomp = tensor.from_elements %x0, ..., %xn : tensor<nbBlocksxi64>
|
||||
auto x_decomp =
|
||||
rewriter.create<tensor::FromElementsOp>(loc, plaintextElements);
|
||||
|
||||
// %c0 = arith.constant 0 : index
|
||||
// %c1 = arith.constant 1 : index
|
||||
// %cB = arith.constant nbBlocks : index
|
||||
auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
auto cB = rewriter.create<arith::ConstantIndexOp>(loc, shape[0]);
|
||||
|
||||
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
|
||||
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
|
||||
op.getLoc(), resultTy, mlir::ValueRange{});
|
||||
|
||||
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
|
||||
// (tensor<nbBlocksxlweSizexi64>) {
|
||||
rewriter.replaceOpWithNewOp<scf::ForOp>(
|
||||
op, c0, cB, c1, init,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i,
|
||||
mlir::ValueRange iterArgs) {
|
||||
// [%i, 0]
|
||||
mlir::SmallVector<mlir::OpFoldResult> offsets{
|
||||
i, rewriter.getI64IntegerAttr(0)};
|
||||
// [1, lweSize]
|
||||
mlir::SmallVector<mlir::OpFoldResult> sizes{
|
||||
rewriter.getI64IntegerAttr(1),
|
||||
rewriter.getI64IntegerAttr(shape[1])};
|
||||
// [1, 1]
|
||||
mlir::SmallVector<mlir::OpFoldResult> strides{
|
||||
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
|
||||
|
||||
auto blockTy = mlir::RankedTensorType::get({shape[1]},
|
||||
resultTy.getElementType());
|
||||
|
||||
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64>
|
||||
auto blockArg0 = builder.create<tensor::ExtractSliceOp>(
|
||||
loc, blockTy, op.lhs(), offsets, sizes, strides);
|
||||
// %blockArg1 = tensor.extract %x_decomp[%i] : tensor<nbBlocksxi64>
|
||||
auto blockArg1 = builder.create<tensor::ExtractOp>(loc, x_decomp, i);
|
||||
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
|
||||
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
|
||||
auto tmp = builder.create<BConcrete::AddPlaintextLweBufferOp>(
|
||||
loc, blockTy, blockArg0, blockArg1);
|
||||
|
||||
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
|
||||
// 1] : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
|
||||
auto res = builder.create<tensor::InsertSliceOp>(
|
||||
loc, tmp, iterArgs[0], offsets, sizes, strides);
|
||||
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
|
||||
builder.create<scf::YieldOp>(loc, (mlir::Value)res);
|
||||
});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// `BConcreteCRTOp` operators to `BConcreteOp` on
|
||||
// each block with the crt decomposition of the cleartext.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "BConcreteCRTOp"(%arg0, %x) {crtDecomposition = [d0...dn]}
|
||||
// : (tensor<nbBlocksxlweSizexi64>, i64) -> (tensor<nbBlocksxlweSizexi64>)
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// // Build the decomposition of the plaintext
|
||||
// %x0_a = arith.constant 64/d0 : f64
|
||||
// %x0_b = arith.mulf %x, %x0_a : i64
|
||||
// %x0 = arith.fptoui %x0_b : f64 to i64
|
||||
// ...
|
||||
// %xn_a = arith.constant 64/dn : f64
|
||||
// %xn_b = arith.mulf %x, %xn_a : i64
|
||||
// %xn = arith.fptoui %xn_b : f64 to i64
|
||||
// %x_decomp = tensor.from_elements %x0, ..., %xn : tensor<nbBlocksxi64>
|
||||
// // Loop on blocks
|
||||
// %c0 = arith.constant 0 : index
|
||||
// %c1 = arith.constant 1 : index
|
||||
// %cB = arith.constant nbBlocks : index
|
||||
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
|
||||
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
|
||||
// (tensor<nbBlocksxlweSizexi64>) {
|
||||
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64>
|
||||
// %blockArg1 = tensor.extract %x_decomp[%i] : tensor<nbBlocksxi64>
|
||||
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
|
||||
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
|
||||
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
|
||||
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
|
||||
// }
|
||||
// ```
|
||||
struct MulCleartextCRTLweBufferOpPattern
|
||||
: public mlir::OpRewritePattern<BConcrete::MulCleartextCRTLweBufferOp> {
|
||||
MulCleartextCRTLweBufferOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<BConcrete::MulCleartextCRTLweBufferOp>(context,
|
||||
benefit) {
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(BConcrete::MulCleartextCRTLweBufferOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto resultTy =
|
||||
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
|
||||
auto loc = op.getLoc();
|
||||
assert(resultTy.getShape().size() == 2);
|
||||
auto shape = resultTy.getShape();
|
||||
|
||||
// %c0 = arith.constant 0 : index
|
||||
// %c1 = arith.constant 1 : index
|
||||
// %cB = arith.constant nbBlocks : index
|
||||
auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
auto cB = rewriter.create<arith::ConstantIndexOp>(loc, shape[0]);
|
||||
|
||||
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
|
||||
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
|
||||
op.getLoc(), resultTy, mlir::ValueRange{});
|
||||
|
||||
auto rhs = rewriter.create<arith::ExtUIOp>(op.getLoc(),
|
||||
rewriter.getI64Type(), op.rhs());
|
||||
|
||||
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
|
||||
// (tensor<nbBlocksxlweSizexi64>) {
|
||||
rewriter.replaceOpWithNewOp<scf::ForOp>(
|
||||
op, c0, cB, c1, init,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i,
|
||||
mlir::ValueRange iterArgs) {
|
||||
// [%i, 0]
|
||||
mlir::SmallVector<mlir::OpFoldResult> offsets{
|
||||
i, rewriter.getI64IntegerAttr(0)};
|
||||
// [1, lweSize]
|
||||
mlir::SmallVector<mlir::OpFoldResult> sizes{
|
||||
rewriter.getI64IntegerAttr(1),
|
||||
rewriter.getI64IntegerAttr(shape[1])};
|
||||
// [1, 1]
|
||||
mlir::SmallVector<mlir::OpFoldResult> strides{
|
||||
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
|
||||
|
||||
auto blockTy = mlir::RankedTensorType::get({shape[1]},
|
||||
resultTy.getElementType());
|
||||
|
||||
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64>
|
||||
auto blockArg0 = builder.create<tensor::ExtractSliceOp>(
|
||||
loc, blockTy, op.lhs(), offsets, sizes, strides);
|
||||
|
||||
// %tmp = BConcrete.mul_cleartext_lwe_buffer(%blockArg0, %x)
|
||||
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
|
||||
auto tmp = builder.create<BConcrete::MulCleartextLweBufferOp>(
|
||||
loc, blockTy, blockArg0, rhs);
|
||||
|
||||
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
|
||||
// 1] : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
|
||||
auto res = builder.create<tensor::InsertSliceOp>(
|
||||
loc, tmp, iterArgs[0], offsets, sizes, strides);
|
||||
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
|
||||
builder.create<scf::YieldOp>(loc, (mlir::Value)res);
|
||||
});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
struct EliminateCRTOpsPass : public EliminateCRTOpsBase<EliminateCRTOpsPass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
|
||||
void EliminateCRTOpsPass::runOnOperation() {
|
||||
auto op = getOperation();
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
// add_crt_lwe_buffers
|
||||
target.addIllegalOp<BConcrete::AddCRTLweBuffersOp>();
|
||||
patterns.add<BConcreteCRTBinaryOpPattern<BConcrete::AddCRTLweBuffersOp,
|
||||
BConcrete::AddLweBuffersOp>>(
|
||||
&getContext());
|
||||
|
||||
// add_plaintext_crt_lwe_buffers
|
||||
target.addIllegalOp<BConcrete::AddPlaintextCRTLweBufferOp>();
|
||||
patterns.add<AddPlaintextCRTLweBufferOpPattern>(&getContext());
|
||||
|
||||
// mul_cleartext_crt_lwe_buffer
|
||||
target.addIllegalOp<BConcrete::MulCleartextCRTLweBufferOp>();
|
||||
patterns.add<MulCleartextCRTLweBufferOpPattern>(&getContext());
|
||||
|
||||
target.addIllegalOp<BConcrete::NegateCRTLweBufferOp>();
|
||||
patterns.add<BConcreteCRTUnaryOpPattern<BConcrete::NegateCRTLweBufferOp,
|
||||
BConcrete::NegateLweBufferOp>>(
|
||||
&getContext());
|
||||
|
||||
// This dialect are used to transforms crt ops to bconcrete ops
|
||||
target
|
||||
.addLegalDialect<arith::ArithmeticDialect, tensor::TensorDialect,
|
||||
scf::SCFDialect, bufferization::BufferizationDialect,
|
||||
mlir::func::FuncDialect, BConcrete::BConcreteDialect>();
|
||||
|
||||
// Apply the conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createEliminateCRTOps() {
|
||||
return std::make_unique<EliminateCRTOpsPass>();
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -25,6 +25,13 @@ void ConcreteDialect::initialize() {
|
||||
>();
|
||||
}
|
||||
|
||||
void printSigned(mlir::AsmPrinter &p, signed i) {
|
||||
if (i == -1)
|
||||
p << "_";
|
||||
else
|
||||
p << i;
|
||||
}
|
||||
|
||||
mlir::Type GlweCiphertextType::parse(mlir::AsmParser &parser) {
|
||||
if (parser.parseLess())
|
||||
return Type();
|
||||
@@ -50,42 +57,61 @@ mlir::Type GlweCiphertextType::parse(mlir::AsmParser &parser) {
|
||||
|
||||
void GlweCiphertextType::print(mlir::AsmPrinter &p) const {
|
||||
p << "<";
|
||||
if (getImpl()->glweDimension == -1)
|
||||
p << "_";
|
||||
else
|
||||
p << getImpl()->glweDimension;
|
||||
printSigned(p, getGlweDimension());
|
||||
p << ",";
|
||||
if (getImpl()->polynomialSize == -1)
|
||||
p << "_";
|
||||
else
|
||||
p << getImpl()->polynomialSize;
|
||||
printSigned(p, getPolynomialSize());
|
||||
p << ",";
|
||||
if (getImpl()->p == -1)
|
||||
p << "_";
|
||||
else
|
||||
p << getImpl()->p;
|
||||
printSigned(p, getP());
|
||||
p << ">";
|
||||
}
|
||||
|
||||
void LweCiphertextType::print(mlir::AsmPrinter &p) const {
|
||||
p << "<";
|
||||
|
||||
if (getDimension() == -1)
|
||||
p << "_";
|
||||
else
|
||||
p << getDimension();
|
||||
|
||||
// decomposition parameters if any
|
||||
auto crt = getCrtDecomposition();
|
||||
if (!crt.empty()) {
|
||||
p << "crt=[";
|
||||
for (auto c : crt.drop_back(1)) {
|
||||
printSigned(p, c);
|
||||
p << ",";
|
||||
}
|
||||
printSigned(p, crt.back());
|
||||
p << "]";
|
||||
p << ",";
|
||||
}
|
||||
printSigned(p, getDimension());
|
||||
p << ",";
|
||||
if (getP() == -1)
|
||||
p << "_";
|
||||
else
|
||||
p << getP();
|
||||
printSigned(p, getP());
|
||||
p << ">";
|
||||
}
|
||||
|
||||
mlir::Type LweCiphertextType::parse(mlir::AsmParser &parser) {
|
||||
if (parser.parseLess())
|
||||
return mlir::Type();
|
||||
|
||||
// Parse for the crt decomposition if any
|
||||
std::vector<int64_t> crtDecomposition;
|
||||
if (!parser.parseOptionalKeyword("crt")) {
|
||||
if (parser.parseEqual() || parser.parseLSquare())
|
||||
return mlir::Type();
|
||||
while (true) {
|
||||
int64_t c = -1;
|
||||
if (parser.parseOptionalKeyword("_") && parser.parseInteger(c)) {
|
||||
return mlir::Type();
|
||||
}
|
||||
crtDecomposition.push_back(c);
|
||||
if (parser.parseOptionalComma()) {
|
||||
if (parser.parseRSquare()) {
|
||||
return mlir::Type();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (parser.parseComma())
|
||||
return mlir::Type();
|
||||
}
|
||||
|
||||
int dimension = -1;
|
||||
if (parser.parseOptionalKeyword("_") && parser.parseInteger(dimension))
|
||||
return mlir::Type();
|
||||
@@ -99,7 +125,7 @@ mlir::Type LweCiphertextType::parse(mlir::AsmParser &parser) {
|
||||
|
||||
mlir::Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
|
||||
|
||||
return getChecked(loc, loc.getContext(), dimension, p);
|
||||
return getChecked(loc, loc.getContext(), dimension, p, crtDecomposition);
|
||||
}
|
||||
|
||||
void CleartextType::print(mlir::AsmPrinter &p) const {
|
||||
|
||||
@@ -16,29 +16,10 @@ namespace concretelang {
|
||||
|
||||
namespace {
|
||||
|
||||
/// Get the integer value that the cleartext was created from if it exists.
|
||||
llvm::Optional<mlir::Value>
|
||||
getIntegerFromCleartextIfExists(mlir::Value cleartext) {
|
||||
assert(
|
||||
cleartext.getType().isa<mlir::concretelang::Concrete::CleartextType>());
|
||||
// Cleartext are supposed to be created from integers
|
||||
auto intToCleartextOp = cleartext.getDefiningOp();
|
||||
if (intToCleartextOp == nullptr)
|
||||
return {};
|
||||
if (llvm::isa<Concrete::IntToCleartextOp>(intToCleartextOp)) {
|
||||
// We want to match when the integer value is constant
|
||||
return intToCleartextOp->getOperand(0);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
/// Get the constant integer that the cleartext was created from if it exists.
|
||||
llvm::Optional<IntegerAttr>
|
||||
getConstantIntFromCleartextIfExists(mlir::Value cleartext) {
|
||||
auto cleartextInt = getIntegerFromCleartextIfExists(cleartext);
|
||||
if (!cleartextInt.hasValue())
|
||||
return {};
|
||||
auto constantOp = cleartextInt.getValue().getDefiningOp();
|
||||
auto constantOp = cleartext.getDefiningOp();
|
||||
if (constantOp == nullptr)
|
||||
return {};
|
||||
if (llvm::isa<arith::ConstantOp>(constantOp)) {
|
||||
|
||||
@@ -32,7 +32,8 @@ void TFHEDialect::initialize() {
|
||||
/// - The bits parameter is 64 (we support only this for v0)
|
||||
::mlir::LogicalResult GLWECipherTextType::verify(
|
||||
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
|
||||
signed dimension, signed polynomialSize, signed bits, signed p) {
|
||||
signed dimension, signed polynomialSize, signed bits, signed p,
|
||||
llvm::ArrayRef<int64_t>) {
|
||||
if (bits != -1 && bits != 64) {
|
||||
emitError() << "GLWE bits parameter can only be 64";
|
||||
return ::mlir::failure();
|
||||
|
||||
@@ -40,6 +40,10 @@ mlir::LogicalResult _verifyGLWEIntegerOperator(mlir::OpState &op,
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "p");
|
||||
return mlir::failure();
|
||||
}
|
||||
if (a.getCrtDecomposition() != result.getCrtDecomposition()) {
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "crt");
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
// verify consistency of width of inputs
|
||||
if ((int)b.getWidth() > a.getP() + 1) {
|
||||
@@ -107,6 +111,11 @@ mlir::LogicalResult verifyBinaryGLWEOperator(Operator &op) {
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "p");
|
||||
return mlir::failure();
|
||||
}
|
||||
if (a.getCrtDecomposition() != b.getCrtDecomposition() ||
|
||||
a.getCrtDecomposition() != result.getCrtDecomposition()) {
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "crt");
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -137,6 +146,10 @@ mlir::LogicalResult verifyUnaryGLWEOperator(Operator &op) {
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "p");
|
||||
return mlir::failure();
|
||||
}
|
||||
if (a.getCrtDecomposition() != result.getCrtDecomposition()) {
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "crt");
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
@@ -9,28 +9,35 @@ namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace TFHE {
|
||||
|
||||
void printSigned(mlir::AsmPrinter &p, signed i) {
|
||||
if (i == -1)
|
||||
p << "_";
|
||||
else
|
||||
p << i;
|
||||
}
|
||||
|
||||
void GLWECipherTextType::print(mlir::AsmPrinter &p) const {
|
||||
p << "<{";
|
||||
if (getDimension() == -1)
|
||||
p << "_";
|
||||
else
|
||||
p << getDimension();
|
||||
p << ",";
|
||||
if (getPolynomialSize() == -1)
|
||||
p << "_";
|
||||
else
|
||||
p << getPolynomialSize();
|
||||
p << ",";
|
||||
if (getBits() == -1)
|
||||
p << "_";
|
||||
else
|
||||
p << getBits();
|
||||
p << "}";
|
||||
p << "<";
|
||||
auto crt = getCrtDecomposition();
|
||||
if (!crt.empty()) {
|
||||
p << "crt=[";
|
||||
for (auto c : crt.drop_back(1)) {
|
||||
printSigned(p, c);
|
||||
p << ",";
|
||||
}
|
||||
printSigned(p, crt.back());
|
||||
p << "]";
|
||||
}
|
||||
p << "{";
|
||||
if (getP() == -1)
|
||||
p << "_";
|
||||
else
|
||||
p << getP();
|
||||
printSigned(p, getDimension());
|
||||
p << ",";
|
||||
printSigned(p, getPolynomialSize());
|
||||
p << ",";
|
||||
printSigned(p, getBits());
|
||||
p << "}";
|
||||
|
||||
p << "{";
|
||||
printSigned(p, getP());
|
||||
p << "}>";
|
||||
}
|
||||
|
||||
@@ -38,9 +45,31 @@ mlir::Type GLWECipherTextType::parse(AsmParser &parser) {
|
||||
if (parser.parseLess())
|
||||
return mlir::Type();
|
||||
|
||||
// First parameters block
|
||||
// Parse for the crt decomposition if any
|
||||
std::vector<int64_t> crtDecomposition;
|
||||
if (!parser.parseOptionalKeyword("crt")) {
|
||||
if (parser.parseEqual() || parser.parseLSquare())
|
||||
return mlir::Type();
|
||||
while (true) {
|
||||
signed c = -1;
|
||||
if (parser.parseOptionalKeyword("_") && parser.parseInteger(c)) {
|
||||
return mlir::Type();
|
||||
}
|
||||
crtDecomposition.push_back(c);
|
||||
if (parser.parseOptionalComma()) {
|
||||
if (parser.parseRSquare()) {
|
||||
return mlir::Type();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (parser.parseLBrace())
|
||||
return mlir::Type();
|
||||
|
||||
// First parameters block
|
||||
int dimension = -1;
|
||||
if (parser.parseOptionalKeyword("_") && parser.parseInteger(dimension))
|
||||
return mlir::Type();
|
||||
@@ -69,7 +98,8 @@ mlir::Type GLWECipherTextType::parse(AsmParser &parser) {
|
||||
if (parser.parseGreater())
|
||||
return mlir::Type();
|
||||
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
|
||||
return getChecked(loc, loc.getContext(), dimension, polynomialSize, bits, p);
|
||||
return getChecked(loc, loc.getContext(), dimension, polynomialSize, bits, p,
|
||||
llvm::ArrayRef<int64_t>(crtDecomposition));
|
||||
}
|
||||
} // namespace TFHE
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -26,10 +26,11 @@ target_link_libraries(
|
||||
ConcretelangRuntime
|
||||
PUBLIC
|
||||
Concrete
|
||||
ConcretelangClientLib
|
||||
|
||||
pthread m dl
|
||||
$<TARGET_OBJECTS:mlir_c_runner_utils>
|
||||
)
|
||||
|
||||
install(TARGETS ConcretelangRuntime omp EXPORT ConcretelangRuntime)
|
||||
install(EXPORT ConcretelangRuntime DESTINATION "./")
|
||||
|
||||
|
||||
@@ -3,11 +3,13 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/Runtime/wrappers.h"
|
||||
#include <assert.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "concretelang/ClientLib/CRT.h"
|
||||
#include "concretelang/Runtime/wrappers.h"
|
||||
|
||||
void memref_expand_lut_in_trivial_glwe_ct_u64(
|
||||
uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
@@ -98,6 +100,10 @@ void memref_bootstrap_lwe_u64(
|
||||
glwe_ct_aligned + glwe_ct_offset);
|
||||
}
|
||||
|
||||
uint64_t encode_crt(int64_t plaintext, uint64_t modulus, uint64_t product) {
|
||||
return concretelang::clientlib::crt::encode(plaintext, modulus, product);
|
||||
}
|
||||
|
||||
void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned,
|
||||
uint64_t src_offset, uint64_t src_size,
|
||||
uint64_t src_stride, uint64_t *dst_allocated,
|
||||
|
||||
@@ -38,171 +38,166 @@ TensorData multi_arity_call_dynamic_rank(void *(*func)(void *...),
|
||||
using concretelang::clientlib::MemRefDescriptor;
|
||||
constexpr auto convert = concretelang::clientlib::tensorDataFromMemRef;
|
||||
switch (rank) {
|
||||
case 0: {
|
||||
case 1: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<1> (*)(void *...)>(func), args);
|
||||
return convert(1, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 1: {
|
||||
case 2: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<2> (*)(void *...)>(func), args);
|
||||
return convert(2, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 2: {
|
||||
case 3: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<3> (*)(void *...)>(func), args);
|
||||
return convert(3, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 3: {
|
||||
case 4: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<4> (*)(void *...)>(func), args);
|
||||
return convert(4, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 4: {
|
||||
case 5: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<5> (*)(void *...)>(func), args);
|
||||
return convert(5, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 5: {
|
||||
case 6: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<6> (*)(void *...)>(func), args);
|
||||
return convert(6, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 6: {
|
||||
case 7: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<7> (*)(void *...)>(func), args);
|
||||
return convert(7, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 7: {
|
||||
case 8: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<8> (*)(void *...)>(func), args);
|
||||
return convert(8, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 8: {
|
||||
case 9: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<9> (*)(void *...)>(func), args);
|
||||
return convert(9, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 9: {
|
||||
case 10: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<10> (*)(void *...)>(func), args);
|
||||
return convert(10, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 10: {
|
||||
case 11: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<11> (*)(void *...)>(func), args);
|
||||
return convert(11, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 11: {
|
||||
case 12: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<12> (*)(void *...)>(func), args);
|
||||
return convert(12, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 12: {
|
||||
case 13: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<13> (*)(void *...)>(func), args);
|
||||
return convert(13, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 13: {
|
||||
case 14: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<14> (*)(void *...)>(func), args);
|
||||
return convert(14, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 14: {
|
||||
case 15: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<15> (*)(void *...)>(func), args);
|
||||
return convert(15, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 15: {
|
||||
case 16: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<16> (*)(void *...)>(func), args);
|
||||
return convert(16, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 16: {
|
||||
case 17: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<17> (*)(void *...)>(func), args);
|
||||
return convert(17, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 17: {
|
||||
case 18: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<18> (*)(void *...)>(func), args);
|
||||
return convert(18, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 18: {
|
||||
case 19: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<19> (*)(void *...)>(func), args);
|
||||
return convert(19, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 19: {
|
||||
case 20: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<20> (*)(void *...)>(func), args);
|
||||
return convert(20, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 20: {
|
||||
case 21: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<21> (*)(void *...)>(func), args);
|
||||
return convert(21, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 21: {
|
||||
case 22: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<22> (*)(void *...)>(func), args);
|
||||
return convert(22, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 22: {
|
||||
case 23: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<23> (*)(void *...)>(func), args);
|
||||
return convert(23, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 23: {
|
||||
case 24: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<24> (*)(void *...)>(func), args);
|
||||
return convert(24, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 24: {
|
||||
case 25: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<25> (*)(void *...)>(func), args);
|
||||
return convert(25, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 25: {
|
||||
case 26: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<26> (*)(void *...)>(func), args);
|
||||
return convert(26, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 26: {
|
||||
case 27: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<27> (*)(void *...)>(func), args);
|
||||
return convert(27, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 27: {
|
||||
case 28: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<28> (*)(void *...)>(func), args);
|
||||
return convert(28, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 28: {
|
||||
case 29: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<29> (*)(void *...)>(func), args);
|
||||
return convert(29, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 29: {
|
||||
case 30: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<30> (*)(void *...)>(func), args);
|
||||
return convert(30, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 30: {
|
||||
case 31: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<31> (*)(void *...)>(func), args);
|
||||
return convert(31, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 31: {
|
||||
case 32: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<32> (*)(void *...)>(func), args);
|
||||
return convert(32, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 32: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<33> (*)(void *...)>(func), args);
|
||||
return convert(33, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
}
|
||||
|
||||
default:
|
||||
assert(false);
|
||||
|
||||
@@ -69,12 +69,6 @@ ServerLambda::load(std::string funcName, std::string outputPath) {
|
||||
return ServerLambda::loadFromModule(module, funcName);
|
||||
}
|
||||
|
||||
TensorData dynamicCall(void *(*func)(void *...),
|
||||
std::vector<void *> &preparedArgs, CircuitGate &output) {
|
||||
size_t rank = output.shape.dimensions.size();
|
||||
return multi_arity_call_dynamic_rank(func, preparedArgs, rank);
|
||||
}
|
||||
|
||||
std::unique_ptr<clientlib::PublicResult>
|
||||
ServerLambda::call(PublicArguments &args, EvaluationKeys &evaluationKeys) {
|
||||
std::vector<void *> preparedArgs(args.preparedArgs.begin(),
|
||||
@@ -84,9 +78,12 @@ ServerLambda::call(PublicArguments &args, EvaluationKeys &evaluationKeys) {
|
||||
runtimeContext.evaluationKeys = evaluationKeys;
|
||||
preparedArgs.push_back((void *)&runtimeContext);
|
||||
|
||||
return clientlib::PublicResult::fromBuffers(
|
||||
clientParameters,
|
||||
{dynamicCall(this->func, preparedArgs, clientParameters.outputs[0])});
|
||||
assert(clientParameters.outputs.size() == 1 &&
|
||||
"ServerLambda::call is implemented for only one output");
|
||||
auto output = args.clientParameters.outputs[0];
|
||||
auto rank = args.clientParameters.bufferShape(output).size();
|
||||
auto result = multi_arity_call_dynamic_rank(func, preparedArgs, rank);
|
||||
return clientlib::PublicResult::fromBuffers(clientParameters, {result});
|
||||
;
|
||||
}
|
||||
|
||||
|
||||
@@ -43,8 +43,8 @@ TensorData multi_arity_call_dynamic_rank(void *(*func)(void *...),
|
||||
constexpr auto convert = concretelang::clientlib::tensorDataFromMemRef;
|
||||
switch (rank) {""")
|
||||
|
||||
for tensor_rank in range(0, 33):
|
||||
memref_rank = tensor_rank + 1
|
||||
for tensor_rank in range(1, 33):
|
||||
memref_rank = tensor_rank
|
||||
print(f""" case {tensor_rank}: {{
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<{memref_rank}> (*)(void *...)>(func), args);
|
||||
|
||||
@@ -166,30 +166,39 @@ CompilerEngine::getConcreteOptimizerDescription(CompilationResult &res) {
|
||||
return std::move(descriptions->begin()->second);
|
||||
}
|
||||
|
||||
/// set the fheContext field if the v0Constraint can be computed
|
||||
/// set the fheContext field if the v0Constraint can be computed
|
||||
llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
|
||||
auto descrOrErr = getConcreteOptimizerDescription(res);
|
||||
if (auto err = descrOrErr.takeError()) {
|
||||
return err;
|
||||
}
|
||||
// The function is non-crypto and without constraint override
|
||||
if (!descrOrErr.get().hasValue()) {
|
||||
if (compilerOptions.v0Parameter.hasValue()) {
|
||||
// parameters come from the compiler options
|
||||
V0Parameter v0Params = compilerOptions.v0Parameter.value();
|
||||
if (compilerOptions.largeIntegerParameter.hasValue()) {
|
||||
v0Params.largeInteger = compilerOptions.largeIntegerParameter;
|
||||
}
|
||||
res.fheContext.emplace(mlir::concretelang::V0FHEContext{{0, 0}, v0Params});
|
||||
return llvm::Error::success();
|
||||
}
|
||||
auto descr = std::move(descrOrErr.get().getValue());
|
||||
auto config = this->compilerOptions.optimizerConfig;
|
||||
|
||||
auto fheParams = (compilerOptions.v0Parameter.hasValue())
|
||||
? compilerOptions.v0Parameter
|
||||
: getParameter(descr, config);
|
||||
if (!fheParams) {
|
||||
return StreamStringError()
|
||||
<< "Could not determine V0 parameters for 2-norm of "
|
||||
<< (*descrOrErr)->constraint.norm2 << " and p of "
|
||||
<< (*descrOrErr)->constraint.p;
|
||||
// compute parameters
|
||||
else {
|
||||
auto descr = getConcreteOptimizerDescription(res);
|
||||
if (auto err = descr.takeError()) {
|
||||
return err;
|
||||
}
|
||||
if (!descr.get().hasValue()) {
|
||||
return llvm::Error::success();
|
||||
}
|
||||
auto optV0Params =
|
||||
getParameter(descr.get().value(), compilerOptions.optimizerConfig);
|
||||
if (!optV0Params) {
|
||||
return StreamStringError()
|
||||
<< "Could not determine V0 parameters for 2-norm of "
|
||||
<< (*descr)->constraint.norm2 << " and p of "
|
||||
<< (*descr)->constraint.p;
|
||||
}
|
||||
res.fheContext.emplace(mlir::concretelang::V0FHEContext{
|
||||
descr.get().value().constraint, optV0Params.getValue()});
|
||||
}
|
||||
res.fheContext.emplace(
|
||||
mlir::concretelang::V0FHEContext{descr.constraint, fheParams.getValue()});
|
||||
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
@@ -282,7 +291,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
|
||||
// FHE -> TFHE
|
||||
if (mlir::concretelang::pipeline::lowerFHEToTFHE(mlirContext, module,
|
||||
enablePass)
|
||||
res.fheContext, enablePass)
|
||||
.failed()) {
|
||||
return errorDiag("Lowering from FHE to TFHE failed");
|
||||
}
|
||||
|
||||
@@ -110,21 +110,13 @@ JITLambda::call(clientlib::PublicArguments &args,
|
||||
// Prepare the outputs vector to store the output value of the lambda.
|
||||
auto numOutputs = 0;
|
||||
for (auto &output : args.clientParameters.outputs) {
|
||||
if (output.shape.dimensions.empty()) {
|
||||
auto shape = args.clientParameters.bufferShape(output);
|
||||
if (shape.size() == 0) {
|
||||
// scalar gate
|
||||
if (output.encryption.hasValue()) {
|
||||
// encrypted scalar : memref<lweSizexi64>
|
||||
numOutputs += numArgOfRankedMemrefCallingConvention(1);
|
||||
} else {
|
||||
// clear scalar
|
||||
numOutputs += 1;
|
||||
}
|
||||
numOutputs += 1;
|
||||
} else {
|
||||
// memref gate : rank+1 if the output is encrypted for the lwe size
|
||||
// dimension
|
||||
auto rank = output.shape.dimensions.size() +
|
||||
(output.encryption.hasValue() ? 1 : 0);
|
||||
numOutputs += numArgOfRankedMemrefCallingConvention(rank);
|
||||
// buffer gate
|
||||
numOutputs += numArgOfRankedMemrefCallingConvention(shape.size());
|
||||
}
|
||||
}
|
||||
std::vector<void *> outputs(numOutputs);
|
||||
@@ -148,7 +140,6 @@ JITLambda::call(clientlib::PublicArguments &args,
|
||||
for (auto &out : outputs) {
|
||||
rawArgs[i++] = &out;
|
||||
}
|
||||
|
||||
// Invoke
|
||||
if (auto err = invokeRaw(rawArgs)) {
|
||||
return std::move(err);
|
||||
@@ -159,14 +150,14 @@ JITLambda::call(clientlib::PublicArguments &args,
|
||||
{
|
||||
size_t outputOffset = 0;
|
||||
for (auto &output : args.clientParameters.outputs) {
|
||||
if (output.shape.dimensions.empty() && !output.encryption.hasValue()) {
|
||||
// clear scalar
|
||||
auto shape = args.clientParameters.bufferShape(output);
|
||||
if (shape.size() == 0) {
|
||||
// scalar scalar
|
||||
buffers.push_back(
|
||||
clientlib::tensorDataFromScalar((uint64_t)outputs[outputOffset++]));
|
||||
} else {
|
||||
// encrypted scalar, and tensor gate are memref
|
||||
auto rank = output.shape.dimensions.size() +
|
||||
(output.encryption.hasValue() ? 1 : 0);
|
||||
// buffer gate
|
||||
auto rank = shape.size();
|
||||
auto allocated = (uint64_t *)outputs[outputOffset++];
|
||||
auto aligned = (uint64_t *)outputs[outputOffset++];
|
||||
auto offset = (size_t)outputs[outputOffset++];
|
||||
|
||||
@@ -182,6 +182,7 @@ markFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("FHEToTFHE", pm, context);
|
||||
@@ -192,8 +193,14 @@ lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
// to linalg.generic operations
|
||||
addPotentiallyNestedPass(pm, mlir::createLinalgGeneralizationPass(),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::concretelang::createConvertFHEToTFHEPass(),
|
||||
enablePass);
|
||||
mlir::concretelang::ApplyLookupTableLowering lowerStrategy =
|
||||
mlir::concretelang::KeySwitchBoostrapLowering;
|
||||
if (fheContext.hasValue() && fheContext->parameter.largeInteger.hasValue()) {
|
||||
lowerStrategy = mlir::concretelang::WopPBSLowering;
|
||||
}
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertFHEToTFHEPass(lowerStrategy),
|
||||
enablePass);
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
@@ -260,6 +267,8 @@ lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("BConcreteToStd", pm, context);
|
||||
addPotentiallyNestedPass(pm, mlir::concretelang::createEliminateCRTOps(),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::concretelang::createAddRuntimeContext(),
|
||||
enablePass);
|
||||
return pm.run(module.getOperation());
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <llvm/ADT/Optional.h>
|
||||
#include <llvm/ADT/STLExtras.h>
|
||||
#include <llvm/Support/Error.h>
|
||||
|
||||
@@ -12,6 +13,7 @@
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "concretelang/Support/V0Curves.h"
|
||||
|
||||
namespace mlir {
|
||||
@@ -20,6 +22,7 @@ namespace concretelang {
|
||||
using ::concretelang::clientlib::BIG_KEY;
|
||||
using ::concretelang::clientlib::CircuitGate;
|
||||
using ::concretelang::clientlib::ClientParameters;
|
||||
using ::concretelang::clientlib::Encoding;
|
||||
using ::concretelang::clientlib::EncryptionGate;
|
||||
using ::concretelang::clientlib::LweSecretKeyID;
|
||||
using ::concretelang::clientlib::Precision;
|
||||
@@ -52,23 +55,21 @@ llvm::Expected<CircuitGate> gateFromMLIRType(LweSecretKeyID secretKeyID,
|
||||
},
|
||||
};
|
||||
}
|
||||
if (auto lweType = type.dyn_cast_or_null<
|
||||
mlir::concretelang::Concrete::LweCiphertextType>()) {
|
||||
// TODO - Get the width from the LWECiphertextType instead of global
|
||||
// precision (could be possible after merge concrete-ciphertext-parameter)
|
||||
size_t precision = (size_t)lweType.getP();
|
||||
if (auto lweTy = type.dyn_cast_or_null<
|
||||
mlir::concretelang::Concrete::LweCiphertextType>()) {
|
||||
return CircuitGate{
|
||||
/* .encryption = */ llvm::Optional<EncryptionGate>({
|
||||
/* .secretKeyID = */ secretKeyID,
|
||||
/* .variance = */ variance,
|
||||
/* .encoding = */
|
||||
{
|
||||
/* .precision = */ precision,
|
||||
/* .precision = */ lweTy.getP(),
|
||||
/* .crt = */ lweTy.getCrtDecomposition().vec(),
|
||||
},
|
||||
}),
|
||||
/*.shape = */
|
||||
{
|
||||
/*.width = */ precision,
|
||||
/*.width = */ lweTy.getP(),
|
||||
/*.dimensions = */ std::vector<int64_t>(),
|
||||
/*.size = */ 0,
|
||||
},
|
||||
@@ -99,6 +100,7 @@ createClientParametersForV0(V0FHEContext fheContext,
|
||||
auto v0Param = fheContext.parameter;
|
||||
Variance encryptionVariance = v0Curve->getVariance(
|
||||
v0Param.glweDimension, v0Param.getPolynomialSize(), 64);
|
||||
// Variance encryptionVariance = 0.;
|
||||
Variance keyswitchVariance = v0Curve->getVariance(1, v0Param.nSmall, 64);
|
||||
// Static client parameters from global parameters for v0
|
||||
ClientParameters c;
|
||||
@@ -138,10 +140,9 @@ createClientParametersForV0(V0FHEContext fheContext,
|
||||
return op.getName() == functionName;
|
||||
});
|
||||
if (funcOp == rangeOps.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot find the function for generate client parameters '" +
|
||||
functionName + "'",
|
||||
llvm::inconvertibleErrorCode());
|
||||
return StreamStringError(
|
||||
"cannot find the function for generate client parameters: ")
|
||||
<< functionName;
|
||||
}
|
||||
|
||||
// Create input and output circuit gate parameters
|
||||
|
||||
@@ -170,7 +170,8 @@ llvm::cl::opt<std::string>
|
||||
llvm::cl::list<uint64_t>
|
||||
jitArgs("jit-args",
|
||||
llvm::cl::desc("Value of arguments to pass to the main func"),
|
||||
llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore);
|
||||
llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore,
|
||||
llvm::cl::MiscFlags::CommaSeparated);
|
||||
|
||||
llvm::cl::opt<std::string> jitKeySetCachePath(
|
||||
"jit-keyset-cache-path",
|
||||
@@ -210,6 +211,31 @@ llvm::cl::list<int64_t> v0Parameter(
|
||||
"logPolynomialSize, nSmall, brLevel, brLobBase, ksLevel, ksLogBase]"),
|
||||
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
|
||||
|
||||
llvm::cl::list<int64_t> largeIntegerCRTDecomposition(
|
||||
"large-integer-crt-decomposition",
|
||||
llvm::cl::desc(
|
||||
"Use the large integer to lower FHE.eint with the given decomposition, "
|
||||
"must be used with the other large-integers options (experimental)"),
|
||||
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
|
||||
|
||||
llvm::cl::list<int64_t> largeIntegerPackingKeyswitch(
|
||||
"large-integer-packing-keyswitch",
|
||||
llvm::cl::desc(
|
||||
"Use the large integer to lower FHE.eint with the given parameters for "
|
||||
"packing keyswitch, must be used with the other large-integers options "
|
||||
"(experimental) [inputLweDimension, inputLweCount, "
|
||||
"outputPolynomialSize, level, baseLog]"),
|
||||
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
|
||||
|
||||
llvm::cl::list<int64_t> largeIntegerCircuitBootstrap(
|
||||
"large-integer-circuit-bootstrap",
|
||||
llvm::cl::desc(
|
||||
"Use the large integer to lower FHE.eint with the given parameters for "
|
||||
"the cicuit boostrap, must be used with the other large-integers "
|
||||
"options "
|
||||
"(experimental) [level, baseLog]"),
|
||||
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
|
||||
|
||||
} // namespace cmdline
|
||||
|
||||
namespace llvm {
|
||||
@@ -257,6 +283,7 @@ cmdlineCompilationOptions() {
|
||||
if (!cmdline::fhelinalgTileSizes.empty())
|
||||
options.fhelinalgTileSizes.emplace(cmdline::fhelinalgTileSizes);
|
||||
|
||||
// Setup the v0 parameter options
|
||||
if (!cmdline::v0Parameter.empty()) {
|
||||
if (cmdline::v0Parameter.size() != 7) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
@@ -270,10 +297,44 @@ cmdlineCompilationOptions() {
|
||||
cmdline::v0Parameter[6]);
|
||||
}
|
||||
|
||||
if (!cmdline::v0Constraint.empty() && !cmdline::optimizerV0) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"You must use --v0-constraint with --optimizer-v0-strategy",
|
||||
llvm::inconvertibleErrorCode());
|
||||
// Setup the large integer options
|
||||
if (!cmdline::largeIntegerCRTDecomposition.empty() ||
|
||||
!cmdline::largeIntegerPackingKeyswitch.empty() ||
|
||||
!cmdline::largeIntegerPackingKeyswitch.empty()) {
|
||||
if (cmdline::largeIntegerCRTDecomposition.empty() ||
|
||||
cmdline::largeIntegerPackingKeyswitch.empty() ||
|
||||
cmdline::largeIntegerPackingKeyswitch.empty()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"The large-integers options should all be set",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (cmdline::largeIntegerPackingKeyswitch.size() != 5) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"The large-integers-packing-keyswitch must be a list of 5 integer",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (cmdline::largeIntegerCircuitBootstrap.size() != 2) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"The large-integers-packing-keyswitch must be a list of 2 integer",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
options.largeIntegerParameter = mlir::concretelang::LargeIntegerParameter();
|
||||
options.largeIntegerParameter->crtDecomposition =
|
||||
cmdline::largeIntegerCRTDecomposition;
|
||||
options.largeIntegerParameter->wopPBS.packingKeySwitch.inputLweDimension =
|
||||
cmdline::largeIntegerPackingKeyswitch[0];
|
||||
options.largeIntegerParameter->wopPBS.packingKeySwitch.inputLweCount =
|
||||
cmdline::largeIntegerPackingKeyswitch[1];
|
||||
options.largeIntegerParameter->wopPBS.packingKeySwitch
|
||||
.outputPolynomialSize = cmdline::largeIntegerPackingKeyswitch[2];
|
||||
options.largeIntegerParameter->wopPBS.packingKeySwitch.level =
|
||||
cmdline::largeIntegerPackingKeyswitch[3];
|
||||
options.largeIntegerParameter->wopPBS.packingKeySwitch.baseLog =
|
||||
cmdline::largeIntegerPackingKeyswitch[4];
|
||||
options.largeIntegerParameter->wopPBS.circuitBootstrap.level =
|
||||
cmdline::largeIntegerCircuitBootstrap[0];
|
||||
options.largeIntegerParameter->wopPBS.circuitBootstrap.baseLog =
|
||||
cmdline::largeIntegerCircuitBootstrap[1];
|
||||
}
|
||||
|
||||
options.optimizerConfig.p_error = cmdline::pbsErrorProbability;
|
||||
@@ -502,9 +563,9 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
}
|
||||
|
||||
if (cmdline::action == Action::COMPILE) {
|
||||
auto err =
|
||||
outputLib->emitArtifacts(/*sharedLib=*/true, /*staticLib=*/true,
|
||||
/*clientParameters=*/true, /*cppHeader=*/true);
|
||||
auto err = outputLib->emitArtifacts(
|
||||
/*sharedLib=*/true, /*staticLib=*/true,
|
||||
/*clientParameters=*/true, /*cppHeader=*/true);
|
||||
if (err) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @add_glwe(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
//CHECK: func @add_lwe_ciphertexts(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, tensor<2049xi64>) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
//CHECK: }
|
||||
func.func @add_glwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
func.func @add_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
%0 = "Concrete.add_lwe_ciphertexts"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
return %0 : !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
//CHECK: func @add_crt_lwe_ciphertexts(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x2049xi64>
|
||||
//CHECK: }
|
||||
func.func @add_crt_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>, %arg1: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7> {
|
||||
%0 = "Concrete.add_lwe_ciphertexts"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>, !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>
|
||||
return %0 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>
|
||||
}
|
||||
|
||||
@@ -1,30 +1,40 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
|
||||
//CHECK: func.func @add_glwe_const_int(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {
|
||||
//CHECK: func.func @add_glwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
|
||||
//CHECK: %c1_i8 = arith.constant 1 : i8
|
||||
//CHECK: %0 = arith.extui %c1_i8 : i8 to i64
|
||||
//CHECK: %[[V0:.*]] = arith.extui %c1_i8 : i8 to i64
|
||||
//CHECK: %c56_i64 = arith.constant 56 : i64
|
||||
//CHECK: %1 = arith.shli %0, %c56_i64 : i64
|
||||
//CHECK: %2 = "BConcrete.add_plaintext_lwe_buffer"(%arg0, %1) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: return %2 : tensor<1025xi64>
|
||||
//CHECK: %[[V1:.*]] = arith.shli %[[V0]], %c56_i64 : i64
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_buffer"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: return %[[V2]] : tensor<1025xi64>
|
||||
//CHECK: }
|
||||
func.func @add_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "Concrete.encode_int"(%0) : (i8) -> !Concrete.plaintext<8>
|
||||
%2 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %1) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.plaintext<8>) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
%2 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
return %2 : !Concrete.lwe_ciphertext<1024,7>
|
||||
}
|
||||
|
||||
//CHECK: func.func @add_glwe_int(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64> {
|
||||
//CHECK: %0 = arith.extui %arg1 : i5 to i64
|
||||
//CHECK: func.func @add_glwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i5) -> tensor<1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = arith.extui %[[A1]] : i5 to i64
|
||||
//CHECK: %c59_i64 = arith.constant 59 : i64
|
||||
//CHECK: %1 = arith.shli %0, %c59_i64 : i64
|
||||
//CHECK: %2 = "BConcrete.add_plaintext_lwe_buffer"(%arg0, %1) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: return %2 : tensor<1025xi64>
|
||||
//CHECK: %[[V1:.*]] = arith.shli %[[V0]], %c59_i64 : i64
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_buffer"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: return %[[V2]] : tensor<1025xi64>
|
||||
//CHECK: }
|
||||
func.func @add_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
%0 = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5>
|
||||
%1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
%1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
return %1 : !Concrete.lwe_ciphertext<1024,4>
|
||||
}
|
||||
|
||||
|
||||
//CHECK: func.func @add_plaintext_lwe_ciphertext(%[[A0:.*]]: tensor<5x1025xi64>) -> tensor<5x1025xi64> {
|
||||
//CHECK: %c1_i8 = arith.constant 1 : i8
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_buffer"(%[[A0]], %c1_i8) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i8) -> tensor<5x1025xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x1025xi64>
|
||||
//CHECK: }
|
||||
func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7> {
|
||||
%0 = arith.constant 1 : i8
|
||||
%2 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>, i8) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>
|
||||
return %2 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>
|
||||
}
|
||||
|
||||
@@ -6,3 +6,10 @@
|
||||
func.func @identity(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
|
||||
return %arg0 : !Concrete.lwe_ciphertext<1024,7>
|
||||
}
|
||||
|
||||
// CHECK: func.func @identity_crt(%arg0: tensor<5x1025xi64>) -> tensor<5x1025xi64> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<5x1025xi64>
|
||||
// CHECK-NEXT: }
|
||||
func.func @identity_crt(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7> {
|
||||
return %arg0 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>
|
||||
}
|
||||
|
||||
@@ -1,26 +1,33 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @mul_lwe_const_int(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {
|
||||
//CHECK: func.func @mul_lwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
|
||||
//CHECK: %c1_i8 = arith.constant 1 : i8
|
||||
//CHECK: %0 = arith.extui %c1_i8 : i8 to i64
|
||||
//CHECK: %1 = "BConcrete.mul_cleartext_lwe_buffer"(%arg0, %0) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: return %1 : tensor<1025xi64>
|
||||
//CHECK: %[[V0:.*]] = arith.extui %c1_i8 : i8 to i64
|
||||
//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_buffer"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: return %[[V1]] : tensor<1025xi64>
|
||||
//CHECK: }
|
||||
|
||||
func.func @mul_lwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "Concrete.int_to_cleartext"(%0) : (i8) -> !Concrete.cleartext<8>
|
||||
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.cleartext<8>) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
return %2 : !Concrete.lwe_ciphertext<1024,7>
|
||||
}
|
||||
|
||||
//CHECK: func.func @mul_lwe_int(%arg0: tensor<1025xi64>, %arg1: i5) -> tensor<1025xi64> {
|
||||
//CHECK: %0 = arith.extui %arg1 : i5 to i64
|
||||
//CHECK: %1 = "BConcrete.mul_cleartext_lwe_buffer"(%arg0, %0) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: return %1 : tensor<1025xi64>
|
||||
//CHECK: func.func @mul_lwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i5) -> tensor<1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = arith.extui %[[A1]] : i5 to i64
|
||||
//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_buffer"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: return %[[V1]] : tensor<1025xi64>
|
||||
//CHECK: }
|
||||
func.func @mul_lwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
%0 = "Concrete.int_to_cleartext"(%arg1) : (i5) -> !Concrete.cleartext<5>
|
||||
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.cleartext<5>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
return %1 : !Concrete.lwe_ciphertext<1024,4>
|
||||
}
|
||||
|
||||
|
||||
//CHECK: func.func @mul_cleartext_lwe_ciphertext_crt(%[[A0:.*]]: tensor<5x1025xi64>, %[[A1:.*]]: i5) -> tensor<5x1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i5) -> tensor<5x1025xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x1025xi64>
|
||||
//CHECK: }
|
||||
func.func @mul_cleartext_lwe_ciphertext_crt(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4> {
|
||||
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>, i5) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>
|
||||
return %1 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>
|
||||
}
|
||||
|
||||
@@ -8,3 +8,12 @@ func.func @neg_lwe(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_cip
|
||||
%0 = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
return %0 : !Concrete.lwe_ciphertext<1024,4>
|
||||
}
|
||||
|
||||
//CHECK: func.func @negate_lwe_ciphertext(%[[A0:.*]]: tensor<5x1025xi64>) -> tensor<5x1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_buffer"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>) -> tensor<5x1025xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x1025xi64>
|
||||
//CHECK: }
|
||||
func.func @negate_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4> {
|
||||
%0 = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>
|
||||
return %0 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>
|
||||
}
|
||||
|
||||
@@ -1,24 +1,20 @@
|
||||
// RUN: concretecompiler --split-input-file --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func
|
||||
// DISABLED-CHECK: func.func @tensor_collapse_shape(%arg0: tensor<2x3x4x5x6x1025xi64>) -> tensor<720x1025xi64> {
|
||||
// DISABLED-CHECK-NEXT: %0 = bufferization.to_memref %arg0 : memref<2x3x4x5x6x1025xi64>
|
||||
// DISABLED-CHECK-NEXT: %1 = memref.collapse_shape %0 [[_:\[\[0, 1, 2, 3, 4\], \[5\]\]]] : memref<2x3x4x5x6x1025xi64> into memref<720x1025xi64>
|
||||
// DISABLED-CHECK-NEXT: %2 = bufferization.to_tensor %1 : memref<720x1025xi64>
|
||||
// DISABLED-CHECK-NEXT: return %2 : tensor<720x1025xi64>
|
||||
//CHECK: func.func @tensor_collapse_shape(%[[A0:.*]]: tensor<2x3x4x5x6x1025xi64>) -> tensor<720x1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = tensor.collapse_shape %[[A0]] [[_:\[\[0, 1, 2, 3, 4\], \[5\]\]]] : tensor<2x3x4x5x6x1025xi64> into tensor<720x1025xi64>
|
||||
//CHECK: return %[[V0]] : tensor<720x1025xi64>
|
||||
//CHECK: }
|
||||
func.func @tensor_collapse_shape(%arg0: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<720x!Concrete.lwe_ciphertext<1024,4>> {
|
||||
%0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4]] {MANP = 1 : ui1}: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext<1024,4>> into tensor<720x!Concrete.lwe_ciphertext<1024,4>>
|
||||
return %0 : tensor<720x!Concrete.lwe_ciphertext<1024,4>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// DISABLED-CHECK: func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x5x1025xi64>) -> tensor<5x6x1025xi64> {
|
||||
// DISABLED-CHECK-NEXT: %0 = bufferization.to_memref %arg0 : memref<2x3x5x1025xi64>
|
||||
// DISABLED-CHECK-NEXT: %1 = memref.collapse_shape %0 [[_:\[\[0, 1, 2\], \[3\]\]]] : memref<2x3x5x1025xi64> into memref<30x1025xi64>
|
||||
// DISABLED-CHECK-NEXT: %2 = memref.expand_shape %1 [[_:\[\[0, 1\], \[2\]\]]] : memref<30x1025xi64> into memref<5x6x1025xi64>
|
||||
// DISABLED-CHECK-NEXT: %3 = bufferization.to_tensor %2 : memref<5x6x1025xi64>
|
||||
// DISABLED-CHECK-NEXT: return %3 : tensor<5x6x1025xi64>
|
||||
//CHECK: func.func @tensor_collatenspse_shape(%[[A0:.*]]: tensor<2x3x5x1025xi64>) -> tensor<5x6x1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = tensor.collapse_shape %[[A0]] [[_:\[\[0, 1, 2\], \[3\]\]]] : tensor<2x3x5x1025xi64> into tensor<30x1025xi64>
|
||||
//CHECK: %[[V1:.*]] = tensor.expand_shape %[[V0]] [[_:\[\[0, 1\], \[2\]\]]] : tensor<30x1025xi64> into tensor<5x6x1025xi64>
|
||||
//CHECK: return %[[V1]] : tensor<5x6x1025xi64>
|
||||
//CHECK: }
|
||||
func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x5x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<5x6x!Concrete.lwe_ciphertext<1024,4>> {
|
||||
%0 = tensor.collapse_shape %arg0 [[0, 1, 2]] {MANP = 1 : ui1}: tensor<2x3x5x!Concrete.lwe_ciphertext<1024,4>> into tensor<30x!Concrete.lwe_ciphertext<1024,4>>
|
||||
%1 = tensor.expand_shape %0 [[0, 1]] {MANP = 1 : ui1}: tensor<30x!Concrete.lwe_ciphertext<1024,4>> into tensor<5x6x!Concrete.lwe_ciphertext<1024,4>>
|
||||
@@ -26,13 +22,31 @@ func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x5x!Concrete.lwe_ciphertex
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// DISABLED-CHECK: func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x2x3x4x1025xi64>) -> tensor<6x2x12x1025xi64> {
|
||||
// DISABLED-CHECK-NEXT: %0 = bufferization.to_memref %arg0 : memref<2x3x2x3x4x1025xi64>
|
||||
// DISABLED-CHECK-NEXT: %1 = memref.collapse_shape %0 [[_:\[\[0, 1\], \[2\], \[3, 4\], \[5\]\]]] : memref<2x3x2x3x4x1025xi64> into memref<6x2x12x1025xi64>
|
||||
// DISABLED-CHECK-NEXT: %2 = bufferization.to_tensor %1 : memref<6x2x12x1025xi64>
|
||||
// DISABLED-CHECK-NEXT: return %2 : tensor<6x2x12x1025xi64>
|
||||
//CHECK: func.func @tensor_collatenspse_shape(%[[A0:.*]]: tensor<2x3x2x3x4x1025xi64>) -> tensor<6x2x12x1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = tensor.collapse_shape %[[A0]] [[_:\[\[0, 1\], \[2\], \[3, 4\], \[5\]\]]] : tensor<2x3x2x3x4x1025xi64> into tensor<6x2x12x1025xi64>
|
||||
//CHECK: return %[[V0]] : tensor<6x2x12x1025xi64>
|
||||
//CHECK: }
|
||||
func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x2x3x4x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<6x2x12x!Concrete.lwe_ciphertext<1024,4>> {
|
||||
%0 = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] {MANP = 1 : ui1}: tensor<2x3x2x3x4x!Concrete.lwe_ciphertext<1024,4>> into tensor<6x2x12x!Concrete.lwe_ciphertext<1024,4>>
|
||||
return %0 : tensor<6x2x12x!Concrete.lwe_ciphertext<1024,4>>
|
||||
}
|
||||
|
||||
// -----
|
||||
//CHECK: func.func @tensor_collapse_shape_crt(%[[A0:.*]]: tensor<2x3x4x5x6x5x1025xi64>) -> tensor<720x5x1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = tensor.collapse_shape %[[A0]] [[_:\[\[0, 1, 2, 3, 4\], \[5\], \[6\]\]]] : tensor<2x3x4x5x6x5x1025xi64> into tensor<720x5x1025xi64>
|
||||
//CHECK: return %[[V0]] : tensor<720x5x1025xi64>
|
||||
//CHECK: }
|
||||
func.func @tensor_collapse_shape_crt(%arg0: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>>) -> tensor<720x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>> {
|
||||
%0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4]] {MANP = 1 : ui1}: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>> into tensor<720x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>>
|
||||
return %0 : tensor<720x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>>
|
||||
}
|
||||
|
||||
// -----
|
||||
//CHECK: func.func @tensor_expand_shape_crt(%[[A0:.*]]: tensor<30x1025xi64>) -> tensor<5x6x1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = tensor.expand_shape %[[A0]] [[_:\[\[0, 1\], \[2\]\]]] : tensor<30x1025xi64> into tensor<5x6x1025xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x6x1025xi64>
|
||||
//CHECK: }
|
||||
func.func @tensor_expand_shape_crt(%arg0: tensor<30x!Concrete.lwe_ciphertext<1024,4>>) -> tensor<5x6x!Concrete.lwe_ciphertext<1024,4>> {
|
||||
%0 = tensor.expand_shape %arg0 [[0, 1]] {MANP = 1 : ui1}: tensor<30x!Concrete.lwe_ciphertext<1024,4>> into tensor<5x6x!Concrete.lwe_ciphertext<1024,4>>
|
||||
return %0 : tensor<5x6x!Concrete.lwe_ciphertext<1024,4>>
|
||||
}
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @tensor_identity(%arg0: tensor<2x3x4x1025xi64>) -> tensor<2x3x4x1025xi64> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x3x4x1025xi64>
|
||||
// CHECK-NEXT: }
|
||||
func.func @tensor_identity(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>> {
|
||||
return %arg0 : tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>>
|
||||
}
|
||||
|
||||
// CHECK: func.func @tensor_identity_crt(%arg0: tensor<2x3x4x5x1025xi64>) -> tensor<2x3x4x5x1025xi64> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x3x4x5x1025xi64>
|
||||
// CHECK-NEXT: }
|
||||
func.func @tensor_identity_crt(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>> {
|
||||
return %arg0 : tensor<2x3x4x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>>
|
||||
}
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @add_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
//CHECK: func.func @add_glwe_const_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
|
||||
//CHECK: %c1_i8 = arith.constant 1 : i8
|
||||
//CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[A0]], %c1_i8) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,7>
|
||||
//CHECK: }
|
||||
func.func @add_glwe_const_int(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{7}> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %[[V2:.*]] = "Concrete.encode_int"(%[[V1]]) : (i8) -> !Concrete.plaintext<8>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %[[V2]]) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.plaintext<8>) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
// CHECK-NEXT: return %[[V3]] : !Concrete.lwe_ciphertext<1024,7>
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "TFHE.add_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,1,64}{7}>, i8) -> (!TFHE.glwe<{1024,1,64}{7}>)
|
||||
return %1: !TFHE.glwe<{1024,1,64}{7}>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func.func @add_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: func.func @add_glwe_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i5) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
//CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[A0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: }
|
||||
func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !TFHE.glwe<{1024,1,64}{4}> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %[[V1]]) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: return %[[V2]] : !Concrete.lwe_ciphertext<1024,4>
|
||||
%1 = "TFHE.add_glwe_int"(%arg0, %arg1): (!TFHE.glwe<{1024,1,64}{4}>, i5) -> (!TFHE.glwe<{1024,1,64}{4}>)
|
||||
return %1: !TFHE.glwe<{1024,1,64}{4}>
|
||||
}
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @mul_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
//CHECK: func.func @mul_glwe_const_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
|
||||
//CHECK: %c1_i8 = arith.constant 1 : i8
|
||||
//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%[[A0]], %c1_i8) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,7>
|
||||
//CHECK: }
|
||||
func.func @mul_glwe_const_int(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{7}> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %[[V2:.*]] = "Concrete.int_to_cleartext"(%[[V1]]) : (i8) -> !Concrete.cleartext<8>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %[[V2]]) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.cleartext<8>) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
// CHECK-NEXT: return %[[V3]] : !Concrete.lwe_ciphertext<1024,7>
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,1,64}{7}>, i8) -> (!TFHE.glwe<{1024,1,64}{7}>)
|
||||
return %1: !TFHE.glwe<{1024,1,64}{7}>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func.func @mul_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: func.func @mul_glwe_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i5) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%[[A0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: }
|
||||
func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !TFHE.glwe<{1024,1,64}{4}> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.int_to_cleartext"(%arg1) : (i5) -> !Concrete.cleartext<5>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %[[V1]]) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.cleartext<5>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: return %[[V2]] : !Concrete.lwe_ciphertext<1024,4>
|
||||
%1 = "TFHE.mul_glwe_int"(%arg0, %arg1): (!TFHE.glwe<{1024,1,64}{4}>, i5) -> (!TFHE.glwe<{1024,1,64}{4}>)
|
||||
return %1: !TFHE.glwe<{1024,1,64}{4}>
|
||||
}
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @sub_const_int_glwe(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
//CHECK: func.func @sub_const_int_glwe(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
|
||||
//CHECK: %c1_i8 = arith.constant 1 : i8
|
||||
//CHECK: %[[V0:.*]] = "Concrete.negate_lwe_ciphertext"(%[[A0]]) : (!Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
//CHECK: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[V0]], %c1_i8) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
//CHECK: return %[[V1]] : !Concrete.lwe_ciphertext<1024,7>
|
||||
//CHECK: }
|
||||
func.func @sub_const_int_glwe(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{7}> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %[[NEG:.*]] = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "Concrete.encode_int"(%[[V1]]) : (i8) -> !Concrete.plaintext<8>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[NEG]], %[[V2]]) : (!Concrete.lwe_ciphertext<1024,7>, !Concrete.plaintext<8>) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
// CHECK-NEXT: return %[[V3]] : !Concrete.lwe_ciphertext<1024,7>
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "TFHE.sub_int_glwe"(%0, %arg0): (i8, !TFHE.glwe<{1024,1,64}{7}>) -> (!TFHE.glwe<{1024,1,64}{7}>)
|
||||
return %1: !TFHE.glwe<{1024,1,64}{7}>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @sub_int_glwe(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: func.func @sub_int_glwe(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i5) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
//CHECK: %[[V0:.*]] = "Concrete.negate_lwe_ciphertext"(%[[A0]]) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[V0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: return %[[V1]] : !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: }
|
||||
func.func @sub_int_glwe(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !TFHE.glwe<{1024,1,64}{4}> {
|
||||
// CHECK-NEXT: %[[NEG:.*]] = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.encode_int"(%arg1) : (i5) -> !Concrete.plaintext<5>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[NEG]], %[[V1]]) : (!Concrete.lwe_ciphertext<1024,4>, !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
// CHECK-NEXT: return %[[V2]] : !Concrete.lwe_ciphertext<1024,4>
|
||||
%1 = "TFHE.sub_int_glwe"(%arg1, %arg0): (i5, !TFHE.glwe<{1024,1,64}{4}>) -> (!TFHE.glwe<{1024,1,64}{4}>)
|
||||
return %1: !TFHE.glwe<{1024,1,64}{4}>
|
||||
}
|
||||
|
||||
@@ -9,6 +9,15 @@ func.func @add_lwe_ciphertexts(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>)
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @add_crt_lwe_ciphertexts(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x2049xi64>
|
||||
//CHECK: }
|
||||
func.func @add_crt_lwe_ciphertexts(%arg0: tensor<5x2049xi64>, %arg1: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
|
||||
%0 = "BConcrete.add_crt_lwe_buffer"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> ( tensor<5x2049xi64>)
|
||||
return %0 : tensor<5x2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @add_plaintext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
@@ -18,7 +27,16 @@ func.func @add_plaintext_lwe_ciphertext(%arg0: tensor<2049xi64>, %arg1: i64) ->
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @mul_cleartext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> {
|
||||
//CHECK: func.func @add_plaintext_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: i64) -> tensor<5x2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x2049xi64>
|
||||
//CHECK: }
|
||||
func.func @add_plaintext_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>, %arg1: i64) -> tensor<5x2049xi64> {
|
||||
%0 = "BConcrete.add_plaintext_crt_lwe_buffer"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> ( tensor<5x2049xi64>)
|
||||
return %0 : tensor<5x2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func @mul_cleartext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
//CHECK: }
|
||||
@@ -27,6 +45,15 @@ func.func @mul_cleartext_lwe_ciphertext(%arg0: tensor<2049xi64>, %arg1: i64) ->
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @mul_cleartext_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: i64) -> tensor<5x2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x2049xi64>
|
||||
//CHECK: }
|
||||
func.func @mul_cleartext_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>, %arg1: i64) -> tensor<5x2049xi64> {
|
||||
%0 = "BConcrete.mul_cleartext_crt_lwe_buffer"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> (tensor<5x2049xi64>)
|
||||
return %0 : tensor<5x2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @negate_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.negate_lwe_buffer"(%[[A0]]) : (tensor<2049xi64>) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
@@ -36,6 +63,15 @@ func.func @negate_lwe_ciphertext(%arg0: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @negate_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_buffer"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> tensor<5x2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x2049xi64>
|
||||
//CHECK: }
|
||||
func.func @negate_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
|
||||
%0 = "BConcrete.negate_crt_lwe_buffer"(%arg0) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> (tensor<5x2049xi64>)
|
||||
return %0 : tensor<5x2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @bootstrap_lwe(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<4096xi64>) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[A0]], %[[A1]]) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 1024 : i32} : (tensor<2049xi64>, tensor<4096xi64>) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
// RUN: concretecompiler --optimize-concrete=false --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
//CHECK: func.func @mul_cleartext_lwe_ciphertext_0(%[[A0:.*]]: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
//CHECK: %c0_i7 = arith.constant 0 : i7
|
||||
//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%[[A0]], %c0_i7) : (!Concrete.lwe_ciphertext<2048,7>, i7) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<2048,7>
|
||||
//CHECK: }
|
||||
func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = arith.constant 0 : i7
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.int_to_cleartext"(%[[V0]]) : (i7) -> !Concrete.cleartext<7>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %[[V1]]) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-NEXT: return %[[V2]] : !Concrete.lwe_ciphertext<2048,7>
|
||||
|
||||
%0 = arith.constant 0 : i7
|
||||
%1 = "Concrete.int_to_cleartext"(%0) : (i7) -> !Concrete.cleartext<7>
|
||||
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>)
|
||||
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0): (!Concrete.lwe_ciphertext<2048,7>, i7) -> (!Concrete.lwe_ciphertext<2048,7>)
|
||||
return %2: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
@@ -9,21 +9,21 @@ func.func @add_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !
|
||||
return %1: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.plaintext<5>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-LABEL: func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i5) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i5) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, i5) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
|
||||
|
||||
%1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.plaintext<5>) -> (!Concrete.lwe_ciphertext<2048,7>)
|
||||
%1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, i5) -> (!Concrete.lwe_ciphertext<2048,7>)
|
||||
return %1: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i7) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i7) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, i7) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
|
||||
|
||||
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>)
|
||||
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, i7) -> (!Concrete.lwe_ciphertext<2048,7>)
|
||||
return %1: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
@@ -38,9 +38,9 @@ func.func @negate_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Co
|
||||
|
||||
// CHECK-LABEL: func.func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
func.func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, level = -1 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
|
||||
%1 = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, glweDimension = 1 : i32, level = -1 : i32, polynomialSize = 2048 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
%1 = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = -1 : i32, level = -1 : i32} : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.glwe_ciphertext<2048,1,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
return %1: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
@@ -51,22 +51,3 @@ func.func @keyswitch_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.l
|
||||
%1 = "Concrete.keyswitch_lwe"(%arg0){baseLog = 2 : i32, level = 3 : i32}: (!Concrete.lwe_ciphertext<2048,7>) -> (!Concrete.lwe_ciphertext<2048,7>)
|
||||
return %1: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @encode_int(%arg0: i6) -> !Concrete.plaintext<6>
|
||||
func.func @encode_int(%arg0: i6) -> (!Concrete.plaintext<6>) {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.encode_int"(%arg0) : (i6) -> !Concrete.plaintext<6>
|
||||
// CHECK-NEXT: return %[[V1]] : !Concrete.plaintext<6>
|
||||
|
||||
%0 = "Concrete.encode_int"(%arg0): (i6) -> !Concrete.plaintext<6>
|
||||
return %0: !Concrete.plaintext<6>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @int_to_cleartext() -> !Concrete.cleartext<6>
|
||||
func.func @int_to_cleartext() -> !Concrete.cleartext<6> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = arith.constant 5 : i6
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.int_to_cleartext"(%[[V0]]) : (i6) -> !Concrete.cleartext<6>
|
||||
// CHECK-NEXT: return %[[V1]] : !Concrete.cleartext<6>
|
||||
%0 = arith.constant 5 : i6
|
||||
%1 = "Concrete.int_to_cleartext"(%0) : (i6) -> !Concrete.cleartext<6>
|
||||
return %1 : !Concrete.cleartext<6>
|
||||
}
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
// RUN: concretecompiler --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
|
||||
|
||||
// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i7) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
func.func @mul_cleartext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: i7) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, i7) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
|
||||
|
||||
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>)
|
||||
return %1: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1): (!Concrete.lwe_ciphertext<2048,7>, i7) -> (!Concrete.lwe_ciphertext<2048,7>)
|
||||
return %1: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
@@ -16,8 +16,7 @@ func.func @mul_cleartext_lwe_ciphertext_0(%arg0: !Concrete.lwe_ciphertext<2048,7
|
||||
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
|
||||
|
||||
%0 = arith.constant 0 : i7
|
||||
%1 = "Concrete.int_to_cleartext"(%0) : (i7) -> !Concrete.cleartext<7>
|
||||
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>)
|
||||
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0): (!Concrete.lwe_ciphertext<2048,7>, i7) -> (!Concrete.lwe_ciphertext<2048,7>)
|
||||
return %2: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
@@ -27,7 +26,6 @@ func.func @mul_cleartext_lwe_ciphertext_minus_1(%arg0: !Concrete.lwe_ciphertext<
|
||||
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
|
||||
|
||||
%0 = arith.constant -1 : i7
|
||||
%1 = "Concrete.int_to_cleartext"(%0) : (i7) -> !Concrete.cleartext<7>
|
||||
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %1): (!Concrete.lwe_ciphertext<2048,7>, !Concrete.cleartext<7>) -> (!Concrete.lwe_ciphertext<2048,7>)
|
||||
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0): (!Concrete.lwe_ciphertext<2048,7>, i7) -> (!Concrete.lwe_ciphertext<2048,7>)
|
||||
return %2: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
@@ -13,7 +13,13 @@ func.func @type_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Conc
|
||||
return %arg0: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @type_cleartext(%arg0: !Concrete.cleartext<5>) -> !Concrete.cleartext<5>
|
||||
// CHECK-LABEL: func @type_lwe_ciphertext_with_crt(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>
|
||||
func.func @type_lwe_ciphertext_with_crt(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7> {
|
||||
// CHECK-NEXT: return %arg0 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>
|
||||
return %arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @type_cleartext(%arg0: !Concrete.cleartext<5>) -> !Concrete.cleartext<5>
|
||||
func.func @type_cleartext(%arg0: !Concrete.cleartext<5>) -> !Concrete.cleartext<5> {
|
||||
// CHECK-NEXT: return %arg0 : !Concrete.cleartext<5>
|
||||
return %arg0: !Concrete.cleartext<5>
|
||||
|
||||
@@ -52,3 +52,20 @@ func.func @add_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>, %arg1: !TFHE.glwe<{1024,
|
||||
return %1: !TFHE.glwe<{1024,12,64}{7}>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// GLWE polynomialSize parameter result
|
||||
func.func @add_glwe(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, %arg1: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}> {
|
||||
// expected-error @+1 {{'TFHE.add_glwe' op should have the same GLWE 'crt' parameter}}
|
||||
%1 = "TFHE.add_glwe"(%arg0, %arg1): (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> (!TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>)
|
||||
return %1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// GLWE polynomialSize parameter inputs
|
||||
func.func @add_glwe(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, %arg1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}> {
|
||||
// expected-error @+1 {{'TFHE.add_glwe' op should have the same GLWE 'crt' parameter}}
|
||||
%1 = "TFHE.add_glwe"(%arg0, %arg1): (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>) -> (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>)
|
||||
return %1: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>
|
||||
}
|
||||
|
||||
@@ -30,6 +30,16 @@ func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,
|
||||
|
||||
// -----
|
||||
|
||||
// GLWE crt parameter
|
||||
func.func @add_glwe_int(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}> {
|
||||
%0 = arith.constant 1 : i8
|
||||
// expected-error @+1 {{'TFHE.add_glwe_int' op should have the same GLWE 'crt' parameter}}
|
||||
%1 = "TFHE.add_glwe_int"(%arg0, %0): (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, i8) -> (!TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>)
|
||||
return %1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// integer width doesn't match GLWE parameter
|
||||
func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}> {
|
||||
%0 = arith.constant 1 : i9
|
||||
|
||||
@@ -30,6 +30,16 @@ func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,
|
||||
|
||||
// -----
|
||||
|
||||
// GLWE crt parameter
|
||||
func.func @mul_glwe_int(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}> {
|
||||
%0 = arith.constant 1 : i8
|
||||
// expected-error @+1 {{'TFHE.mul_glwe_int' op should have the same GLWE 'crt' parameter}}
|
||||
%1 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, i8) -> (!TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>)
|
||||
return %1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// integer width doesn't match GLWE parameter
|
||||
func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}> {
|
||||
%0 = arith.constant 1 : i9
|
||||
|
||||
@@ -27,6 +27,15 @@ func.func @neg_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,11,6
|
||||
|
||||
// -----
|
||||
|
||||
// GLWE crt parameter
|
||||
func.func @neg_glwe(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}> {
|
||||
// expected-error @+1 {{'TFHE.neg_glwe' op should have the same GLWE 'crt' parameter}}
|
||||
%1 = "TFHE.neg_glwe"(%arg0): (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> (!TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>)
|
||||
return %1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// integer width doesn't match GLWE parameter
|
||||
func.func @neg_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,11,64}{7}> {
|
||||
// expected-error @+1 {{'TFHE.neg_glwe' op should have the same GLWE 'polynomialSize' parameter}}
|
||||
|
||||
@@ -28,6 +28,17 @@ func.func @sub_int_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,
|
||||
return %1: !TFHE.glwe<{1024,11,64}{7}>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// GLWE polynomialSize parameter
|
||||
func.func @sub_int_glwe(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}> {
|
||||
%0 = arith.constant 1 : i8
|
||||
// expected-error @+1 {{'TFHE.sub_int_glwe' op should have the same GLWE 'crt' parameter}}
|
||||
%1 = "TFHE.sub_int_glwe"(%0, %arg0): (i8, !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> (!TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>)
|
||||
return %1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// integer width doesn't match GLWE parameter
|
||||
|
||||
@@ -11,3 +11,15 @@ func.func @glwe_1(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> {
|
||||
// CHECK-LABEL: return %arg0 : !TFHE.glwe<{_,_,_}{7}>
|
||||
return %arg0: !TFHE.glwe<{_,_,_}{7}>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @glwe_crt(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}>) -> !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}>
|
||||
func.func @glwe_crt(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}>) -> !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}> {
|
||||
// CHECK-LABEL: return %arg0 : !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}>
|
||||
return %arg0: !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @glwe_crt_undef(%arg0: !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}>) -> !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}>
|
||||
func.func @glwe_crt_undef(%arg0: !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}>) -> !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}> {
|
||||
// CHECK-LABEL: return %arg0 : !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}>
|
||||
return %arg0: !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}>
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ llvm::Error checkResult(ScalarDesc &desc,
|
||||
}
|
||||
if (desc.value != res64->getValue()) {
|
||||
return StreamStringError("unexpected result value: got ")
|
||||
<< res64->getValue() << "expected " << desc.value;
|
||||
<< res64->getValue() << " expected " << desc.value;
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
@@ -204,6 +204,12 @@ template <> struct llvm::yaml::MappingTraits<EndToEndDesc> {
|
||||
desc.v0Constraint->p = v0constraint[0];
|
||||
desc.v0Constraint->norm2 = v0constraint[1];
|
||||
}
|
||||
mlir::concretelang::LargeIntegerParameter largeInterger;
|
||||
io.mapOptional("large-integer-crt-decomposition",
|
||||
largeInterger.crtDecomposition);
|
||||
if (!largeInterger.crtDecomposition.empty()) {
|
||||
desc.largeIntegerParameter = largeInterger;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ struct TensorDescription {
|
||||
ValueWidth width;
|
||||
};
|
||||
struct ScalarDesc {
|
||||
uint64_t value;
|
||||
int64_t value;
|
||||
ValueWidth width;
|
||||
};
|
||||
|
||||
@@ -43,6 +43,8 @@ struct EndToEndDesc {
|
||||
std::vector<TestDescription> tests;
|
||||
llvm::Optional<mlir::concretelang::V0Parameter> v0Parameter;
|
||||
llvm::Optional<mlir::concretelang::V0FHEConstraint> v0Constraint;
|
||||
llvm::Optional<mlir::concretelang::LargeIntegerParameter>
|
||||
largeIntegerParameter;
|
||||
};
|
||||
|
||||
llvm::Expected<mlir::concretelang::LambdaArgument *>
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
## Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
## Exceptions. See
|
||||
## https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
## for license information.
|
||||
##
|
||||
## This file contains programs and references to end to end test the compiler.
|
||||
## Program in this file aims to test the `tensor` dialects on encrypted integers.
|
||||
##
|
||||
## Operators:
|
||||
## - tensor.extract
|
||||
## - tensor.insert
|
||||
## - tensor.extract_slice
|
||||
## - tensor.insert_slice
|
||||
|
||||
description: identity
|
||||
program: |
|
||||
func.func @main(%t: tensor<2x10x!FHE.eint<6>>) -> tensor<2x10x!FHE.eint<6>> {
|
||||
@@ -14,6 +28,24 @@ tests:
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
shape: [2,10]
|
||||
---
|
||||
description: identity_16bits
|
||||
program: |
|
||||
func.func @main(%t: tensor<2x10x!FHE.eint<16>>) -> tensor<2x10x!FHE.eint<16>> {
|
||||
return %t : tensor<2x10x!FHE.eint<16>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
|
||||
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
|
||||
shape: [2,10]
|
||||
outputs:
|
||||
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
|
||||
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
|
||||
shape: [2,10]
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: extract
|
||||
program: |
|
||||
func.func @main(%t: tensor<2x10x!FHE.eint<6>>, %i: index, %j: index) ->
|
||||
@@ -59,6 +91,157 @@ tests:
|
||||
outputs:
|
||||
- scalar: 9
|
||||
---
|
||||
description: extract_16bits
|
||||
program: |
|
||||
func.func @main(%t: tensor<2x10x!FHE.eint<16>>, %i: index, %j: index) ->
|
||||
!FHE.eint<16> {
|
||||
%c = tensor.extract %t[%i, %j] : tensor<2x10x!FHE.eint<16>>
|
||||
return %c : !FHE.eint<16>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
|
||||
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
|
||||
shape: [2,10]
|
||||
- scalar: 0
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 65535
|
||||
- inputs:
|
||||
- tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
shape: [2,10]
|
||||
- scalar: 0
|
||||
- scalar: 9
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
shape: [2,10]
|
||||
- scalar: 1
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
shape: [2,10]
|
||||
- scalar: 1
|
||||
- scalar: 9
|
||||
outputs:
|
||||
- scalar: 9
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: insert
|
||||
program: |
|
||||
func.func @main(%t: tensor<2x10x!FHE.eint<6>>, %i: index, %j: index, %x: !FHE.eint<6>) -> tensor<2x10x!FHE.eint<6>> {
|
||||
%r = tensor.insert %x into %t[%i, %j] : tensor<2x10x!FHE.eint<6>>
|
||||
return %r : tensor<2x10x!FHE.eint<6>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
shape: [2,10]
|
||||
- scalar: 0
|
||||
- scalar: 0
|
||||
- scalar: 42
|
||||
outputs:
|
||||
- tensor: [42, 12, 7, 43, 52, 9, 26, 34, 22, 0,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
shape: [2,10]
|
||||
- inputs:
|
||||
- tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
shape: [2,10]
|
||||
- scalar: 0
|
||||
- scalar: 1
|
||||
- scalar: 42
|
||||
outputs:
|
||||
- tensor: [63, 42, 7, 43, 52, 9, 26, 34, 22, 0,
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
shape: [2,10]
|
||||
- inputs:
|
||||
- tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0,
|
||||
42, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
shape: [2,10]
|
||||
- scalar: 1
|
||||
- scalar: 0
|
||||
- scalar: 42
|
||||
outputs:
|
||||
- tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0,
|
||||
42, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
shape: [2,10]
|
||||
- inputs:
|
||||
- tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0,
|
||||
42, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
shape: [2,10]
|
||||
- scalar: 1
|
||||
- scalar: 9
|
||||
- scalar: 42
|
||||
outputs:
|
||||
- tensor: [63, 12, 7, 43, 52, 9, 26, 34, 22, 0,
|
||||
42, 1, 2, 3, 4, 5, 6, 7, 8, 42]
|
||||
shape: [2,10]
|
||||
---
|
||||
description: insert_16bits
|
||||
program: |
|
||||
func.func @main(%t: tensor<2x10x!FHE.eint<16>>, %i: index, %j: index, %x: !FHE.eint<16>) -> tensor<2x10x!FHE.eint<16>> {
|
||||
%r = tensor.insert %x into %t[%i, %j] : tensor<2x10x!FHE.eint<16>>
|
||||
return %r : tensor<2x10x!FHE.eint<16>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
|
||||
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
|
||||
shape: [2,10]
|
||||
- scalar: 0
|
||||
- scalar: 0
|
||||
- scalar: 42
|
||||
outputs:
|
||||
- tensor: [42, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
|
||||
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
|
||||
shape: [2,10]
|
||||
- inputs:
|
||||
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
|
||||
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
|
||||
shape: [2,10]
|
||||
- scalar: 0
|
||||
- scalar: 1
|
||||
- scalar: 42
|
||||
outputs:
|
||||
- tensor: [65535, 42, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
|
||||
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
|
||||
shape: [2,10]
|
||||
- inputs:
|
||||
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
|
||||
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
|
||||
shape: [2,10]
|
||||
- scalar: 1
|
||||
- scalar: 0
|
||||
- scalar: 42
|
||||
outputs:
|
||||
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
|
||||
42, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
|
||||
shape: [2,10]
|
||||
- inputs:
|
||||
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
|
||||
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
|
||||
shape: [2,10]
|
||||
- scalar: 1
|
||||
- scalar: 9
|
||||
- scalar: 42
|
||||
outputs:
|
||||
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
|
||||
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 42]
|
||||
shape: [2,10]
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: extract_slice
|
||||
program: |
|
||||
func.func @main(%t: tensor<2x10x!FHE.eint<6>>) -> tensor<1x5x!FHE.eint<6>> {
|
||||
@@ -91,6 +274,24 @@ tests:
|
||||
- tensor: [ 5, 6, 7, 8, 9]
|
||||
shape: [5]
|
||||
---
|
||||
description: extract_slice_16bits
|
||||
program: |
|
||||
func.func @main(%t: tensor<2x10x!FHE.eint<16>>) -> tensor<1x5x!FHE.eint<16>> {
|
||||
%r = tensor.extract_slice %t[1, 5][1, 5][1, 1] : tensor<2x10x!FHE.eint<16>> to tensor<1x5x!FHE.eint<16>>
|
||||
return %r : tensor<1x5x!FHE.eint<16>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
|
||||
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
|
||||
shape: [2,10]
|
||||
outputs:
|
||||
- tensor: [64814, 65491, 4271, 9294, 0]
|
||||
shape: [1,5]
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: extract_slice_stride
|
||||
program: |
|
||||
func.func @main(%t: tensor<2x10x!FHE.eint<6>>) -> tensor<1x5x!FHE.eint<6>> {
|
||||
@@ -122,6 +323,24 @@ tests:
|
||||
- tensor: [3, 2, 1]
|
||||
shape: [3]
|
||||
---
|
||||
description: extract_slice_stride_16bits
|
||||
program: |
|
||||
func.func @main(%t: tensor<2x10x!FHE.eint<6>>) -> tensor<1x5x!FHE.eint<6>> {
|
||||
%r = tensor.extract_slice %t[1, 0][1, 5][1, 2] : tensor<2x10x!FHE.eint<6>> to tensor<1x5x!FHE.eint<6>>
|
||||
return %r : tensor<1x5x!FHE.eint<6>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
|
||||
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
|
||||
shape: [2,10]
|
||||
outputs:
|
||||
- tensor: [38786, 65112, 60515, 65491, 9294]
|
||||
shape: [1,5]
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: insert_slice
|
||||
program: |
|
||||
func.func @main(%t0: tensor<2x10x!FHE.eint<6>>, %t1: tensor<2x2x!FHE.eint<6>>) -> tensor<2x10x!FHE.eint<6>> {
|
||||
@@ -179,3 +398,25 @@ tests:
|
||||
- tensor: [0, 1, 2,
|
||||
3, 4, 5]
|
||||
shape: [2, 3]
|
||||
---
|
||||
description: insert_slice_16bits
|
||||
program: |
|
||||
func.func @main(%t0: tensor<2x10x!FHE.eint<16>>, %t1: tensor<2x2x!FHE.eint<16>>) -> tensor<2x10x!FHE.eint<16>> {
|
||||
%r = tensor.insert_slice %t1 into %t0[0, 5][2, 2][1, 1] : tensor<2x2x!FHE.eint<16>> into tensor<2x10x!FHE.eint<16>>
|
||||
return %r : tensor<2x10x!FHE.eint<16>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [65535, 46706, 18752, 55384, 55709, 1726, 35063, 57650, 45551, 5769,
|
||||
38786, 36362, 65112, 5748, 60515, 64814, 65491, 4271, 9294, 0]
|
||||
shape: [2,10]
|
||||
- tensor: [1000, 1001,
|
||||
1002, 1003]
|
||||
shape: [2,2]
|
||||
outputs:
|
||||
- tensor: [65535, 46706, 18752, 55384, 55709, 1000, 1001, 57650, 45551, 5769,
|
||||
38786, 36362, 65112, 5748, 60515, 1002, 1003, 4271, 9294, 0]
|
||||
shape: [2,10]
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
|
||||
@@ -9,6 +9,28 @@ tests:
|
||||
outputs:
|
||||
- scalar: 1
|
||||
---
|
||||
description: identity_16bits
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
|
||||
return %arg0: !FHE.eint<16>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 1
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 72071
|
||||
outputs:
|
||||
- scalar: 72071
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: zero_tensor
|
||||
program: |
|
||||
func.func @main() -> tensor<2x2x4x!FHE.eint<6>> {
|
||||
@@ -20,6 +42,20 @@ tests:
|
||||
- tensor: [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
|
||||
shape: [2,2,4]
|
||||
---
|
||||
description: zero_tensor_16bits
|
||||
program: |
|
||||
func.func @main() -> tensor<2x2x4x!FHE.eint<16>> {
|
||||
%0 = "FHE.zero_tensor"() : () -> tensor<2x2x4x!FHE.eint<16>>
|
||||
return %0 : tensor<2x2x4x!FHE.eint<16>>
|
||||
}
|
||||
tests:
|
||||
- outputs:
|
||||
- tensor: [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
|
||||
shape: [2,2,4]
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: add_eint_int_cst
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
@@ -32,18 +68,30 @@ tests:
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 1
|
||||
---
|
||||
description: add_eint_int_cst_16bits
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
|
||||
%0 = arith.constant 1 : i17
|
||||
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<16>, i17) -> (!FHE.eint<16>)
|
||||
return %1: !FHE.eint<16>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 2
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
- scalar: 72070
|
||||
outputs:
|
||||
- scalar: 3
|
||||
- scalar: 72071
|
||||
- inputs:
|
||||
- scalar: 3
|
||||
- scalar: 72071
|
||||
outputs:
|
||||
- scalar: 4
|
||||
- scalar: 0
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: add_eint_int_arg
|
||||
program: |
|
||||
@@ -63,30 +111,75 @@ tests:
|
||||
outputs:
|
||||
- scalar: 3
|
||||
---
|
||||
description: sub_int_eint_cst
|
||||
description: add_eint_int_arg_16bits
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
%0 = arith.constant 7 : i3
|
||||
%1 = "FHE.sub_int_eint"(%0, %arg0): (i3, !FHE.eint<2>) -> (!FHE.eint<2>)
|
||||
func.func @main(%arg0: !FHE.eint<16>, %arg1: i17) -> !FHE.eint<16> {
|
||||
%1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<16>, i17) -> (!FHE.eint<16>)
|
||||
return %1: !FHE.eint<16>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 5
|
||||
outputs:
|
||||
- scalar: 5
|
||||
- inputs:
|
||||
- scalar: 36036
|
||||
- scalar: 36035
|
||||
outputs:
|
||||
- scalar: 72071
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: add_eint
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
|
||||
return %1: !FHE.eint<2>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 6
|
||||
- scalar: 1
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 2
|
||||
outputs:
|
||||
- scalar: 3
|
||||
---
|
||||
description: add_eint_16bits
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<16>, %arg1: !FHE.eint<16>) -> !FHE.eint<16> {
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<16>, !FHE.eint<16>) -> (!FHE.eint<16>)
|
||||
return %1: !FHE.eint<16>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 5
|
||||
outputs:
|
||||
- scalar: 5
|
||||
- inputs:
|
||||
- scalar: 3
|
||||
- scalar: 36036
|
||||
- scalar: 36035
|
||||
outputs:
|
||||
- scalar: 4
|
||||
- inputs:
|
||||
- scalar: 4
|
||||
outputs:
|
||||
- scalar: 3
|
||||
- scalar: 72071
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: sub_eint_int_cst
|
||||
program: |
|
||||
@@ -105,6 +198,22 @@ tests:
|
||||
outputs:
|
||||
- scalar: 0
|
||||
---
|
||||
description: sub_eint_int_cst_16bits
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
|
||||
%0 = arith.constant 7 : i17
|
||||
%1 = "FHE.sub_eint_int"(%arg0, %0): (!FHE.eint<16>, i17) -> (!FHE.eint<16>)
|
||||
return %1: !FHE.eint<16>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 7
|
||||
outputs:
|
||||
- scalar: 0
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: sub_eint_int_arg
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<4>, %arg1: i5) -> !FHE.eint<4> {
|
||||
@@ -128,6 +237,22 @@ tests:
|
||||
outputs:
|
||||
- scalar: 3
|
||||
---
|
||||
description: sub_eint_int_arg_16bits_fixme
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<16>, %arg1: i17) -> !FHE.eint<16> {
|
||||
%1 = "FHE.sub_eint_int"(%arg0, %arg1): (!FHE.eint<16>, i17) -> (!FHE.eint<16>)
|
||||
return %1: !FHE.eint<16>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 72071
|
||||
- scalar: 2
|
||||
outputs:
|
||||
- scalar: 72069
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: sub_eint
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
|
||||
@@ -151,6 +276,32 @@ tests:
|
||||
outputs:
|
||||
- scalar: 3
|
||||
---
|
||||
description: sub_eint_16bits
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<16>, %arg1: !FHE.eint<16>) -> !FHE.eint<16> {
|
||||
%1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<16>, !FHE.eint<16>) -> (!FHE.eint<16>)
|
||||
return %1: !FHE.eint<16>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 72071
|
||||
- scalar: 72071
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 7
|
||||
- scalar: 4
|
||||
outputs:
|
||||
- scalar: 3
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: sub_int_eint_arg
|
||||
program: |
|
||||
func.func @main(%arg0: i3, %arg1: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
@@ -174,6 +325,79 @@ tests:
|
||||
outputs:
|
||||
- scalar: 5
|
||||
---
|
||||
description: sub_int_eint_arg_16bits
|
||||
program: |
|
||||
func.func @main(%arg0: i17, %arg1: !FHE.eint<16>) -> !FHE.eint<16> {
|
||||
%1 = "FHE.sub_int_eint"(%arg0, %arg1): (i17, !FHE.eint<16>) -> (!FHE.eint<16>)
|
||||
return %1: !FHE.eint<16>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: -1
|
||||
outputs:
|
||||
- scalar: 1
|
||||
#- inputs:
|
||||
# - scalar: 72071
|
||||
# - scalar: 0
|
||||
# outputs:
|
||||
# - scalar: 72071
|
||||
#- inputs:
|
||||
# - scalar: 7
|
||||
# - scalar: 4
|
||||
# outputs:
|
||||
# - scalar: 3
|
||||
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: sub_int_eint_cst
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
%0 = arith.constant 7 : i3
|
||||
%1 = "FHE.sub_int_eint"(%0, %arg0): (i3, !FHE.eint<2>) -> (!FHE.eint<2>)
|
||||
return %1 : !FHE.eint<2>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 6
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
outputs:
|
||||
- scalar: 5
|
||||
---
|
||||
description: sub_int_eint_cst_16bits
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
|
||||
%0 = arith.constant -1 : i17
|
||||
%1 = "FHE.sub_int_eint"(%0, %arg0): (i17, !FHE.eint<16>) -> (!FHE.eint<16>)
|
||||
return %1: !FHE.eint<16>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 72071
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 72071
|
||||
- inputs:
|
||||
- scalar: 32000
|
||||
outputs:
|
||||
- scalar: 40071
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: neg_eint
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
@@ -198,6 +422,29 @@ tests:
|
||||
outputs:
|
||||
- scalar: 6
|
||||
---
|
||||
description: neg_eint_16bits
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
|
||||
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<16>) -> (!FHE.eint<16>)
|
||||
return %1: !FHE.eint<16>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 72071
|
||||
- inputs:
|
||||
- scalar: 72071
|
||||
outputs:
|
||||
- scalar: 1
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: neg_eint_3bits
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> {
|
||||
@@ -247,6 +494,30 @@ tests:
|
||||
outputs:
|
||||
- scalar: 4
|
||||
---
|
||||
description: mul_eint_int_cst_16bits
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
|
||||
%0 = arith.constant 2 : i17
|
||||
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<16>, i17) -> (!FHE.eint<16>)
|
||||
return %1: !FHE.eint<16>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 2
|
||||
- inputs:
|
||||
- scalar: 36035
|
||||
outputs:
|
||||
- scalar: 72070
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: mul_eint_int_arg
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<2>, %arg1: i3) -> !FHE.eint<2> {
|
||||
@@ -270,28 +541,31 @@ tests:
|
||||
outputs:
|
||||
- scalar: 4
|
||||
---
|
||||
description: add_eint
|
||||
description: mul_eint_int_arg_16bits
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
func.func @main(%arg0: !FHE.eint<16>, %arg1: i17) -> !FHE.eint<16> {
|
||||
%1 = "FHE.mul_eint_int"(%arg0, %arg1): (!FHE.eint<16>, i17) -> (!FHE.eint<16>)
|
||||
return %1: !FHE.eint<16>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 2
|
||||
- scalar: 0
|
||||
- scalar: 87
|
||||
outputs:
|
||||
- scalar: 3
|
||||
- inputs:
|
||||
- scalar: 4
|
||||
- scalar: 5
|
||||
outputs:
|
||||
- scalar: 9
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 1
|
||||
- scalar: 72071
|
||||
outputs:
|
||||
- scalar: 72071
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
- scalar: 3572
|
||||
outputs:
|
||||
- scalar: 7144
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: apply_lookup_table_1_bits
|
||||
program: |
|
||||
|
||||
@@ -1,3 +1,229 @@
|
||||
description: add_eint_int_term_to_term
|
||||
program: |
|
||||
// Returns the term to term addition of `%a0` with `%a1`
|
||||
func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!FHE.eint<6>> {
|
||||
%res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4xi7>) -> tensor<4x!FHE.eint<6>>
|
||||
return %res : tensor<4x!FHE.eint<6>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [31, 6, 12, 9]
|
||||
shape: [4]
|
||||
width: 8
|
||||
- tensor: [32, 9, 2, 3]
|
||||
shape: [4]
|
||||
width: 8
|
||||
outputs:
|
||||
- tensor: [63, 15, 14, 12]
|
||||
shape: [4]
|
||||
---
|
||||
description: add_eint_int_term_to_term_16bits
|
||||
program: |
|
||||
// Returns the term to term addition of `%a0` with `%a1`
|
||||
func.func @main(%a0: tensor<4x!FHE.eint<16>>, %a1: tensor<4xi17>) -> tensor<4x!FHE.eint<16>> {
|
||||
%res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<16>>, tensor<4xi17>) -> tensor<4x!FHE.eint<16>>
|
||||
return %res : tensor<4x!FHE.eint<16>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [32767, 1276, 10212, 0]
|
||||
shape: [4]
|
||||
- tensor: [32768, 20967, 3, 0]
|
||||
shape: [4]
|
||||
outputs:
|
||||
- tensor: [65535, 22243, 10215, 0]
|
||||
shape: [4]
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: add_eint_term_to_term
|
||||
program: |
|
||||
func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> {
|
||||
%res = "FHELinalg.add_eint"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>>
|
||||
return %res : tensor<4x!FHE.eint<6>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [31, 6, 12, 9]
|
||||
shape: [4]
|
||||
- tensor: [32, 9, 2, 3]
|
||||
shape: [4]
|
||||
outputs:
|
||||
- tensor: [63, 15, 14, 12]
|
||||
shape: [4]
|
||||
---
|
||||
description: add_eint_term_to_term_16bits
|
||||
program: |
|
||||
func.func @main(%a0: tensor<4x!FHE.eint<16>>, %a1: tensor<4x!FHE.eint<16>>) -> tensor<4x!FHE.eint<16>> {
|
||||
%res = "FHELinalg.add_eint"(%a0, %a1) : (tensor<4x!FHE.eint<16>>, tensor<4x!FHE.eint<16>>) -> tensor<4x!FHE.eint<16>>
|
||||
return %res : tensor<4x!FHE.eint<16>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [32767, 1276, 10212, 0]
|
||||
shape: [4]
|
||||
- tensor: [32768, 20967, 3, 0]
|
||||
shape: [4]
|
||||
outputs:
|
||||
- tensor: [65535, 22243, 10215, 0]
|
||||
shape: [4]
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: sub_int_eint_term_to_term
|
||||
program: |
|
||||
// Returns the term to term substraction of `%a0` with `%a1`
|
||||
func.func @main(%a0: tensor<4xi5>, %a1: tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>> {
|
||||
%res = "FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4xi5>, tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>>
|
||||
return %res : tensor<4x!FHE.eint<4>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [15, 9, 12, 9]
|
||||
shape: [4]
|
||||
width: 8
|
||||
- tensor: [15, 6, 2, 3]
|
||||
shape: [4]
|
||||
width: 8
|
||||
outputs:
|
||||
- tensor: [0, 3, 10, 6]
|
||||
shape: [4]
|
||||
---
|
||||
description: sub_int_eint_term_to_term_16bits
|
||||
program: |
|
||||
// Returns the term to term substraction of `%a0` with `%a1`
|
||||
func.func @main(%a0: tensor<4xi17>, %a1: tensor<4x!FHE.eint<16>>) -> tensor<4x!FHE.eint<16>> {
|
||||
%res = "FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4xi17>, tensor<4x!FHE.eint<16>>) -> tensor<4x!FHE.eint<16>>
|
||||
return %res : tensor<4x!FHE.eint<16>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [65535, 22243, 10215, 0]
|
||||
shape: [4]
|
||||
- tensor: [65535, 1276, 10212, 0]
|
||||
shape: [4]
|
||||
outputs:
|
||||
- tensor: [0, 20967, 3, 0]
|
||||
shape: [4]
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: sub_eint_int_term_to_term
|
||||
program: |
|
||||
func.func @main(%a0: tensor<4xi5>, %a1: tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>> {
|
||||
%res = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<4x!FHE.eint<4>>, tensor<4xi5>) -> tensor<4x!FHE.eint<4>>
|
||||
return %res : tensor<4x!FHE.eint<4>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [15, 6, 2, 3]
|
||||
shape: [4]
|
||||
width: 8
|
||||
- tensor: [15, 9, 12, 9]
|
||||
shape: [4]
|
||||
width: 8
|
||||
outputs:
|
||||
- tensor: [0, 3, 10, 6]
|
||||
shape: [4]
|
||||
---
|
||||
description: sub_eint_int_term_to_term_16bits
|
||||
program: |
|
||||
func.func @main(%a0: tensor<4xi17>, %a1: tensor<4x!FHE.eint<16>>) -> tensor<4x!FHE.eint<16>> {
|
||||
%res = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<4x!FHE.eint<16>>, tensor<4xi17>) -> tensor<4x!FHE.eint<16>>
|
||||
return %res : tensor<4x!FHE.eint<16>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [65535, 1276, 10212, 0]
|
||||
shape: [4]
|
||||
- tensor: [65535, 22243, 10215, 0]
|
||||
shape: [4]
|
||||
outputs:
|
||||
- tensor: [0, 20967, 3, 0]
|
||||
shape: [4]
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: sub_eint_term_to_term
|
||||
program: |
|
||||
func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> {
|
||||
%res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>>
|
||||
return %res : tensor<4x!FHE.eint<6>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [31, 6, 12, 9]
|
||||
shape: [4]
|
||||
width: 8
|
||||
- tensor: [4, 2, 9, 3]
|
||||
shape: [4]
|
||||
width: 8
|
||||
outputs:
|
||||
- tensor: [27, 4, 3, 6]
|
||||
shape: [4]
|
||||
---
|
||||
description: sub_eint_term_to_term_16bits
|
||||
program: |
|
||||
func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> {
|
||||
%res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>>
|
||||
return %res : tensor<4x!FHE.eint<6>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [65535, 22243, 10215, 0]
|
||||
shape: [4]
|
||||
- tensor: [65535, 1276, 10212, 0]
|
||||
shape: [4]
|
||||
outputs:
|
||||
- tensor: [0, 20967, 3, 0]
|
||||
shape: [4]
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: mul_eint_int_term_to_term
|
||||
program: |
|
||||
// Returns the term to term multiplication of `%a0` with `%a1`
|
||||
func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!FHE.eint<6>> {
|
||||
%res = "FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4xi7>) -> tensor<4x!FHE.eint<6>>
|
||||
return %res : tensor<4x!FHE.eint<6>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [31, 6, 12, 9]
|
||||
shape: [4]
|
||||
width: 8
|
||||
- tensor: [2, 3, 2, 3]
|
||||
shape: [4]
|
||||
width: 8
|
||||
outputs:
|
||||
- tensor: [62, 18, 24, 27]
|
||||
shape: [4]
|
||||
---
|
||||
description: mul_eint_int_term_to_term_16bits
|
||||
program: |
|
||||
// Returns the term to term multiplication of `%a0` with `%a1`
|
||||
func.func @main(%a0: tensor<4x!FHE.eint<16>>, %a1: tensor<4xi17>) -> tensor<4x!FHE.eint<16>> {
|
||||
%res = "FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<16>>, tensor<4xi17>) -> tensor<4x!FHE.eint<16>>
|
||||
return %res : tensor<4x!FHE.eint<16>>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- tensor: [1, 65535, 12, 0]
|
||||
shape: [4]
|
||||
- tensor: [65535, 1, 1987, 0]
|
||||
shape: [4]
|
||||
outputs:
|
||||
- tensor: [65535, 65535, 23844, 0]
|
||||
shape: [4]
|
||||
v0-constraint: [16, 0]
|
||||
v0-parameter: [1,11,565,1,23,5,3]
|
||||
large-integer-crt-decomposition: [7,8,9,11,13]
|
||||
---
|
||||
description: transpose1d
|
||||
program: |
|
||||
func.func @main(%input: tensor<3x!FHE.eint<6>>) -> tensor<3x!FHE.eint<6>> {
|
||||
|
||||
@@ -20,6 +20,9 @@ void compile_and_run(EndToEndDesc desc, LambdaSupport support) {
|
||||
if (desc.v0Parameter.hasValue()) {
|
||||
options.v0Parameter = *desc.v0Parameter;
|
||||
}
|
||||
if (desc.largeIntegerParameter.hasValue()) {
|
||||
options.largeIntegerParameter = *desc.largeIntegerParameter;
|
||||
}
|
||||
|
||||
/* 0 - Enable parallel testing where required */
|
||||
#ifdef CONCRETELANG_PARALLEL_TESTING_ENABLED
|
||||
|
||||
@@ -70,37 +70,6 @@ func.func @main(%arg0: tensor<3x2x!FHE.eint<4>>) -> tensor<3x!FHE.eint<4>> {
|
||||
EXPECT_EQ((*res)[2], 6);
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, add_eint_int_term_to_term) {
|
||||
|
||||
checkedJit(lambda, R"XXX(
|
||||
// Returns the term to term addition of `%a0` with `%a1`
|
||||
func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!FHE.eint<6>> {
|
||||
%res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4xi7>) -> tensor<4x!FHE.eint<6>>
|
||||
return %res : tensor<4x!FHE.eint<6>>
|
||||
}
|
||||
)XXX");
|
||||
std::vector<uint8_t> a0{31, 6, 12, 9};
|
||||
std::vector<uint8_t> a1{32, 9, 2, 3};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg0(a0);
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg1(a1);
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (size_t)4);
|
||||
|
||||
for (size_t i = 0; i < 4; i++) {
|
||||
EXPECT_EQ((*res)[i], (uint64_t)a0[i] + a1[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Same as add_eint_int_term_to_term test above, but returning a lambda argument
|
||||
TEST(End2EndJit_FHELinalg, add_eint_int_term_to_term_ret_lambda_argument) {
|
||||
|
||||
@@ -371,40 +340,6 @@ TEST(End2EndJit_FHELinalg, add_eint_int_matrix_line_missing_dim) {
|
||||
// FHELinalg add_eint ///////////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(End2EndJit_FHELinalg, add_eint_term_to_term) {
|
||||
|
||||
checkedJit(lambda, R"XXX(
|
||||
// Returns the term to term addition of `%a0` with `%a1`
|
||||
func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> {
|
||||
%res = "FHELinalg.add_eint"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>>
|
||||
return %res : tensor<4x!FHE.eint<6>>
|
||||
}
|
||||
)XXX");
|
||||
|
||||
std::vector<uint8_t> a0{31, 6, 12, 9};
|
||||
std::vector<uint8_t> a1{32, 9, 2, 3};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg0(a0);
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg1(a1);
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)4);
|
||||
|
||||
for (size_t i = 0; i < 4; i++) {
|
||||
EXPECT_EQ((*res)[i], (uint64_t)a0[i] + a1[i])
|
||||
<< "result differ at pos " << i << ", expect " << a0[i] + a1[i]
|
||||
<< " got " << (*res)[i];
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, add_eint_term_to_term_broadcast) {
|
||||
|
||||
checkedJit(lambda, R"XXX(
|
||||
@@ -854,36 +789,6 @@ TEST(End2EndJit_FHELinalg, sub_int_eint_matrix_line_missing_dim) {
|
||||
// FHELinalg sub_eint_int ///////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(End2EndJit_FHELinalg, sub_eint_int_term_to_term) {
|
||||
|
||||
checkedJit(lambda, R"XXX(
|
||||
func.func @main(%a0: tensor<4xi5>, %a1: tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>> {
|
||||
%res = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<4x!FHE.eint<4>>, tensor<4xi5>) -> tensor<4x!FHE.eint<4>>
|
||||
return %res : tensor<4x!FHE.eint<4>>
|
||||
}
|
||||
)XXX");
|
||||
std::vector<uint8_t> a0{31, 6, 2, 3};
|
||||
std::vector<uint8_t> a1{32, 9, 12, 9};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg0(a0);
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg1(a1);
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)4);
|
||||
|
||||
for (size_t i = 0; i < 4; i++) {
|
||||
EXPECT_EQ((*res)[i], (uint64_t)a1[i] - a0[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, sub_eint_int_term_to_term_broadcast) {
|
||||
|
||||
checkedJit(lambda, R"XXX(
|
||||
|
||||
@@ -6,8 +6,13 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#define ASSERT_LLVM_ERROR(err) \
|
||||
if (err) { \
|
||||
ASSERT_TRUE(false) << llvm::toString(err); \
|
||||
{ \
|
||||
llvm::Error e = err; \
|
||||
if (e) { \
|
||||
handleAllErrors(std::move(e), [](const llvm::ErrorInfoBase &ei) { \
|
||||
ASSERT_TRUE(false) << ei.message(); \
|
||||
}); \
|
||||
} \
|
||||
}
|
||||
|
||||
// Checks that the value `val` is not in an error state. Returns
|
||||
|
||||
@@ -8,6 +8,7 @@ add_unittest(
|
||||
unit_tests_concretelang_clientlib
|
||||
|
||||
ClientParameters.cpp
|
||||
CRT.cpp
|
||||
KeySet.cpp
|
||||
)
|
||||
|
||||
|
||||
50
compiler/tests/unit_tests/concretelang/ClientLib/CRT.cpp
Normal file
50
compiler/tests/unit_tests/concretelang/ClientLib/CRT.cpp
Normal file
@@ -0,0 +1,50 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "concretelang/ClientLib/CRT.h"
|
||||
#include "tests_tools/assert.h"
|
||||
namespace {
|
||||
namespace crt = concretelang::clientlib::crt;
|
||||
typedef std::vector<int64_t> CRTModuli;
|
||||
|
||||
// Define a fixture for instantiate test with client parameters
|
||||
class CRTTest : public ::testing::TestWithParam<CRTModuli> {};
|
||||
|
||||
TEST_P(CRTTest, crt_iCrt) {
|
||||
auto moduli = GetParam();
|
||||
|
||||
// Max representable value from moduli
|
||||
uint64_t maxValue = 1;
|
||||
for (auto modulus : moduli)
|
||||
maxValue *= modulus;
|
||||
maxValue = maxValue - 1;
|
||||
|
||||
std::vector<uint64_t> valuesToTest{0, maxValue / 2, maxValue};
|
||||
for (auto a : valuesToTest) {
|
||||
auto remainders = crt::crt(moduli, a);
|
||||
auto b = crt::iCrt(moduli, remainders);
|
||||
|
||||
ASSERT_EQ(a, b);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<CRTModuli> generateAllParameters() {
|
||||
return {
|
||||
// This is our default moduli for the 16 bits
|
||||
{7, 8, 9, 11, 13},
|
||||
};
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(CRTSuite, CRTTest,
|
||||
::testing::ValuesIn(generateAllParameters()),
|
||||
[](const testing::TestParamInfo<CRTModuli> info) {
|
||||
auto moduli = info.param;
|
||||
std::string desc("mod");
|
||||
if (!moduli.empty()) {
|
||||
for (auto b : moduli) {
|
||||
desc = desc + "_" + std::to_string(b);
|
||||
}
|
||||
}
|
||||
return desc;
|
||||
});
|
||||
|
||||
} // namespace
|
||||
@@ -42,7 +42,7 @@ TEST(Support, client_parameters_json_serde) {
|
||||
}}};
|
||||
params0.inputs = {
|
||||
{
|
||||
/*.encryption = */ {{clientlib::SMALL_KEY, 0.00, {4}}},
|
||||
/*.encryption = */ {{clientlib::SMALL_KEY, 0.00, {4, {1, 2, 3, 4}}}},
|
||||
/*.shape = */ {32, {1, 2, 3, 4}, 1 * 2 * 3 * 4},
|
||||
},
|
||||
{
|
||||
|
||||
@@ -8,14 +8,14 @@
|
||||
namespace clientlib = concretelang::clientlib;
|
||||
|
||||
// Define a fixture for instantiate test with client parameters
|
||||
class ClientParametersTest
|
||||
class KeySetTest
|
||||
: public ::testing::TestWithParam<clientlib::ClientParameters> {
|
||||
protected:
|
||||
clientlib::ClientParameters clientParameters;
|
||||
};
|
||||
|
||||
// Test case encrypt and decrypt
|
||||
TEST_P(ClientParametersTest, encrypt_decrypt) {
|
||||
TEST_P(KeySetTest, encrypt_decrypt) {
|
||||
|
||||
auto clientParameters = GetParam();
|
||||
|
||||
@@ -29,7 +29,7 @@ TEST_P(ClientParametersTest, encrypt_decrypt) {
|
||||
ASSERT_OUTCOME_HAS_VALUE(keySet->allocate_lwe(0, &ciphertext, size));
|
||||
|
||||
// Encrypt
|
||||
uint64_t input = 3;
|
||||
uint64_t input = 0;
|
||||
ASSERT_OUTCOME_HAS_VALUE(keySet->encrypt_lwe(0, ciphertext, input));
|
||||
|
||||
// Decrypt
|
||||
@@ -45,9 +45,9 @@ TEST_P(ClientParametersTest, encrypt_decrypt) {
|
||||
|
||||
/// Create a client parameters with just one secret key of `dimension` and with
|
||||
/// one input scalar gate and one output scalar gate on the same key
|
||||
clientlib::ClientParameters
|
||||
generateClientParameterOneScalarOneScalar(clientlib::LweDimension dimension,
|
||||
clientlib::Precision precision) {
|
||||
clientlib::ClientParameters generateClientParameterOneScalarOneScalar(
|
||||
clientlib::LweDimension dimension, clientlib::Precision precision,
|
||||
clientlib::CRTDecomposition crtDecomposition) {
|
||||
// One secret key with the given dimension
|
||||
clientlib::ClientParameters params;
|
||||
params.secretKeys.insert({clientlib::SMALL_KEY, {/*.dimension =*/dimension}});
|
||||
@@ -56,6 +56,7 @@ generateClientParameterOneScalarOneScalar(clientlib::LweDimension dimension,
|
||||
clientlib::EncryptionGate encryption;
|
||||
encryption.secretKeyID = clientlib::SMALL_KEY;
|
||||
encryption.encoding.precision = precision;
|
||||
encryption.encoding.crt = crtDecomposition;
|
||||
clientlib::CircuitGate gate;
|
||||
gate.encryption = encryption;
|
||||
params.inputs.push_back(gate);
|
||||
@@ -74,13 +75,23 @@ std::vector<clientlib::ClientParameters> generateAllParameters() {
|
||||
llvm::for_each(llvm::enumerate(precisions),
|
||||
[](auto p) { p.value() = p.index() + 1; });
|
||||
|
||||
// All crt decomposition to test
|
||||
std::vector<clientlib::CRTDecomposition> crtDecompositions{
|
||||
// Empty crt decompositon means no decomposition
|
||||
{},
|
||||
// The default decomposition for 16 bits
|
||||
{7, 8, 9, 11, 13},
|
||||
};
|
||||
|
||||
// All client parameters to test
|
||||
std::vector<clientlib::ClientParameters> parameters;
|
||||
|
||||
for (auto dimension : lweDimensions) {
|
||||
for (auto precision : precisions) {
|
||||
parameters.push_back(
|
||||
generateClientParameterOneScalarOneScalar(dimension, precision));
|
||||
for (auto crtDecomposition : crtDecompositions) {
|
||||
parameters.push_back(generateClientParameterOneScalarOneScalar(
|
||||
dimension, precision, crtDecomposition));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,13 +99,21 @@ std::vector<clientlib::ClientParameters> generateAllParameters() {
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
OneScalarOnScalar, ClientParametersTest,
|
||||
::testing::ValuesIn(generateAllParameters()),
|
||||
OneScalarOnScalar, KeySetTest, ::testing::ValuesIn(generateAllParameters()),
|
||||
[](const testing::TestParamInfo<clientlib::ClientParameters> info) {
|
||||
auto cp = info.param;
|
||||
auto input_0 = cp.inputs[0];
|
||||
return std::string("lweDimension_") +
|
||||
std::to_string(cp.lweSecretKeyParam(input_0).value().dimension) +
|
||||
"_precision_" +
|
||||
std::to_string(input_0.encryption.getValue().encoding.precision);
|
||||
auto paramDescription =
|
||||
std::string("lweDimension_") +
|
||||
std::to_string(cp.lweSecretKeyParam(input_0).value().dimension) +
|
||||
"_precision_" +
|
||||
std::to_string(input_0.encryption.getValue().encoding.precision);
|
||||
auto crt = input_0.encryption.getValue().encoding.crt;
|
||||
if (!crt.empty()) {
|
||||
paramDescription = paramDescription + "_crt_";
|
||||
for (auto b : crt) {
|
||||
paramDescription = paramDescription + "_" + std::to_string(b);
|
||||
}
|
||||
}
|
||||
return paramDescription;
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user