refactor(encodings): raise plaintext/lut encodings higher up in the pipeline

This commit is contained in:
aPere3
2022-11-10 14:57:48 +01:00
committed by Alexandre Péré
parent 7226c89cf1
commit 2fd9b6f0e3
94 changed files with 2731 additions and 2040 deletions

View File

@@ -3,5 +3,4 @@ mlir_tablegen(Passes.h.inc -gen-pass-decls -name Conversion)
add_public_tablegen_target(ConcretelangConversionPassIncGen)
add_dependencies(mlir-headers ConcretelangConversionPassIncGen)
add_subdirectory(FHEToTFHE)
add_subdirectory(TFHEToConcrete)

View File

@@ -1,6 +0,0 @@
set(LLVM_TARGET_DEFINITIONS Patterns.td)
mlir_tablegen(Patterns.h.inc -gen-rewriters -name FHE)
add_public_tablegen_target(FHEToTFHEPatternsIncGen)
add_dependencies(mlir-headers FHEToTFHEPatternsIncGen)
add_concretelang_doc(Patterns FHEToTFHEPatterns concretelang/ -gen-pass-doc)

View File

@@ -1,27 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CONVERSION_FHETOTFHE_PASS_H_
#define CONCRETELANG_CONVERSION_FHETOTFHE_PASS_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace concretelang {
// ApplyLookupTableLowering indicates the strategy to lower an
// FHE.apply_loopup_table ops
enum ApplyLookupTableLowering {
KeySwitchBoostrapLowering,
WopPBSLowering,
};
/// Create a pass to convert `FHE` dialect to `TFHE` dialect.
std::unique_ptr<OperationPass<ModuleOp>>
createConvertFHEToTFHEPass(ApplyLookupTableLowering lower);
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -1,89 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CONVERSION_FHETOTFHE_PATTERNS_H_
#define CONCRETELANG_CONVERSION_FHETOTFHE_PATTERNS_H_
#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h"
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
namespace concretelang {
using FHE::EncryptedIntegerType;
using TFHE::GLWECipherTextType;
/// Converts FHE::EncryptedInteger into TFHE::GlweCiphetext
GLWECipherTextType
convertTypeEncryptedIntegerToGLWE(mlir::MLIRContext *context,
EncryptedIntegerType eint) {
return GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth(),
llvm::ArrayRef<int64_t>());
}
/// Converts the type `t` to `TFHE::GlweCiphetext` if `t` is a
/// `FHE::EncryptedInteger`, otherwise just returns `t`.
mlir::Type convertTypeToGLWEIfEncryptedIntegerType(mlir::MLIRContext *context,
mlir::Type t) {
if (auto eint = t.dyn_cast<EncryptedIntegerType>())
return convertTypeEncryptedIntegerToGLWE(context, eint);
return t;
}
mlir::Value createZeroGLWEOpFromFHE(mlir::PatternRewriter &rewriter,
mlir::Location loc, mlir::OpResult result) {
mlir::SmallVector<mlir::Value> args{};
mlir::SmallVector<mlir::NamedAttribute, 0> attrs;
mlir::SmallVector<mlir::Type, 1> resTypes{result.getType()};
TFHE::ZeroGLWEOp op =
rewriter.create<TFHE::ZeroGLWEOp>(loc, resTypes, args, attrs);
convertOperandAndResultTypes(rewriter, op,
convertTypeToGLWEIfEncryptedIntegerType);
return op.getODSResults(0).front();
}
template <class Operator>
mlir::Value createGLWEOpFromFHE(mlir::PatternRewriter &rewriter,
mlir::Location loc, mlir::Value arg0,
mlir::Value arg1, mlir::OpResult result) {
mlir::SmallVector<mlir::Value, 2> args{arg0, arg1};
mlir::SmallVector<mlir::NamedAttribute, 0> attrs;
mlir::SmallVector<mlir::Type, 1> resTypes{result.getType()};
Operator op = rewriter.create<Operator>(loc, resTypes, args, attrs);
convertOperandAndResultTypes(rewriter, op,
convertTypeToGLWEIfEncryptedIntegerType);
return op.getODSResults(0).front();
}
template <class Operator>
mlir::Value createGLWEOpFromFHE(mlir::PatternRewriter &rewriter,
mlir::Location loc, mlir::Value arg0,
mlir::OpResult result) {
mlir::SmallVector<mlir::Value, 1> args{arg0};
mlir::SmallVector<mlir::NamedAttribute, 0> attrs;
mlir::SmallVector<mlir::Type, 1> resTypes{result.getType()};
Operator op = rewriter.create<Operator>(loc, resTypes, args, attrs);
convertOperandAndResultTypes(rewriter, op,
convertTypeToGLWEIfEncryptedIntegerType);
return op.getODSResults(0).front();
}
} // namespace concretelang
} // namespace mlir
namespace {
#include "concretelang/Conversion/FHEToTFHE/Patterns.h.inc"
}
void populateWithGeneratedFHEToTFHE(mlir::RewritePatternSet &patterns) {
populateWithGenerated(patterns);
}
#endif

View File

@@ -1,45 +0,0 @@
#ifndef CONCRETELANG_CONVERSION_FHETOTFHE_PATTERNS
#define CONCRETELANG_CONVERSION_FHETOTFHE_PATTERNS
include "mlir/Pass/PassBase.td"
include "mlir/IR/PatternBase.td"
include "concretelang/Dialect/FHE/IR/FHEOps.td"
include "concretelang/Dialect/TFHE/IR/TFHEOps.td"
def createZeroGLWEOp : NativeCodeCall<"mlir::concretelang::createZeroGLWEOpFromFHE($_builder, $_loc, $0)">;
def ZeroEintPattern : Pat<
(FHE_ZeroEintOp:$result),
(createZeroGLWEOp $result)>;
def createAddGLWEIntOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromFHE<mlir::concretelang::TFHE::AddGLWEIntOp>($_builder, $_loc, $0, $1, $2)">;
def AddEintIntPattern : Pat<
(FHE_AddEintIntOp:$result $arg0, $arg1),
(createAddGLWEIntOp $arg0, $arg1, $result)>;
def createAddGLWEOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromFHE<mlir::concretelang::TFHE::AddGLWEOp>($_builder, $_loc, $0, $1, $2)">;
def AddEintPattern : Pat<
(FHE_AddEintOp:$result $arg0, $arg1),
(createAddGLWEOp $arg0, $arg1, $result)>;
def createSubGLWEIntOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromFHE<mlir::concretelang::TFHE::SubGLWEIntOp>($_builder, $_loc, $0, $1, $2)">;
def SubIntEintPattern : Pat<
(FHE_SubIntEintOp:$result $arg0, $arg1),
(createSubGLWEIntOp $arg0, $arg1, $result)>;
def createNegGLWEOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromFHE<mlir::concretelang::TFHE::NegGLWEOp>($_builder, $_loc, $0, $1)">;
def NegEintPattern : Pat<
(FHE_NegEintOp:$result $arg0),
(createNegGLWEOp $arg0, $result)>;
def createMulGLWEIntOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromFHE<mlir::concretelang::TFHE::MulGLWEIntOp>($_builder, $_loc, $0, $1, $2)">;
def MulEintIntPattern : Pat<
(FHE_MulEintIntOp:$result $arg0, $arg1),
(createMulGLWEIntOp $arg0, $arg1, $result)>;
#endif

View File

@@ -0,0 +1,51 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CONVERSION_FHETOTFHECRT_PASS_H_
#define CONCRETELANG_CONVERSION_FHETOTFHECRT_PASS_H_
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Casting.h"
#include <list>
namespace mlir {
namespace concretelang {
struct CrtLoweringParameters {
mlir::SmallVector<int64_t> mods;
mlir::SmallVector<int64_t> bits;
size_t nMods;
size_t modsProd;
size_t bitsTotal;
size_t polynomialSize;
size_t lutSize;
CrtLoweringParameters(mlir::SmallVector<int64_t> mods, size_t polySize)
: mods(mods), polynomialSize(polySize) {
nMods = mods.size();
modsProd = 1;
bitsTotal = 0;
bits.clear();
for (auto &mod : mods) {
modsProd *= mod;
uint64_t nbits =
static_cast<uint64_t>(ceil(log2(static_cast<double>(mod))));
bits.push_back(nbits);
bitsTotal += nbits;
}
size_t lutCrtSize = size_t(1) << bitsTotal;
lutCrtSize = std::max(lutCrtSize, polynomialSize);
lutSize = mods.size() * lutCrtSize;
}
};
/// Create a pass to convert `FHE` dialect to `TFHE` dialect with the crt
// strategy.
std::unique_ptr<OperationPass<mlir::ModuleOp>>
createConvertFHEToTFHECrtPass(CrtLoweringParameters lowering);
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,28 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CONVERSION_FHETOTFHESCALAR_PASS_H_
#define CONCRETELANG_CONVERSION_FHETOTFHESCALAR_PASS_H_
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Casting.h"
#include <list>
namespace mlir {
namespace concretelang {
struct ScalarLoweringParameters {
size_t polynomialSize;
ScalarLoweringParameters(size_t polySize) : polynomialSize(polySize){};
};
/// Create a pass to convert `FHE` dialect to `TFHE` dialect with the scalar
// strategy.
std::unique_ptr<OperationPass<mlir::ModuleOp>>
createConvertFHEToTFHEScalarPass(ScalarLoweringParameters loweringParameters);
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -17,7 +17,8 @@
#include "concretelang/Conversion/ConcreteToBConcrete/Pass.h"
#include "concretelang/Conversion/ExtractSDFGOps/Pass.h"
#include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h"
#include "concretelang/Conversion/FHEToTFHE/Pass.h"
#include "concretelang/Conversion/FHEToTFHECrt/Pass.h"
#include "concretelang/Conversion/FHEToTFHEScalar/Pass.h"
#include "concretelang/Conversion/LinalgExtras/Passes.h"
#include "concretelang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
#include "concretelang/Conversion/SDFGToStreamEmulator/Pass.h"

View File

@@ -9,10 +9,18 @@ def FHETensorOpsToLinalg : Pass<"fhe-tensor-ops-to-linalg", "::mlir::func::FuncO
let dependentDialects = ["mlir::linalg::LinalgDialect"];
}
def FHEToTFHE : Pass<"fhe-to-tfhe", "mlir::ModuleOp"> {
let summary = "Lowers operations from the FHE dialect to TFHE";
def FHEToTFHEScalar : Pass<"fhe-to-tfhe-scalar", "mlir::ModuleOp"> {
let summary = "Lowers operations from the FHE dialect to TFHE using the scalar strategy.";
let description = [{ Lowers operations from the FHE dialect to Std + Math }];
let constructor = "mlir::concretelang::createConvertFHEToTFHEPass()";
let constructor = "mlir::concretelang::createConvertFHEToTFHEScalarPass()";
let options = [];
let dependentDialects = ["mlir::linalg::LinalgDialect"];
}
def FHEToTFHECrt : Pass<"fhe-to-tfhe-crt", "mlir::ModuleOp"> {
let summary = "Lowers operations from the FHE dialect to TFHE using the crt strategy.";
let description = [{ Lowers operations from the FHE dialect to Std + Math }];
let constructor = "mlir::concretelang::createConvertFHEToTFHECrtPass()";
let options = [];
let dependentDialects = ["mlir::linalg::LinalgDialect"];
}

View File

@@ -26,8 +26,7 @@ LweCiphertextType convertTypeToLWE(mlir::MLIRContext *context,
auto glwe = type.dyn_cast_or_null<GLWECipherTextType>();
if (glwe != nullptr) {
assert(glwe.getPolynomialSize() == 1);
return LweCiphertextType::get(context, glwe.getDimension(), glwe.getP(),
glwe.getCrtDecomposition());
return LweCiphertextType::get(context, glwe.getDimension(), glwe.getP());
}
auto lwe = type.dyn_cast_or_null<LweCiphertextType>();
if (lwe != nullptr) {

View File

@@ -25,56 +25,21 @@ def BConcrete_AddLweTensorOp : BConcrete_Op<"add_lwe_tensor", [NoSideEffect]> {
let results = (outs 1DTensorOf<[I64]>:$result);
}
def BConcrete_AddCRTLweTensorOp : BConcrete_Op<"add_crt_lwe_tensor", [NoSideEffect]> {
let arguments = (ins
2DTensorOf<[I64]>:$lhs,
2DTensorOf<[I64]>:$rhs,
I64ArrayAttr:$crtDecomposition
);
let results = (outs 2DTensorOf<[I64]>:$result);
}
def BConcrete_AddPlaintextLweTensorOp : BConcrete_Op<"add_plaintext_lwe_tensor", [NoSideEffect]> {
let arguments = (ins 1DTensorOf<[I64]>:$lhs, I64:$rhs);
let results = (outs 1DTensorOf<[I64]>:$result);
}
def BConcrete_AddPlaintextCRTLweTensorOp : BConcrete_Op<"add_plaintext_crt_lwe_tensor", [NoSideEffect]> {
let arguments = (ins
2DTensorOf<[I64]>:$lhs,
AnyInteger:$rhs,
I64ArrayAttr:$crtDecomposition
);
let results = (outs 2DTensorOf<[I64]>:$result);
}
def BConcrete_MulCleartextLweTensorOp : BConcrete_Op<"mul_cleartext_lwe_tensor", [NoSideEffect]> {
let arguments = (ins 1DTensorOf<[I64]>:$lhs, I64:$rhs);
let results = (outs 1DTensorOf<[I64]>:$result);
}
def BConcrete_MulCleartextCRTLweTensorOp : BConcrete_Op<"mul_cleartext_crt_lwe_tensor", [NoSideEffect]> {
let arguments = (ins
2DTensorOf<[I64]>:$lhs,
AnyInteger:$rhs,
I64ArrayAttr:$crtDecomposition
);
let results = (outs 2DTensorOf<[I64]>:$result);
}
def BConcrete_NegateLweTensorOp : BConcrete_Op<"negate_lwe_tensor", [NoSideEffect]> {
let arguments = (ins 1DTensorOf<[I64]>:$ciphertext);
let results = (outs 1DTensorOf<[I64]>:$result);
}
def BConcrete_NegateCRTLweTensorOp : BConcrete_Op<"negate_crt_lwe_tensor", [NoSideEffect]> {
let arguments = (ins
2DTensorOf<[I64]>:$ciphertext,
I64ArrayAttr:$crtDecomposition
);
let results = (outs 2DTensorOf<[I64]>:$result);
}
def BConcrete_KeySwitchLweTensorOp : BConcrete_Op<"keyswitch_lwe_tensor", [NoSideEffect]> {
let arguments = (ins
// LweKeySwitchKeyType:$keyswitch_key,
@@ -99,6 +64,49 @@ def BConcrete_BatchedKeySwitchLweTensorOp : BConcrete_Op<"batched_keyswitch_lwe_
let results = (outs 2DTensorOf<[I64]>:$result);
}
def BConcrete_EncodeExpandLutForBootstrapTensorOp : BConcrete_Op<"encode_expand_lut_for_bootstrap_tensor", [NoSideEffect]> {
let summary =
"Encode and expand a lookup table so that it can be used for a bootstrap.";
let arguments = (ins
1DTensorOf<[I64]> : $input_lookup_table,
I32Attr: $polySize,
I32Attr: $outputBits
);
let results = (outs 1DTensorOf<[I64]> : $result);
}
def BConcrete_EncodeExpandLutForWopPBSTensorOp : BConcrete_Op<"encode_expand_lut_for_woppbs_tensor", [NoSideEffect]> {
let summary =
"Encode and expand a lookup table so that it can be used for a wop pbs.";
let arguments = (ins
1DTensorOf<[I64]> : $input_lookup_table,
I64ArrayAttr: $crtDecomposition,
I64ArrayAttr: $crtBits,
I32Attr : $polySize,
I32Attr : $modulusProduct
);
let results = (outs 1DTensorOf<[I64]> : $result);
}
def BConcrete_EncodePlaintextWithCrtTensorOp : BConcrete_Op<"encode_plaintext_with_crt_tensor", [NoSideEffect]> {
let summary =
"Encodes a plaintext by decomposing it on a crt basis.";
let arguments = (ins
I64 : $input,
I64ArrayAttr: $mods,
I64Attr: $modsProd
);
let results = (outs 1DTensorOf<[I64]> : $result);
}
def BConcrete_BootstrapLweTensorOp : BConcrete_Op<"bootstrap_lwe_tensor", [NoSideEffect]> {
let arguments = (ins
1DTensorOf<[I64]>:$input_ciphertext,
@@ -144,8 +152,7 @@ def BConcrete_WopPBSCRTLweTensorOp : BConcrete_Op<"wop_pbs_crt_lwe_tensor", [NoS
I32Attr : $packingKeySwitchBaseLog,
// Circuit bootstrap parameters
I32Attr : $circuitBootstrapLevel,
I32Attr : $circuitBootstrapBaseLog,
I64ArrayAttr:$crtDecomposition
I32Attr : $circuitBootstrapBaseLog
);
let results = (outs 2DTensorOf<[I64]>:$result);
}
@@ -153,6 +160,8 @@ def BConcrete_WopPBSCRTLweTensorOp : BConcrete_Op<"wop_pbs_crt_lwe_tensor", [NoS
// BConcrete memref operators /////////////////////////////////////////////////
def BConcrete_LweBuffer : MemRefRankOf<[I64], [1]>;
def BConcrete_LutBuffer : MemRefRankOf<[I64], [1]>;
def BConcrete_CrtPlaintextBuffer : MemRefRankOf<[I64], [1]>;
def BConcrete_LweCRTBuffer : MemRefRankOf<[I64], [2]>;
def BConcrete_BatchLweBuffer : MemRefRankOf<[I64], [2]>;
@@ -209,11 +218,49 @@ def BConcrete_BatchedKeySwitchLweBufferOp : BConcrete_Op<"batched_keyswitch_lwe_
);
}
def BConcrete_EncodeExpandLutForBootstrapBufferOp : BConcrete_Op<"encode_expand_lut_for_bootstrap_buffer"> {
let summary =
"Encode and expand a lookup table so that it can be used for a bootstrap.";
let arguments = (ins
BConcrete_LutBuffer: $result,
BConcrete_LutBuffer: $input_lookup_table,
I32Attr: $polySize,
I32Attr: $outputBits
);
}
def BConcrete_EncodeExpandLutForWopPBSBufferOp : BConcrete_Op<"encode_expand_lut_for_woppbs_buffer"> {
let summary =
"Encode and expand a lookup table so that it can be used for a wop pbs.";
let arguments = (ins
BConcrete_LutBuffer : $result,
BConcrete_LutBuffer : $input_lookup_table,
I64ArrayAttr: $crtDecomposition,
I64ArrayAttr: $crtBits,
I32Attr : $polySize,
I32Attr : $modulusProduct
);
}
def BConcrete_EncodePlaintextWithCrtBufferOp : BConcrete_Op<"encode_plaintext_with_crt_buffer"> {
let summary =
"Encodes a plaintext by decomposing it on a crt basis.";
let arguments = (ins
BConcrete_CrtPlaintextBuffer: $result,
I64 : $input,
I64ArrayAttr: $mods,
I64Attr: $modsProd
);
}
def BConcrete_BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> {
let arguments = (ins
BConcrete_LweBuffer:$result,
BConcrete_LweBuffer:$input_ciphertext,
MemRefRankOf<[I64], [1]>:$lookup_table,
BConcrete_LutBuffer:$lookup_table,
I32Attr:$inputLweDim,
I32Attr:$polySize,
I32Attr:$level,
@@ -227,7 +274,7 @@ def BConcrete_BatchedBootstrapLweBufferOp : BConcrete_Op<"batched_bootstrap_lwe_
let arguments = (ins
BConcrete_BatchLweBuffer:$result,
BConcrete_BatchLweBuffer:$input_ciphertext,
MemRefRankOf<[I64], [1]>:$lookup_table,
BConcrete_LutBuffer:$lookup_table,
I32Attr:$inputLweDim,
I32Attr:$polySize,
I32Attr:$level,
@@ -241,7 +288,7 @@ def BConcrete_WopPBSCRTLweBufferOp : BConcrete_Op<"wop_pbs_crt_lwe_buffer"> {
let arguments = (ins
BConcrete_LweCRTBuffer:$result,
BConcrete_LweCRTBuffer:$ciphertext,
MemRefRankOf<[I64], [1]>:$lookup_table,
BConcrete_LutBuffer:$lookup_table,
// Bootstrap parameters
I32Attr : $bootstrapLevel,
I32Attr : $bootstrapBaseLog,

View File

@@ -53,6 +53,47 @@ def Concrete_NegateLweCiphertextOp : Concrete_Op<"negate_lwe_ciphertext"> {
let results = (outs Concrete_LweCiphertextType:$result);
}
def Concrete_EncodeExpandLutForBootstrapOp : Concrete_Op<"encode_expand_lut_for_bootstrap"> {
let summary =
"Encode and expand a lookup table so that it can be used for a bootstrap.";
let arguments = (ins
1DTensorOf<[I64]> : $input_lookup_table,
I32Attr: $polySize,
I32Attr: $outputBits
);
let results = (outs 1DTensorOf<[I64]> : $result);
}
def Concrete_EncodeExpandLutForWopPBSOp : Concrete_Op<"encode_expand_lut_for_woppbs"> {
let summary =
"Encode and expand a lookup table so that it can be used for a wop pbs.";
let arguments = (ins
1DTensorOf<[I64]> : $input_lookup_table,
I64ArrayAttr: $crtDecomposition,
I64ArrayAttr: $crtBits,
I32Attr : $polySize,
I32Attr : $modulusProduct
);
let results = (outs 1DTensorOf<[I64]> : $result);
}
def Concrete_EncodePlaintextWithCrtOp : Concrete_Op<"encode_plaintext_with_crt"> {
let summary =
"Encodes a plaintext by decomposing it on a crt basis.";
let arguments = (ins
I64 : $input,
I64ArrayAttr: $mods,
I64Attr: $modsProd
);
let results = (outs 1DTensorOf<[I64]> : $result);
}
def Concrete_BootstrapLweOp : Concrete_Op<"bootstrap_lwe", [BatchableOpInterface]> {
let summary = "Bootstraps a LWE ciphertext with a GLWE trivial encryption of the lookup table";
@@ -153,7 +194,7 @@ def Concrete_WopPBSLweOp : Concrete_Op<"wop_pbs_lwe"> {
let summary = "";
let arguments = (ins
Concrete_LweCiphertextType:$ciphertext,
Type<And<[TensorOf<[Concrete_LweCiphertextType]>.predicate, HasStaticShapePred]>>:$ciphertexts,
1DTensorOf<[I64]>:$accumulator,
// Bootstrap parameters
I32Attr : $bootstrapLevel,
@@ -168,9 +209,11 @@ def Concrete_WopPBSLweOp : Concrete_Op<"wop_pbs_lwe"> {
I32Attr : $packingKeySwitchBaseLog,
// Circuit bootstrap parameters
I32Attr : $circuitBootstrapLevel,
I32Attr : $circuitBootstrapBaseLog
I32Attr : $circuitBootstrapBaseLog,
// Crt decomposition
I64ArrayAttr: $crtDecomposition
);
let results = (outs Concrete_LweCiphertextType:$result);
let results = (outs Type<And<[TensorOf<[Concrete_LweCiphertextType]>.predicate, HasStaticShapePred]>>:$result);
}
#endif

View File

@@ -40,9 +40,7 @@ def Concrete_LweCiphertextType : Concrete_Type<"LweCiphertext", [MemRefElementTy
// The dimension of the lwe ciphertext
"signed":$dimension,
// Precision of the lwe ciphertext
"signed":$p,
// CRT decomposition for large integers
ArrayRefParameter<"int64_t", "CRT decomposition">:$crtDecomposition
"signed":$p
);

View File

@@ -19,6 +19,48 @@ include "concretelang/Dialect/TFHE/IR/TFHETypes.td"
class TFHE_Op<string mnemonic, list<Trait> traits = []>
: Op<TFHE_Dialect, mnemonic, traits>;
def TFHE_EncodeExpandLutForBootstrapOp : TFHE_Op<"encode_expand_lut_for_bootstrap"> {
let summary =
"Encode and expand a lookup table so that it can be used for a bootstrap.";
let arguments = (ins
1DTensorOf<[I64]> : $input_lookup_table,
I32Attr: $polySize,
I32Attr: $outputBits
);
let results = (outs 1DTensorOf<[I64]> : $result);
}
def TFHE_EncodeExpandLutForWopPBSOp : TFHE_Op<"encode_expand_lut_for_woppbs"> {
let summary =
"Encode and expand a lookup table so that it can be used for a wop pbs.";
let arguments = (ins
1DTensorOf<[I64]> : $input_lookup_table,
I64ArrayAttr: $crtDecomposition,
I64ArrayAttr: $crtBits,
I32Attr : $polySize,
I32Attr : $modulusProduct
);
let results = (outs 1DTensorOf<[I64]> : $result);
}
def TFHE_EncodePlaintextWithCrtOp : TFHE_Op<"encode_plaintext_with_crt"> {
let summary =
"Encodes a plaintext by decomposing it on a crt basis.";
let arguments = (ins
I64 : $input,
I64ArrayAttr: $mods,
I64Attr: $modsProd
);
let results = (outs 1DTensorOf<[I64]> : $result);
}
def TFHE_ZeroGLWEOp : TFHE_Op<"zero"> {
let summary = "Returns a trivial encyption of 0";
@@ -92,6 +134,7 @@ def TFHE_KeySwitchGLWEOp : TFHE_Op<"keyswitch_glwe"> {
let results = (outs TFHE_GLWECipherTextType : $result);
}
def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe"> {
let summary =
"Programmable bootstraping of a GLWE ciphertext with a lookup table";
@@ -112,7 +155,7 @@ def TFHE_WopPBSGLWEOp : TFHE_Op<"wop_pbs_glwe"> {
let summary = "";
let arguments = (ins
TFHE_GLWECipherTextType : $ciphertext,
Type<And<[TensorOf<[TFHE_GLWECipherTextType]>.predicate, HasStaticShapePred]>>: $ciphertexts,
1DTensorOf<[I64]> : $lookupTable,
// Bootstrap parameters
I32Attr : $bootstrapLevel,
@@ -127,9 +170,11 @@ def TFHE_WopPBSGLWEOp : TFHE_Op<"wop_pbs_glwe"> {
I32Attr : $packingKeySwitchBaseLog,
// Circuit bootstrap parameters
I32Attr : $circuitBootstrapLevel,
I32Attr : $circuitBootstrapBaseLog
I32Attr : $circuitBootstrapBaseLog,
// Crt decomposition
I64ArrayAttr: $crtDecomposition
);
let results = (outs TFHE_GLWECipherTextType:$result);
let results = (outs Type<And<[TensorOf<[TFHE_GLWECipherTextType]>.predicate, HasStaticShapePred]>>:$result);
}
#endif

View File

@@ -25,9 +25,7 @@ def TFHE_GLWECipherTextType
// Number of bits of the ciphertext
"signed":$bits,
// Number of bits of the plain text representation
"signed":$p,
// CRT decomposition for large integers
ArrayRefParameter<"int64_t", "CRT decomposition">:$crtDecomposition
"signed":$p
);
let hasCustomAssemblyFormat = 1;

View File

@@ -21,16 +21,33 @@ extern "C" {
/// \param out_MESSAGE_BITS number of bits of message to be used
/// \param lut original LUT
/// \param lut_size
void encode_and_expand_lut(uint64_t *output, size_t output_size,
size_t out_MESSAGE_BITS, const uint64_t *lut,
size_t lut_size);
void memref_encode_expand_lut_for_bootstrap(
uint64_t *output_lut_allocated, uint64_t *output_lut_aligned,
uint64_t output_lut_offset, uint64_t output_lut_size,
uint64_t output_lut_stride, uint64_t *input_lut_allocated,
uint64_t *input_lut_aligned, uint64_t input_lut_offset,
uint64_t input_lut_size, uint64_t input_lut_stride, uint32_t poly_size,
uint32_t out_MESSAGE_BITS);
void memref_expand_lut_in_trivial_glwe_ct_u64(
uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
uint32_t poly_size, uint32_t glwe_dimension, uint32_t out_precision,
uint64_t *lut_allocated, uint64_t *lut_aligned, uint64_t lut_offset,
uint64_t lut_size, uint64_t lut_stride);
void memref_encode_expand_lut_for_woppbs(
uint64_t *output_lut_allocated, uint64_t *output_lut_aligned,
uint64_t output_lut_offset, uint64_t output_lut_size,
uint64_t output_lut_stride, uint64_t *input_lut_allocated,
uint64_t *input_lut_aligned, uint64_t input_lut_offset,
uint64_t input_lut_size, uint64_t input_lut_stride,
uint64_t *crt_decomposition_allocated, uint64_t *crt_decomposition_aligned,
uint64_t crt_decomposition_offset, uint64_t crt_decomposition_size,
uint64_t crt_decomposition_stride, uint64_t *crt_bits_allocated,
uint64_t *crt_bits_aligned, uint64_t crt_bits_offset,
uint64_t crt_bits_size, uint64_t crt_bits_stride, uint32_t poly_size,
uint32_t modulus_product);
void memref_encode_plaintext_with_crt(
uint64_t *output_allocated, uint64_t *output_aligned,
uint64_t output_offset, uint64_t output_size, uint64_t output_stride,
uint64_t input, uint64_t *mods_allocated, uint64_t *output_lut_aligned,
uint64_t mods_offset, uint64_t output_lut_size, uint64_t mods_stride,
uint64_t mods_product);
void memref_add_lwe_ciphertexts_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,

View File

@@ -192,6 +192,10 @@ public:
/// Read sources and exit before any lowering
FHE,
/// Read sources and lower all the FHELinalg operations to FHE operations
/// and scf loops
FHE_NO_LINALG,
/// Read sources and lower all FHE operations to TFHE
/// operations
TFHE,
@@ -200,10 +204,6 @@ public:
/// operations
CONCRETE,
/// Read sources and lower all FHE and TFHE operations to Concrete
/// operations with all linalg ops replaced by loops
CONCRETEWITHLOOPS,
/// Read sources and lower all FHE, TFHE and Concrete operations to
/// BConcrete operations
BCONCRETE,

View File

@@ -34,6 +34,12 @@ markFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::ArrayRef<int64_t> tileSizes,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelize, bool batch);
mlir::LogicalResult
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,

View File

@@ -43,6 +43,12 @@ char memref_expand_lut_in_trivial_glwe_ct_u64[] =
char memref_wop_pbs_crt_buffer[] = "memref_wop_pbs_crt_buffer";
char memref_encode_plaintext_with_crt[] = "memref_encode_plaintext_with_crt";
char memref_encode_expand_lut_for_bootstrap[] =
"memref_encode_expand_lut_for_bootstrap";
char memref_encode_expand_lut_for_woppbs[] =
"memref_encode_expand_lut_for_woppbs";
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
size_t rank) {
std::vector<int64_t> shape(rank, -1);
@@ -161,6 +167,24 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
contextType,
},
{});
} else if (funcName == memref_encode_plaintext_with_crt) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref1DType, rewriter.getI64Type(),
memref1DType, rewriter.getI64Type()},
{});
} else if (funcName == memref_encode_expand_lut_for_bootstrap) {
funcType =
mlir::FunctionType::get(rewriter.getContext(),
{memref1DType, memref1DType,
rewriter.getI32Type(), rewriter.getI32Type()},
{});
} else if (funcName == memref_encode_expand_lut_for_woppbs) {
funcType = mlir::FunctionType::get(
rewriter.getContext(),
{memref1DType, memref1DType, memref1DType, memref1DType,
rewriter.getI32Type(), rewriter.getI32Type()},
{});
} else {
op->emitError("unknwon external function") << funcName;
return mlir::failure();
@@ -301,6 +325,92 @@ void wopPBSAddOperands(BConcrete::WopPBSCRTLweBufferOp op,
operands.push_back(getContextArgument(op));
}
void encodePlaintextWithCrtAddOperands(
BConcrete::EncodePlaintextWithCrtBufferOp op,
mlir::SmallVector<mlir::Value> &operands, mlir::RewriterBase &rewriter) {
// mods
mlir::Type modsType = mlir::RankedTensorType::get({(int)op.modsAttr().size()},
rewriter.getI64Type());
std::vector<int64_t> modsValues;
for (auto a : op.mods()) {
modsValues.push_back(a.cast<mlir::IntegerAttr>().getValue().getZExtValue());
}
auto modsAttr = rewriter.getI64TensorAttr(modsValues);
auto modsOp =
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), modsAttr, modsType);
auto modsGlobalMemref = mlir::bufferization::getGlobalFor(modsOp, 0);
rewriter.eraseOp(modsOp);
assert(!failed(modsGlobalMemref));
auto modsGlobalRef = rewriter.create<memref::GetGlobalOp>(
op.getLoc(), (*modsGlobalMemref).type(), (*modsGlobalMemref).getName());
operands.push_back(getCastedMemRef(rewriter, modsGlobalRef));
// mods_prod
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.modsProdAttr()));
}
void encodeExpandLutForBootstrapAddOperands(
BConcrete::EncodeExpandLutForBootstrapBufferOp op,
mlir::SmallVector<mlir::Value> &operands, mlir::RewriterBase &rewriter) {
// poly_size
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
// output bits
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.outputBitsAttr()));
}
void encodeExpandLutForWopPBSAddOperands(
BConcrete::EncodeExpandLutForWopPBSBufferOp op,
mlir::SmallVector<mlir::Value> &operands, mlir::RewriterBase &rewriter) {
// crt_decomposition
mlir::Type crtDecompositionType = mlir::RankedTensorType::get(
{(int)op.crtDecompositionAttr().size()}, rewriter.getI64Type());
std::vector<int64_t> crtDecompositionValues;
for (auto a : op.crtDecomposition()) {
crtDecompositionValues.push_back(
a.cast<mlir::IntegerAttr>().getValue().getZExtValue());
}
auto crtDecompositionAttr = rewriter.getI64TensorAttr(crtDecompositionValues);
auto crtDecompositionOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), crtDecompositionAttr, crtDecompositionType);
auto crtDecompositionGlobalMemref =
mlir::bufferization::getGlobalFor(crtDecompositionOp, 0);
rewriter.eraseOp(crtDecompositionOp);
assert(!failed(crtDecompositionGlobalMemref));
auto crtDecompositionGlobalRef = rewriter.create<memref::GetGlobalOp>(
op.getLoc(), (*crtDecompositionGlobalMemref).type(),
(*crtDecompositionGlobalMemref).getName());
operands.push_back(getCastedMemRef(rewriter, crtDecompositionGlobalRef));
// crt_bits
mlir::Type crtBitsType = mlir::RankedTensorType::get(
{(int)op.crtBitsAttr().size()}, rewriter.getI64Type());
std::vector<int64_t> crtBitsValues;
for (auto a : op.crtBits()) {
crtBitsValues.push_back(
a.cast<mlir::IntegerAttr>().getValue().getZExtValue());
}
auto crtBitsAttr = rewriter.getI64TensorAttr(crtBitsValues);
auto crtBitsOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), crtBitsAttr, crtBitsType);
auto crtBitsGlobalMemref = mlir::bufferization::getGlobalFor(crtBitsOp, 0);
rewriter.eraseOp(crtBitsOp);
assert(!failed(crtBitsGlobalMemref));
auto crtBitsGlobalRef = rewriter.create<memref::GetGlobalOp>(
op.getLoc(), (*crtBitsGlobalMemref).type(),
(*crtBitsGlobalMemref).getName());
operands.push_back(getCastedMemRef(rewriter, crtBitsGlobalRef));
// poly_size
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
// modulus_product
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.modulusProductAttr()));
}
struct BConcreteToCAPIPass : public BConcreteToCAPIBase<BConcreteToCAPIPass> {
BConcreteToCAPIPass(bool gpu) : gpu(gpu) {}
@@ -334,6 +444,18 @@ struct BConcreteToCAPIPass : public BConcreteToCAPIBase<BConcreteToCAPIPass> {
patterns.add<BConcreteToCAPICallPattern<BConcrete::NegateLweBufferOp,
memref_negate_lwe_ciphertext_u64>>(
&getContext());
patterns.add<
BConcreteToCAPICallPattern<BConcrete::EncodePlaintextWithCrtBufferOp,
memref_encode_plaintext_with_crt>>(
&getContext(), encodePlaintextWithCrtAddOperands);
patterns.add<BConcreteToCAPICallPattern<
BConcrete::EncodeExpandLutForBootstrapBufferOp,
memref_encode_expand_lut_for_bootstrap>>(
&getContext(), encodeExpandLutForBootstrapAddOperands);
patterns.add<
BConcreteToCAPICallPattern<BConcrete::EncodeExpandLutForWopPBSBufferOp,
memref_encode_expand_lut_for_woppbs>>(
&getContext(), encodeExpandLutForWopPBSAddOperands);
if (gpu) {
patterns.add<BConcreteToCAPICallPattern<BConcrete::KeySwitchLweBufferOp,
memref_keyswitch_lwe_cuda_u64>>(

View File

@@ -1,4 +1,5 @@
add_subdirectory(FHEToTFHE)
add_subdirectory(FHEToTFHEScalar)
add_subdirectory(FHEToTFHECrt)
add_subdirectory(TFHEGlobalParametrization)
add_subdirectory(TFHEToConcrete)
add_subdirectory(FHETensorOpsToLinalg)

View File

@@ -68,10 +68,6 @@ public:
addConversion([&](mlir::concretelang::Concrete::LweCiphertextType type) {
assert(type.getDimension() != -1);
llvm::SmallVector<int64_t, 2> shape;
auto crt = type.getCrtDecomposition();
if (!crt.empty()) {
shape.push_back(crt.size());
}
shape.push_back(type.getDimension() + 1);
return mlir::RankedTensorType::get(
shape, mlir::IntegerType::get(type.getContext(), 64));
@@ -95,10 +91,6 @@ public:
mlir::SmallVector<int64_t> newShape;
newShape.reserve(type.getShape().size() + 1);
newShape.append(type.getShape().begin(), type.getShape().end());
auto crt = lwe.getCrtDecomposition();
if (!crt.empty()) {
newShape.push_back(crt.size());
}
newShape.push_back(lwe.getDimension() + 1);
mlir::Type r = mlir::RankedTensorType::get(
newShape, mlir::IntegerType::get(type.getContext(), 64));
@@ -146,7 +138,7 @@ struct ZeroOpPattern : public mlir::OpRewritePattern<ZeroOp> {
};
};
template <typename ConcreteOp, typename BConcreteOp, typename BConcreteCRTOp>
template <typename ConcreteOp, typename BConcreteOp>
struct LowToBConcrete : public mlir::OpRewritePattern<ConcreteOp> {
LowToBConcrete(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<ConcreteOp>(context, benefit) {}
@@ -161,29 +153,9 @@ struct LowToBConcrete : public mlir::OpRewritePattern<ConcreteOp> {
concreteOp.getOperation()->getAttrs();
mlir::Operation *bConcreteOp;
if (resultTyRange.size() == 1 &&
resultTyRange.front()
.isa<mlir::concretelang::Concrete::LweCiphertextType>()) {
auto crt = resultTyRange.front()
.cast<mlir::concretelang::Concrete::LweCiphertextType>()
.getCrtDecomposition();
if (crt.empty()) {
bConcreteOp = rewriter.replaceOpWithNewOp<BConcreteOp>(
concreteOp, resultTyRange, concreteOp.getOperation()->getOperands(),
attributes);
} else {
auto newAttributes = attributes.vec();
newAttributes.push_back(rewriter.getNamedAttr(
"crtDecomposition", rewriter.getI64ArrayAttr(crt)));
bConcreteOp = rewriter.replaceOpWithNewOp<BConcreteCRTOp>(
concreteOp, resultTyRange, concreteOp.getOperation()->getOperands(),
newAttributes);
}
} else {
bConcreteOp = rewriter.replaceOpWithNewOp<BConcreteOp>(
concreteOp, resultTyRange, concreteOp.getOperation()->getOperands(),
attributes);
}
bConcreteOp = rewriter.replaceOpWithNewOp<BConcreteOp>(
concreteOp, resultTyRange, concreteOp.getOperation()->getOperands(),
attributes);
mlir::concretelang::convertOperandAndResultTypes(
rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) {
@@ -361,7 +333,6 @@ struct AddPlaintextLweCiphertextOpPattern
matchAndRewrite(Concrete::AddPlaintextLweCiphertextOp concreteOp,
::mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
auto loc = concreteOp.getLoc();
mlir::concretelang::Concrete::LweCiphertextType resultTy =
((mlir::Type)concreteOp->getResult(0).getType())
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
@@ -371,31 +342,11 @@ struct AddPlaintextLweCiphertextOpPattern
llvm::ArrayRef<::mlir::NamedAttribute> attributes =
concreteOp.getOperation()->getAttrs();
auto crt = resultTy.getCrtDecomposition();
mlir::Operation *bConcreteOp;
if (crt.empty()) {
// Encode the plaintext value
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
loc, rewriter.getIntegerType(64), concreteOp.rhs());
mlir::Value constantShiftOp = rewriter.create<mlir::arith::ConstantOp>(
loc,
rewriter.getI64IntegerAttr(64 - concreteOp.getType().getP() - 1));
auto encoded = rewriter.create<mlir::arith::ShLIOp>(
loc, rewriter.getI64Type(), castedInt, constantShiftOp);
bConcreteOp =
rewriter.replaceOpWithNewOp<BConcrete::AddPlaintextLweTensorOp>(
concreteOp, newResultTy,
mlir::ValueRange{concreteOp.lhs(), encoded}, attributes);
} else {
// The encoding is done when we eliminate CRT ops
auto newAttributes = attributes.vec();
newAttributes.push_back(rewriter.getNamedAttr(
"crtDecomposition", rewriter.getI64ArrayAttr(crt)));
bConcreteOp =
rewriter.replaceOpWithNewOp<BConcrete::AddPlaintextCRTLweTensorOp>(
concreteOp, newResultTy, concreteOp.getOperation()->getOperands(),
newAttributes);
}
bConcreteOp =
rewriter.replaceOpWithNewOp<BConcrete::AddPlaintextLweTensorOp>(
concreteOp, newResultTy,
mlir::ValueRange{concreteOp.lhs(), concreteOp.rhs()}, attributes);
mlir::concretelang::convertOperandAndResultTypes(
rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) {
@@ -417,7 +368,6 @@ struct MulCleartextLweCiphertextOpPattern
matchAndRewrite(Concrete::MulCleartextLweCiphertextOp concreteOp,
::mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
auto loc = concreteOp.getLoc();
mlir::concretelang::Concrete::LweCiphertextType resultTy =
((mlir::Type)concreteOp->getResult(0).getType())
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
@@ -427,25 +377,11 @@ struct MulCleartextLweCiphertextOpPattern
llvm::ArrayRef<::mlir::NamedAttribute> attributes =
concreteOp.getOperation()->getAttrs();
auto crt = resultTy.getCrtDecomposition();
mlir::Operation *bConcreteOp;
if (crt.empty()) {
// Encode the plaintext value
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
loc, rewriter.getIntegerType(64), concreteOp.rhs());
bConcreteOp =
rewriter.replaceOpWithNewOp<BConcrete::MulCleartextLweTensorOp>(
concreteOp, newResultTy,
mlir::ValueRange{concreteOp.lhs(), castedInt}, attributes);
} else {
auto newAttributes = attributes.vec();
newAttributes.push_back(rewriter.getNamedAttr(
"crtDecomposition", rewriter.getI64ArrayAttr(crt)));
bConcreteOp =
rewriter.replaceOpWithNewOp<BConcrete::MulCleartextCRTLweTensorOp>(
concreteOp, newResultTy, concreteOp.getOperation()->getOperands(),
newAttributes);
}
bConcreteOp =
rewriter.replaceOpWithNewOp<BConcrete::MulCleartextLweTensorOp>(
concreteOp, newResultTy,
mlir::ValueRange{concreteOp.lhs(), concreteOp.rhs()}, attributes);
mlir::concretelang::convertOperandAndResultTypes(
rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) {
@@ -468,11 +404,6 @@ struct ExtractSliceOpPattern
::mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
auto resultTy = extractSliceOp.result().getType();
auto lweResultTy =
resultTy.cast<mlir::RankedTensorType>()
.getElementType()
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
auto nbBlock = lweResultTy.getCrtDecomposition().size();
auto newResultTy =
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
@@ -480,19 +411,12 @@ struct ExtractSliceOpPattern
mlir::SmallVector<mlir::Attribute> staticOffsets;
staticOffsets.append(extractSliceOp.static_offsets().begin(),
extractSliceOp.static_offsets().end());
if (nbBlock != 0) {
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
}
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
// add the lweSize to the sizes
mlir::SmallVector<mlir::Attribute> staticSizes;
staticSizes.append(extractSliceOp.static_sizes().begin(),
extractSliceOp.static_sizes().end());
if (nbBlock != 0) {
staticSizes.push_back(rewriter.getI64IntegerAttr(
newResultTy.getDimSize(newResultTy.getRank() - 2)));
}
staticSizes.push_back(rewriter.getI64IntegerAttr(
newResultTy.getDimSize(newResultTy.getRank() - 1)));
@@ -500,9 +424,6 @@ struct ExtractSliceOpPattern
mlir::SmallVector<mlir::Attribute> staticStrides;
staticStrides.append(extractSliceOp.static_strides().begin(),
extractSliceOp.static_strides().end());
if (nbBlock != 0) {
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
}
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
// replace tensor.extract_slice to the new one
@@ -545,29 +466,20 @@ struct ExtractOpPattern
if (lweResultTy == nullptr) {
return mlir::failure();
}
auto nbBlock = lweResultTy.getCrtDecomposition().size();
auto newResultTy =
converter.convertType(lweResultTy).cast<mlir::RankedTensorType>();
auto rankOfResult = extractOp.indices().size() +
/* for the lwe dimension */ 1 +
/* for the block dimension */
(nbBlock == 0 ? 0 : 1);
auto rankOfResult = extractOp.indices().size() + 1;
// [min..., 0] for static_offsets ()
mlir::SmallVector<mlir::Attribute> staticOffsets(
rankOfResult,
rewriter.getI64IntegerAttr(std::numeric_limits<int64_t>::min()));
if (nbBlock != 0) {
staticOffsets[staticOffsets.size() - 2] = rewriter.getI64IntegerAttr(0);
}
staticOffsets[staticOffsets.size() - 1] = rewriter.getI64IntegerAttr(0);
// [1..., lweDimension+1] for static_sizes or
// [1..., nbBlock, lweDimension+1]
mlir::SmallVector<mlir::Attribute> staticSizes(
rankOfResult, rewriter.getI64IntegerAttr(1));
if (nbBlock != 0) {
staticSizes[staticSizes.size() - 2] = rewriter.getI64IntegerAttr(nbBlock);
}
staticSizes[staticSizes.size() - 1] = rewriter.getI64IntegerAttr(
newResultTy.getDimSize(newResultTy.getRank() - 1));
@@ -577,14 +489,8 @@ struct ExtractOpPattern
// replace tensor.extract_slice to the new one
mlir::SmallVector<int64_t> extractedSliceShape(rankOfResult, 1);
if (nbBlock != 0) {
extractedSliceShape[extractedSliceShape.size() - 2] = nbBlock;
extractedSliceShape[extractedSliceShape.size() - 1] =
newResultTy.getDimSize(1);
} else {
extractedSliceShape[extractedSliceShape.size() - 1] =
newResultTy.getDimSize(0);
}
extractedSliceShape[extractedSliceShape.size() - 1] =
newResultTy.getDimSize(0);
auto extractedSliceType =
mlir::RankedTensorType::get(extractedSliceShape, rewriter.getI64Type());
@@ -601,17 +507,12 @@ struct ExtractOpPattern
});
mlir::ReassociationIndices reassociation;
for (int64_t i = 0;
i < extractedSliceType.getRank() - (nbBlock == 0 ? 0 : 1); i++) {
for (int64_t i = 0; i < extractedSliceType.getRank(); i++) {
reassociation.push_back(i);
}
mlir::SmallVector<mlir::ReassociationIndices> reassocs{reassociation};
if (nbBlock != 0) {
reassocs.push_back({extractedSliceType.getRank() - 1});
}
mlir::tensor::CollapseShapeOp collapseOp =
rewriter.replaceOpWithNewOp<mlir::tensor::CollapseShapeOp>(
extractOp, newResultTy, extractedSlice, reassocs);
@@ -644,7 +545,6 @@ struct InsertSliceOpPattern
if (lweResultTy == nullptr) {
return mlir::failure();
}
auto nbBlock = lweResultTy.getCrtDecomposition().size();
auto newResultTy =
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
@@ -652,19 +552,12 @@ struct InsertSliceOpPattern
mlir::SmallVector<mlir::Attribute> staticOffsets;
staticOffsets.append(insertSliceOp.static_offsets().begin(),
insertSliceOp.static_offsets().end());
if (nbBlock != 0) {
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
}
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
// add lweDimension+1 to static_sizes
mlir::SmallVector<mlir::Attribute> staticSizes;
staticSizes.append(insertSliceOp.static_sizes().begin(),
insertSliceOp.static_sizes().end());
if (nbBlock != 0) {
staticSizes.push_back(rewriter.getI64IntegerAttr(
newResultTy.getDimSize(newResultTy.getRank() - 2)));
}
staticSizes.push_back(rewriter.getI64IntegerAttr(
newResultTy.getDimSize(newResultTy.getRank() - 1)));
@@ -672,9 +565,6 @@ struct InsertSliceOpPattern
mlir::SmallVector<mlir::Attribute> staticStrides;
staticStrides.append(insertSliceOp.static_strides().begin(),
insertSliceOp.static_strides().end());
if (nbBlock != 0) {
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
}
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
// replace tensor.insert_slice with the new one
@@ -710,7 +600,6 @@ struct InsertOpPattern : public mlir::OpRewritePattern<mlir::tensor::InsertOp> {
if (lweResultTy == nullptr) {
return mlir::failure();
};
auto hasBlock = lweResultTy.getCrtDecomposition().size() != 0;
mlir::RankedTensorType newResultTy =
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
@@ -718,9 +607,6 @@ struct InsertOpPattern : public mlir::OpRewritePattern<mlir::tensor::InsertOp> {
mlir::SmallVector<mlir::OpFoldResult> offsets;
offsets.append(insertOp.indices().begin(), insertOp.indices().end());
offsets.push_back(rewriter.getIndexAttr(0));
if (hasBlock) {
offsets.push_back(rewriter.getIndexAttr(0));
}
// Inserting a smaller tensor into a (potentially) bigger one. Set
// dimensions for all leading dimensions of the target tensor not
@@ -729,10 +615,6 @@ struct InsertOpPattern : public mlir::OpRewritePattern<mlir::tensor::InsertOp> {
rewriter.getI64IntegerAttr(1));
// Add size for the bufferized source element
if (hasBlock) {
sizes.push_back(rewriter.getI64IntegerAttr(
newResultTy.getDimSize(newResultTy.getRank() - 2)));
}
sizes.push_back(rewriter.getI64IntegerAttr(
newResultTy.getDimSize(newResultTy.getRank() - 1)));
@@ -871,9 +753,6 @@ struct TensorShapeOpPattern : public mlir::OpRewritePattern<ShapeOp> {
::mlir::PatternRewriter &rewriter) const override {
ConcreteToBConcreteTypeConverter converter;
auto resultTy = ((mlir::Type)shapeOp.result().getType()).cast<VecTy>();
auto lweResultTy =
((mlir::Type)resultTy.getElementType())
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
auto newResultTy =
((mlir::Type)converter.convertType(resultTy)).cast<VecTy>();
@@ -886,12 +765,6 @@ struct TensorShapeOpPattern : public mlir::OpRewritePattern<ShapeOp> {
auto oldReassocs = shapeOp.getReassociationIndices();
mlir::SmallVector<mlir::ReassociationIndices> newReassocs;
newReassocs.append(oldReassocs.begin(), oldReassocs.end());
// add [rank-1] to reassociations if crt decomp
if (!lweResultTy.getCrtDecomposition().empty()) {
mlir::ReassociationIndices lweAssoc;
lweAssoc.push_back(reassocTy.getRank() - 2);
newReassocs.push_back(lweAssoc);
}
// add [rank] to reassociations
{
@@ -1020,14 +893,21 @@ void ConcreteToBConcretePass::runOnOperation() {
LowerBootstrap, LowerBatchedBootstrap, LowerKeySwitch,
LowerBatchedKeySwitch,
LowToBConcrete<mlir::concretelang::Concrete::AddLweCiphertextsOp,
mlir::concretelang::BConcrete::AddLweTensorOp,
BConcrete::AddCRTLweTensorOp>,
mlir::concretelang::BConcrete::AddLweTensorOp>,
AddPlaintextLweCiphertextOpPattern, MulCleartextLweCiphertextOpPattern,
LowToBConcrete<
mlir::concretelang::Concrete::EncodeExpandLutForBootstrapOp,
mlir::concretelang::BConcrete::EncodeExpandLutForBootstrapTensorOp>,
LowToBConcrete<
mlir::concretelang::Concrete::EncodeExpandLutForWopPBSOp,
mlir::concretelang::BConcrete::EncodeExpandLutForWopPBSTensorOp>,
LowToBConcrete<
mlir::concretelang::Concrete::EncodePlaintextWithCrtOp,
mlir::concretelang::BConcrete::EncodePlaintextWithCrtTensorOp>,
LowToBConcrete<mlir::concretelang::Concrete::NegateLweCiphertextOp,
mlir::concretelang::BConcrete::NegateLweTensorOp,
BConcrete::NegateCRTLweTensorOp>,
LowToBConcrete<Concrete::WopPBSLweOp, BConcrete::WopPBSCRTLweTensorOp,
BConcrete::WopPBSCRTLweTensorOp>>(&getContext());
mlir::concretelang::BConcrete::NegateLweTensorOp>,
LowToBConcrete<Concrete::WopPBSLweOp, BConcrete::WopPBSCRTLweTensorOp>>(
&getContext());
// Add patterns to rewrite tensor operators that works on encrypted
// tensors

View File

@@ -1,391 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <iostream>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Linalg/IR/Linalg.h>
#include <mlir/IR/Operation.h>
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "concretelang/Conversion/FHEToTFHE/Patterns.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
#include "concretelang/Dialect/RT/IR/RTDialect.h"
#include "concretelang/Dialect/RT/IR/RTOps.h"
#include "concretelang/Dialect/RT/IR/RTTypes.h"
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
namespace FHE = mlir::concretelang::FHE;
namespace TFHE = mlir::concretelang::TFHE;
namespace {
using mlir::concretelang::FHE::EncryptedIntegerType;
using mlir::concretelang::TFHE::GLWECipherTextType;
/// FHEToTFHETypeConverter is a TypeConverter that transform
/// `FHE.eint<p>` to `TFHE.glwe<{_,_,_}{p}>`
class FHEToTFHETypeConverter : public mlir::TypeConverter {
public:
FHEToTFHETypeConverter() {
addConversion([](mlir::Type type) { return type; });
addConversion([](EncryptedIntegerType type) {
return mlir::concretelang::convertTypeEncryptedIntegerToGLWE(
type.getContext(), type);
});
addConversion([](mlir::RankedTensorType type) {
auto eint =
type.getElementType().dyn_cast_or_null<EncryptedIntegerType>();
if (eint == nullptr) {
return (mlir::Type)(type);
}
mlir::Type r = mlir::RankedTensorType::get(
type.getShape(),
mlir::concretelang::convertTypeEncryptedIntegerToGLWE(
eint.getContext(), eint));
return r;
});
addConversion([&](mlir::concretelang::RT::FutureType type) {
return mlir::concretelang::RT::FutureType::get(
this->convertType(type.dyn_cast<mlir::concretelang::RT::FutureType>()
.getElementType()));
});
addConversion([&](mlir::concretelang::RT::PointerType type) {
return mlir::concretelang::RT::PointerType::get(
this->convertType(type.dyn_cast<mlir::concretelang::RT::PointerType>()
.getElementType()));
});
}
};
/// This rewrite pattern transforms any instance of `FHE.apply_lookup_table`
/// operators.
///
/// Example:
///
/// ```mlir
/// %0 = "FHE.apply_lookup_table"(%ct, %lut): (!FHE.eint<2>, tensor<4xi64>)
/// ->(!FHE.eint<2>)
/// ```
///
/// becomes:
///
/// ```mlir
/// %glwe_ks = "TFHE.keyswitch_glwe"(%ct)
/// {baseLog = -1 : i32, level = -1 : i32}
/// : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
/// %0 = "TFHE.bootstrap_glwe"(%glwe_ks, %lut)
/// {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32,
/// polynomialSize = -1 : i32}
/// : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) ->
/// !TFHE.glwe<{_,_,_}{2}>
/// ```
struct ApplyLookupTableEintOpToKeyswitchBootstrapPattern
: public mlir::OpRewritePattern<FHE::ApplyLookupTableEintOp> {
ApplyLookupTableEintOpToKeyswitchBootstrapPattern(
mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<FHE::ApplyLookupTableEintOp>(context,
benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::ApplyLookupTableEintOp lutOp,
mlir::PatternRewriter &rewriter) const override {
FHEToTFHETypeConverter converter;
auto inputTy = converter.convertType(lutOp.a().getType())
.cast<TFHE::GLWECipherTextType>();
auto resultTy = converter.convertType(lutOp.getType());
auto glweKs = rewriter.create<TFHE::KeySwitchGLWEOp>(
lutOp.getLoc(), inputTy, lutOp.a(), -1, -1);
mlir::concretelang::convertOperandAndResultTypes(
rewriter, glweKs, [&](mlir::MLIRContext *, mlir::Type t) {
return converter.convertType(t);
});
// %0 = "TFHE.bootstrap_glwe"(%glwe_ks, %glwe_lut)
rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
lutOp, resultTy, glweKs, lutOp.lut(), -1, -1, -1, -1);
return ::mlir::success();
};
};
/// This rewrite pattern transforms any instance of `FHE.apply_lookup_table`
/// operators.
///
/// Example:
///
/// ```mlir
/// %0 = "FHE.apply_lookup_table"(%ct, %lut): (!FHE.eint<2>, tensor<4xi64>)
/// ->(!FHE.eint<2>)
/// ```
///
/// becomes:
///
/// ```mlir
/// %0 = "TFHE.wop_pbs_glwe"(%ct, %lut)
/// : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) ->
/// (!TFHE.glwe<{_,_,_}{2}>)
/// ```
struct ApplyLookupTableEintOpToWopPBSPattern
: public mlir::OpRewritePattern<FHE::ApplyLookupTableEintOp> {
ApplyLookupTableEintOpToWopPBSPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<FHE::ApplyLookupTableEintOp>(context,
benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::ApplyLookupTableEintOp lutOp,
mlir::PatternRewriter &rewriter) const override {
FHEToTFHETypeConverter converter;
auto resultTy = converter.convertType(lutOp.getType());
// %0 = "TFHE.wop_pbs_glwe"(%ct, %lut)
// : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) ->
// (!TFHE.glwe<{_,_,_}{2}>)
auto wopPBS = rewriter.replaceOpWithNewOp<TFHE::WopPBSGLWEOp>(
lutOp, resultTy, lutOp.a(), lutOp.lut(), -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1);
mlir::concretelang::convertOperandAndResultTypes(
rewriter, wopPBS, [&](mlir::MLIRContext *, mlir::Type t) {
return converter.convertType(t);
});
return ::mlir::success();
};
};
/// This rewrite pattern transforms any instance of `FHE.sub_eint_int`
/// operators to a negation and an addition.
struct SubEintIntOpPattern : public mlir::OpRewritePattern<FHE::SubEintIntOp> {
SubEintIntOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<FHE::SubEintIntOp>(context, benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::SubEintIntOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value lhs = op.getOperand(0);
mlir::Value rhs = op.getOperand(1);
mlir::Type rhsType = rhs.getType();
mlir::Attribute minusOneAttr = mlir::IntegerAttr::get(rhsType, -1);
mlir::Value minusOne =
rewriter.create<mlir::arith::ConstantOp>(location, minusOneAttr)
.getResult();
mlir::Value negative =
rewriter.create<mlir::arith::MulIOp>(location, rhs, minusOne)
.getResult();
FHEToTFHETypeConverter converter;
auto resultTy = converter.convertType(op.getType());
auto addition =
rewriter.create<TFHE::AddGLWEIntOp>(location, resultTy, lhs, negative);
mlir::concretelang::convertOperandAndResultTypes(
rewriter, addition, [&](mlir::MLIRContext *, mlir::Type t) {
return converter.convertType(t);
});
rewriter.replaceOp(op, {addition.getResult()});
return mlir::success();
};
};
/// This rewrite pattern transforms any instance of `FHE.sub_eint`
/// operators to a negation and an addition.
struct SubEintOpPattern : public mlir::OpRewritePattern<FHE::SubEintOp> {
SubEintOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<FHE::SubEintOp>(context, benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::SubEintOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value lhs = op.getOperand(0);
mlir::Value rhs = op.getOperand(1);
FHEToTFHETypeConverter converter;
auto rhsTy = converter.convertType(rhs.getType());
auto negative = rewriter.create<TFHE::NegGLWEOp>(location, rhsTy, rhs);
mlir::concretelang::convertOperandAndResultTypes(
rewriter, negative, [&](mlir::MLIRContext *, mlir::Type t) {
return converter.convertType(t);
});
auto resultTy = converter.convertType(op.getType());
auto addition = rewriter.create<TFHE::AddGLWEOp>(location, resultTy, lhs,
negative.getResult());
mlir::concretelang::convertOperandAndResultTypes(
rewriter, addition, [&](mlir::MLIRContext *, mlir::Type t) {
return converter.convertType(t);
});
rewriter.replaceOp(op, {addition.getResult()});
return mlir::success();
};
};
struct FHEToTFHEPass : public FHEToTFHEBase<FHEToTFHEPass> {
FHEToTFHEPass(mlir::concretelang::ApplyLookupTableLowering lutLowerStrategy)
: lutLowerStrategy(lutLowerStrategy) {}
void runOnOperation() override {
auto op = this->getOperation();
mlir::ConversionTarget target(getContext());
FHEToTFHETypeConverter converter;
// Mark ops from the target dialect as legal operations
target.addLegalDialect<mlir::concretelang::TFHE::TFHEDialect>();
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
// Make sure that no ops from `FHE` remain after the lowering
target.addIllegalDialect<mlir::concretelang::FHE::FHEDialect>();
// Make sure that no ops `linalg.generic` that have illegal types
target.addDynamicallyLegalOp<mlir::linalg::GenericOp,
mlir::tensor::GenerateOp>(
[&](mlir::Operation *op) {
return (
converter.isLegal(op->getOperandTypes()) &&
converter.isLegal(op->getResultTypes()) &&
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
});
// Make sure that func has legal signature
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
[&](mlir::func::FuncOp funcOp) {
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
converter.isLegal(&funcOp.getBody());
});
target.addDynamicallyLegalOp<mlir::func::ConstantOp>(
[&](mlir::func::ConstantOp op) {
return FunctionConstantOpConversion<FHEToTFHETypeConverter>::isLegal(
op, converter);
});
// Add all patterns required to lower all ops from `FHE` to
// `TFHE`
mlir::RewritePatternSet patterns(&getContext());
populateWithGeneratedFHEToTFHE(patterns);
patterns.add<
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>>(
patterns.getContext(), converter);
switch (lutLowerStrategy) {
case mlir::concretelang::KeySwitchBoostrapLowering:
patterns.add<ApplyLookupTableEintOpToKeyswitchBootstrapPattern>(
&getContext());
break;
case mlir::concretelang::WopPBSLowering:
patterns.add<ApplyLookupTableEintOpToWopPBSPattern>(&getContext());
break;
}
patterns.add<SubEintOpPattern>(&getContext());
patterns.add<SubEintIntOpPattern>(&getContext());
patterns.add<FunctionConstantOpConversion<FHEToTFHETypeConverter>>(
&getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
FHEToTFHETypeConverter>>(
&getContext(), converter);
patterns.add<
mlir::concretelang::GenericTypeConverterPattern<mlir::linalg::YieldOp>>(
patterns.getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
FHEToTFHETypeConverter>>(
&getContext(), converter);
patterns.add<
RegionOpTypeConverterPattern<mlir::scf::ForOp, FHEToTFHETypeConverter>>(
&getContext(), converter);
patterns.add<mlir::concretelang::GenericTypeAndOpConverterPattern<
mlir::concretelang::FHE::ZeroTensorOp,
mlir::concretelang::TFHE::ZeroTensorGLWEOp>>(&getContext(), converter);
mlir::concretelang::populateWithTensorTypeConverterPatterns(
patterns, target, converter);
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
patterns, converter);
// Conversion of RT Dialect Ops
patterns.add<
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
mlir::concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::MakeReadyFutureOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::AwaitFutureOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::CreateAsyncTaskOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::WorkFunctionReturnOp>,
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::MakeReadyFutureOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::AwaitFutureOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::CreateAsyncTaskOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::WorkFunctionReturnOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))
.failed()) {
this->signalPassFailure();
}
}
private:
mlir::concretelang::ApplyLookupTableLowering lutLowerStrategy;
};
} // namespace
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertFHEToTFHEPass(ApplyLookupTableLowering lower) {
return std::make_unique<FHEToTFHEPass>(lower);
}
} // namespace concretelang
} // namespace mlir

View File

@@ -1,6 +1,6 @@
add_mlir_dialect_library(
FHEToTFHE
FHEToTFHE.cpp
FHEToTFHECrt
FHEToTFHECrt.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
DEPENDS
@@ -12,4 +12,4 @@ add_mlir_dialect_library(
MLIRTransforms
MLIRMathDialect)
target_link_libraries(FHEToTFHE PUBLIC MLIRIR)
target_link_libraries(FHEToTFHECrt PUBLIC MLIRIR)

View File

@@ -0,0 +1,960 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <iostream>
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Linalg/IR/Linalg.h>
#include <mlir/IR/Operation.h>
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "concretelang/Conversion/FHEToTFHECrt/Pass.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
#include "concretelang/Dialect/RT/IR/RTDialect.h"
#include "concretelang/Dialect/RT/IR/RTOps.h"
#include "concretelang/Dialect/RT/IR/RTTypes.h"
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
namespace FHE = mlir::concretelang::FHE;
namespace TFHE = mlir::concretelang::TFHE;
namespace concretelang = mlir::concretelang;
namespace fhe_to_tfhe_crt_conversion {
namespace typing {
/// Converts `FHE::EncryptedInteger` into `Tensor<TFHE::GlweCiphetext>`.
mlir::RankedTensorType convertEint(mlir::MLIRContext *context,
FHE::EncryptedIntegerType eint,
uint64_t crtLength) {
return mlir::RankedTensorType::get(
mlir::ArrayRef<int64_t>((int64_t)crtLength),
TFHE::GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth()));
}
/// Converts `Tensor<FHE::EncryptedInteger>` into a
/// `Tensor<TFHE::GlweCiphertext>` if the element type is appropriate. Otherwise
/// return the input type.
mlir::Type maybeConvertEintTensor(mlir::MLIRContext *context,
mlir::RankedTensorType maybeEintTensor,
uint64_t crtLength) {
if (!maybeEintTensor.getElementType().isa<FHE::EncryptedIntegerType>()) {
return (mlir::Type)(maybeEintTensor);
}
auto eint =
maybeEintTensor.getElementType().cast<FHE::EncryptedIntegerType>();
auto currentShape = maybeEintTensor.getShape();
mlir::SmallVector<int64_t> newShape =
mlir::SmallVector<int64_t>(currentShape.begin(), currentShape.end());
newShape.push_back((int64_t)crtLength);
return mlir::RankedTensorType::get(
llvm::ArrayRef<int64_t>(newShape),
TFHE::GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth()));
}
/// Converts the type `FHE::EncryptedInteger` to `Tensor<TFHE::GlweCiphetext>`
/// if the input type is appropriate. Otherwise return the input type.
mlir::Type maybeConvertEint(mlir::MLIRContext *context, mlir::Type t,
uint64_t crtLength) {
if (auto eint = t.dyn_cast<FHE::EncryptedIntegerType>())
return convertEint(context, eint, crtLength);
return t;
}
/// The type converter used to convert `FHE` to `TFHE` types using the crt
/// strategy.
class TypeConverter : public mlir::TypeConverter {
public:
TypeConverter(concretelang::CrtLoweringParameters loweringParameters) {
size_t nMods = loweringParameters.nMods;
addConversion([](mlir::Type type) { return type; });
addConversion([=](FHE::EncryptedIntegerType type) {
return convertEint(type.getContext(), type, nMods);
});
addConversion([=](mlir::RankedTensorType type) {
return maybeConvertEintTensor(type.getContext(), type, nMods);
});
addConversion([&](concretelang::RT::FutureType type) {
return concretelang::RT::FutureType::get(this->convertType(
type.dyn_cast<concretelang::RT::FutureType>().getElementType()));
});
addConversion([&](concretelang::RT::PointerType type) {
return concretelang::RT::PointerType::get(this->convertType(
type.dyn_cast<concretelang::RT::PointerType>().getElementType()));
});
}
/// Returns a lambda that uses this converter to turn one type into another.
std::function<mlir::Type(mlir::MLIRContext *, mlir::Type)>
getConversionLambda() {
return [&](mlir::MLIRContext *, mlir::Type t) { return convertType(t); };
}
};
} // namespace typing
namespace lowering {
/// A pattern rewriter superclass used by most op rewriters during the
/// conversion.
template <typename T> struct CrtOpPattern : public mlir::OpRewritePattern<T> {
/// The lowering parameters are bound to the op rewriter.
concretelang::CrtLoweringParameters loweringParameters;
CrtOpPattern(mlir::MLIRContext *context,
concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<T>(context, benefit),
loweringParameters(params) {}
/// Writes an `scf::for` that loops over the crt dimension of two tensors and
/// execute the input lambda to write the loop body. Returns the first result
/// of the op.
///
/// Note:
/// -----
///
/// + The type of `firstArgTensor` type is used as output type.
mlir::Value writeBinaryTensorLoop(
mlir::Location location, mlir::Value firstTensor,
mlir::Value secondTensor, mlir::PatternRewriter &rewriter,
mlir::function_ref<void(mlir::OpBuilder &, mlir::Location, mlir::Value,
mlir::ValueRange)>
body) const {
// Create the loop
mlir::arith::ConstantOp zeroConstantOp =
rewriter.create<mlir::arith::ConstantIndexOp>(location, 0);
mlir::arith::ConstantOp oneConstantOp =
rewriter.create<mlir::arith::ConstantIndexOp>(location, 1);
mlir::arith::ConstantOp crtSizeConstantOp =
rewriter.create<mlir::arith::ConstantIndexOp>(location,
loweringParameters.nMods);
mlir::scf::ForOp newOp = rewriter.create<mlir::scf::ForOp>(
location, zeroConstantOp, crtSizeConstantOp, oneConstantOp,
mlir::ValueRange{firstTensor, secondTensor}, body);
// Convert the types of the new operation
typing::TypeConverter converter(loweringParameters);
concretelang::convertOperandAndResultTypes(rewriter, newOp,
converter.getConversionLambda());
return newOp.getResult(0);
}
/// Writes an `scf::for` that loops over the crt dimension of one tensor and
/// execute the input lambda to write the loop body. Returns the first result
/// of the op.
///
/// Note:
/// -----
///
/// + The type of `firstArgTensor` type is used as output type.
mlir::Value writeUnaryTensorLoop(
mlir::Location location, mlir::Value tensor,
mlir::PatternRewriter &rewriter,
mlir::function_ref<void(mlir::OpBuilder &, mlir::Location, mlir::Value,
mlir::ValueRange)>
body) const {
// Create the loop
mlir::arith::ConstantOp zeroConstantOp =
rewriter.create<mlir::arith::ConstantIndexOp>(location, 0);
mlir::arith::ConstantOp oneConstantOp =
rewriter.create<mlir::arith::ConstantIndexOp>(location, 1);
mlir::arith::ConstantOp crtSizeConstantOp =
rewriter.create<mlir::arith::ConstantIndexOp>(location,
loweringParameters.nMods);
mlir::scf::ForOp newOp = rewriter.create<mlir::scf::ForOp>(
location, zeroConstantOp, crtSizeConstantOp, oneConstantOp, tensor,
body);
// Convert the types of the new operation
typing::TypeConverter converter(loweringParameters);
concretelang::convertOperandAndResultTypes(rewriter, newOp,
converter.getConversionLambda());
return newOp.getResult(0);
}
/// Writes the crt encoding of a plaintext of arbitrary precision.
mlir::Value writePlaintextCrtEncoding(mlir::Location location,
mlir::Value rawPlaintext,
mlir::PatternRewriter &rewriter) const {
mlir::Value castedPlaintext = rewriter.create<mlir::arith::ExtUIOp>(
location, rewriter.getI64Type(), rawPlaintext);
return rewriter.create<TFHE::EncodePlaintextWithCrtOp>(
location,
mlir::RankedTensorType::get(
mlir::ArrayRef<int64_t>(loweringParameters.nMods),
rewriter.getI64Type()),
castedPlaintext, rewriter.getI64ArrayAttr(loweringParameters.mods),
rewriter.getI64IntegerAttr(loweringParameters.modsProd));
}
};
/// Rewriter for the `FHE::add_eint_int` operation.
struct AddEintIntOpPattern : public CrtOpPattern<FHE::AddEintIntOp> {
AddEintIntOpPattern(mlir::MLIRContext *context,
concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: CrtOpPattern<FHE::AddEintIntOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::AddEintIntOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value eintOperand = op.a();
mlir::Value intOperand = op.b();
// Convert operand type to glwe tensor.
typing::TypeConverter converter(loweringParameters);
intOperand.setType(converter.convertType(intOperand.getType()));
eintOperand.setType(converter.convertType(eintOperand.getType()));
// Write plaintext encoding
mlir::Value encodedPlaintextTensor =
writePlaintextCrtEncoding(op.getLoc(), intOperand, rewriter);
// Write add loop.
mlir::Type ciphertextScalarType =
converter.convertType(eintOperand.getType())
.cast<mlir::RankedTensorType>()
.getElementType();
mlir::Value output = writeBinaryTensorLoop(
location, eintOperand, encodedPlaintextTensor, rewriter,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter,
mlir::ValueRange args) {
mlir::Value extractedEint =
builder.create<mlir::tensor::ExtractOp>(loc, args[0], iter);
mlir::Value extractedInt =
builder.create<mlir::tensor::ExtractOp>(loc, args[1], iter);
mlir::Value output = builder.create<TFHE::AddGLWEIntOp>(
loc, ciphertextScalarType, extractedEint, extractedInt);
mlir::Value newTensor = builder.create<mlir::tensor::InsertOp>(
loc, output, args[0], iter);
builder.create<mlir::scf::YieldOp>(
loc, mlir::ValueRange{newTensor, args[1]});
});
// Rewrite original op.
rewriter.replaceOp(op, output);
return mlir::success();
}
};
/// Rewriter for the `FHE::sub_int_eint` operation.
struct SubIntEintOpPattern : public CrtOpPattern<FHE::SubIntEintOp> {
SubIntEintOpPattern(mlir::MLIRContext *context,
concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: CrtOpPattern<FHE::SubIntEintOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::SubIntEintOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value intOperand = op.a();
mlir::Value eintOperand = op.b();
// Convert operand type to glwe tensor.
typing::TypeConverter converter(loweringParameters);
intOperand.setType(converter.convertType(intOperand.getType()));
eintOperand.setType(converter.convertType(eintOperand.getType()));
// Write plaintext encoding
mlir::Value encodedPlaintextTensor =
writePlaintextCrtEncoding(op.getLoc(), intOperand, rewriter);
// Write add loop.
mlir::Type ciphertextScalarType =
converter.convertType(eintOperand.getType())
.cast<mlir::RankedTensorType>()
.getElementType();
mlir::Value output = writeBinaryTensorLoop(
location, eintOperand, encodedPlaintextTensor, rewriter,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter,
mlir::ValueRange args) {
mlir::Value extractedEint =
builder.create<mlir::tensor::ExtractOp>(loc, args[0], iter);
mlir::Value extractedInt =
builder.create<mlir::tensor::ExtractOp>(loc, args[1], iter);
mlir::Value output = builder.create<TFHE::SubGLWEIntOp>(
loc, ciphertextScalarType, extractedInt, extractedEint);
mlir::Value newTensor = builder.create<mlir::tensor::InsertOp>(
loc, output, args[0], iter);
builder.create<mlir::scf::YieldOp>(
loc, mlir::ValueRange{newTensor, args[1]});
});
// Rewrite original op.
rewriter.replaceOp(op, output);
return mlir::success();
}
};
/// Rewriter for the `FHE::sub_eint_int` operation.
struct SubEintIntOpPattern : public CrtOpPattern<FHE::SubEintIntOp> {
SubEintIntOpPattern(mlir::MLIRContext *context,
concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: CrtOpPattern<FHE::SubEintIntOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::SubEintIntOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value eintOperand = op.a();
mlir::Value intOperand = op.b();
// Convert operand type to glwe tensor.
typing::TypeConverter converter(loweringParameters);
intOperand.setType(converter.convertType(intOperand.getType()));
eintOperand.setType(converter.convertType(eintOperand.getType()));
// Write plaintext negation
mlir::Type intType = intOperand.getType();
mlir::Attribute minusOneAttr = mlir::IntegerAttr::get(intType, -1);
mlir::Value minusOne =
rewriter.create<mlir::arith::ConstantOp>(location, minusOneAttr)
.getResult();
mlir::Value negative =
rewriter.create<mlir::arith::MulIOp>(location, intOperand, minusOne)
.getResult();
// Write plaintext encoding
mlir::Value encodedPlaintextTensor =
writePlaintextCrtEncoding(op.getLoc(), negative, rewriter);
// Write add loop.
mlir::Type ciphertextScalarType =
converter.convertType(eintOperand.getType())
.cast<mlir::RankedTensorType>()
.getElementType();
mlir::Value output = writeBinaryTensorLoop(
location, eintOperand, encodedPlaintextTensor, rewriter,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter,
mlir::ValueRange args) {
mlir::Value extractedEint =
builder.create<mlir::tensor::ExtractOp>(loc, args[0], iter);
mlir::Value extractedInt =
builder.create<mlir::tensor::ExtractOp>(loc, args[1], iter);
mlir::Value output = builder.create<TFHE::AddGLWEIntOp>(
loc, ciphertextScalarType, extractedEint, extractedInt);
mlir::Value newTensor = builder.create<mlir::tensor::InsertOp>(
loc, output, args[0], iter);
builder.create<mlir::scf::YieldOp>(
loc, mlir::ValueRange{newTensor, args[1]});
});
// Rewrite original op.
rewriter.replaceOp(op, output);
return mlir::success();
}
};
/// Rewriter for the `FHE::add_eint` operation.
struct AddEintOpPattern : CrtOpPattern<FHE::AddEintOp> {
AddEintOpPattern(mlir::MLIRContext *context,
concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: CrtOpPattern<FHE::AddEintOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::AddEintOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value lhsOperand = op.a();
mlir::Value rhsOperand = op.b();
// Convert operand type to glwe tensor.
typing::TypeConverter converter(loweringParameters);
lhsOperand.setType(converter.convertType(lhsOperand.getType()));
rhsOperand.setType(converter.convertType(rhsOperand.getType()));
// Write add loop.
mlir::Type ciphertextScalarType =
converter.convertType(lhsOperand.getType())
.cast<mlir::RankedTensorType>()
.getElementType();
mlir::Value output = writeBinaryTensorLoop(
location, lhsOperand, rhsOperand, rewriter,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter,
mlir::ValueRange args) {
mlir::Value extractedLhs =
builder.create<mlir::tensor::ExtractOp>(loc, args[0], iter);
mlir::Value extractedRhs =
builder.create<mlir::tensor::ExtractOp>(loc, args[1], iter);
mlir::Value output = builder.create<TFHE::AddGLWEOp>(
loc, ciphertextScalarType, extractedLhs, extractedRhs);
mlir::Value newTensor = builder.create<mlir::tensor::InsertOp>(
loc, output, args[0], iter);
builder.create<mlir::scf::YieldOp>(
loc, mlir::ValueRange{newTensor, args[1]});
});
// Rewrite original op.
rewriter.replaceOp(op, output);
return mlir::success();
}
};
/// Rewriter for the `FHE::sub_eint` operation.
struct SubEintOpPattern : CrtOpPattern<FHE::SubEintOp> {
SubEintOpPattern(mlir::MLIRContext *context,
concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: CrtOpPattern<FHE::SubEintOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::SubEintOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value lhsOperand = op.a();
mlir::Value rhsOperand = op.b();
// Convert operand type to glwe tensor.
typing::TypeConverter converter(loweringParameters);
lhsOperand.setType(converter.convertType(lhsOperand.getType()));
rhsOperand.setType(converter.convertType(rhsOperand.getType()));
// Write sub loop.
mlir::Type ciphertextScalarType =
converter.convertType(lhsOperand.getType())
.cast<mlir::RankedTensorType>()
.getElementType();
mlir::Value output = writeBinaryTensorLoop(
location, lhsOperand, rhsOperand, rewriter,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter,
mlir::ValueRange args) {
mlir::Value extractedLhs =
builder.create<mlir::tensor::ExtractOp>(loc, args[0], iter);
mlir::Value extractedRhs =
builder.create<mlir::tensor::ExtractOp>(loc, args[1], iter);
mlir::Value negatedRhs = builder.create<TFHE::NegGLWEOp>(
loc, ciphertextScalarType, extractedRhs);
mlir::Value output = builder.create<TFHE::AddGLWEOp>(
loc, ciphertextScalarType, extractedLhs, negatedRhs);
mlir::Value newTensor = builder.create<mlir::tensor::InsertOp>(
loc, output, args[0], iter);
builder.create<mlir::scf::YieldOp>(
loc, mlir::ValueRange{newTensor, args[1]});
});
// Rewrite original op.
rewriter.replaceOp(op, output);
return mlir::success();
}
};
/// Rewriter for the `FHE::neg_eint` operation.
struct NegEintOpPattern : CrtOpPattern<FHE::NegEintOp> {
NegEintOpPattern(mlir::MLIRContext *context,
concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: CrtOpPattern<FHE::NegEintOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::NegEintOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value operand = op.a();
// Convert operand type to glwe tensor.
typing::TypeConverter converter{loweringParameters};
operand.setType(converter.convertType(operand.getType()));
// Write the loop nest.
mlir::Type ciphertextScalarType = converter.convertType(operand.getType())
.cast<mlir::RankedTensorType>()
.getElementType();
mlir::Value loopRes = writeUnaryTensorLoop(
location, operand, rewriter,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter,
mlir::ValueRange args) {
mlir::Value extractedCiphertext =
builder.create<mlir::tensor::ExtractOp>(loc, args[0], iter);
mlir::Value negatedCiphertext = builder.create<TFHE::NegGLWEOp>(
loc, ciphertextScalarType, extractedCiphertext);
mlir::Value newTensor = builder.create<mlir::tensor::InsertOp>(
loc, negatedCiphertext, args[0], iter);
builder.create<mlir::scf::YieldOp>(loc, mlir::ValueRange{newTensor});
});
// Rewrite original op.
rewriter.replaceOp(op, loopRes);
return mlir::success();
}
};
/// Rewriter for the `FHE::mul_eint_int` operation.
struct MulEintIntOpPattern : CrtOpPattern<FHE::MulEintIntOp> {
MulEintIntOpPattern(mlir::MLIRContext *context,
concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: CrtOpPattern<FHE::MulEintIntOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::MulEintIntOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value eintOperand = op.a();
mlir::Value intOperand = op.b();
// Convert operand type to glwe tensor.
typing::TypeConverter converter{loweringParameters};
eintOperand.setType(converter.convertType(eintOperand.getType()));
// Write cleartext "encoding"
mlir::Value encodedCleartext = rewriter.create<mlir::arith::ExtSIOp>(
location, rewriter.getI64Type(), intOperand);
// Write the loop nest.
mlir::Type ciphertextScalarType =
converter.convertType(eintOperand.getType())
.cast<mlir::RankedTensorType>()
.getElementType();
mlir::Value loopRes = writeUnaryTensorLoop(
location, eintOperand, rewriter,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter,
mlir::ValueRange args) {
mlir::Value extractedCiphertext =
builder.create<mlir::tensor::ExtractOp>(loc, args[0], iter);
mlir::Value negatedCiphertext = builder.create<TFHE::MulGLWEIntOp>(
loc, ciphertextScalarType, extractedCiphertext, encodedCleartext);
mlir::Value newTensor = builder.create<mlir::tensor::InsertOp>(
loc, negatedCiphertext, args[0], iter);
builder.create<mlir::scf::YieldOp>(loc, mlir::ValueRange{newTensor});
});
// Rewrite original op.
rewriter.replaceOp(op, loopRes);
return mlir::success();
}
};
/// Rewriter for the `FHE::apply_lookup_table` operation.
struct ApplyLookupTableEintOpPattern
: public CrtOpPattern<FHE::ApplyLookupTableEintOp> {
ApplyLookupTableEintOpPattern(mlir::MLIRContext *context,
concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: CrtOpPattern<FHE::ApplyLookupTableEintOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::ApplyLookupTableEintOp op,
mlir::PatternRewriter &rewriter) const override {
typing::TypeConverter converter(loweringParameters);
mlir::Value newLut =
rewriter
.create<TFHE::EncodeExpandLutForWopPBSOp>(
op.getLoc(),
mlir::RankedTensorType::get(
mlir::ArrayRef<int64_t>(loweringParameters.lutSize),
rewriter.getI64Type()),
op.lut(),
rewriter.getI64ArrayAttr(
mlir::ArrayRef<int64_t>(loweringParameters.mods)),
rewriter.getI64ArrayAttr(
mlir::ArrayRef<int64_t>(loweringParameters.bits)),
rewriter.getI32IntegerAttr(loweringParameters.polynomialSize),
rewriter.getI32IntegerAttr(loweringParameters.modsProd))
.getResult();
// Replace the lut with an encoded / expanded one.
auto wopPBS = rewriter.create<TFHE::WopPBSGLWEOp>(
op.getLoc(), op.getType(), op.a(), newLut, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, rewriter.getI64ArrayAttr({}));
concretelang::convertOperandAndResultTypes(rewriter, wopPBS,
converter.getConversionLambda());
rewriter.replaceOp(op, {wopPBS.getResult()});
return ::mlir::success();
};
};
/// Rewriter for the `tensor::extract` operation.
struct TensorExtractOpPattern : public CrtOpPattern<mlir::tensor::ExtractOp> {
TensorExtractOpPattern(mlir::MLIRContext *context,
concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: CrtOpPattern<mlir::tensor::ExtractOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::tensor::ExtractOp op,
mlir::PatternRewriter &rewriter) const override {
if (!op.getTensor()
.getType()
.cast<mlir::TensorType>()
.getElementType()
.isa<FHE::EncryptedIntegerType>() &&
!op.getTensor()
.getType()
.cast<mlir::TensorType>()
.getElementType()
.isa<TFHE::GLWECipherTextType>()) {
return mlir::success();
}
typing::TypeConverter converter{loweringParameters};
mlir::SmallVector<mlir::OpFoldResult> offsets;
mlir::SmallVector<mlir::OpFoldResult> sizes;
mlir::SmallVector<mlir::OpFoldResult> strides;
for (auto index : op.getIndices()) {
offsets.push_back(index);
sizes.push_back(rewriter.getI64IntegerAttr(1));
strides.push_back(rewriter.getI64IntegerAttr(1));
}
offsets.push_back(
rewriter.create<mlir::arith::ConstantIndexOp>(op.getLoc(), 0)
.getResult());
sizes.push_back(rewriter.getI64IntegerAttr(loweringParameters.nMods));
strides.push_back(rewriter.getI64IntegerAttr(1));
auto newOp = rewriter.create<mlir::tensor::ExtractSliceOp>(
op.getLoc(),
converter.convertType(op.getResult().getType())
.cast<mlir::RankedTensorType>(),
op.getTensor(), offsets, sizes, strides);
concretelang::convertOperandAndResultTypes(rewriter, newOp,
converter.getConversionLambda());
rewriter.replaceOp(op, {newOp.getResult()});
return mlir::success();
}
};
/// Rewriter for the `tensor::extract` operation.
struct TensorInsertOpPattern : public CrtOpPattern<mlir::tensor::InsertOp> {
TensorInsertOpPattern(mlir::MLIRContext *context,
concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: CrtOpPattern<mlir::tensor::InsertOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::tensor::InsertOp op,
mlir::PatternRewriter &rewriter) const override {
if (!op.getDest()
.getType()
.cast<mlir::TensorType>()
.getElementType()
.isa<FHE::EncryptedIntegerType>() &&
!op.getDest()
.getType()
.cast<mlir::TensorType>()
.getElementType()
.isa<TFHE::GLWECipherTextType>()) {
return mlir::success();
}
typing::TypeConverter converter{loweringParameters};
mlir::SmallVector<mlir::OpFoldResult> offsets;
mlir::SmallVector<mlir::OpFoldResult> sizes;
mlir::SmallVector<mlir::OpFoldResult> strides;
for (auto index : op.getIndices()) {
offsets.push_back(index);
sizes.push_back(rewriter.getI64IntegerAttr(1));
strides.push_back(rewriter.getI64IntegerAttr(1));
}
offsets.push_back(
rewriter.create<mlir::arith::ConstantIndexOp>(op.getLoc(), 0)
.getResult());
sizes.push_back(rewriter.getI64IntegerAttr(loweringParameters.nMods));
strides.push_back(rewriter.getI64IntegerAttr(1));
auto newOp = rewriter.create<mlir::tensor::InsertSliceOp>(
op.getLoc(), op.getScalar(), op.getDest(), offsets, sizes, strides);
concretelang::convertOperandAndResultTypes(rewriter, newOp,
converter.getConversionLambda());
rewriter.replaceOp(op, {newOp});
return mlir::success();
}
};
/// Rewriter for the `tensor::from_elements` operation.
struct TensorFromElementsOpPattern
: public CrtOpPattern<mlir::tensor::FromElementsOp> {
TensorFromElementsOpPattern(mlir::MLIRContext *context,
concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: CrtOpPattern<mlir::tensor::FromElementsOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::tensor::FromElementsOp op,
mlir::PatternRewriter &rewriter) const override {
if (!op.getResult()
.getType()
.cast<mlir::RankedTensorType>()
.getElementType()
.isa<FHE::EncryptedIntegerType>() &&
!op.getResult()
.getType()
.cast<mlir::RankedTensorType>()
.getElementType()
.isa<TFHE::GLWECipherTextType>()) {
return mlir::success();
}
typing::TypeConverter converter{loweringParameters};
// Create dest tensor allocation op
mlir::Value outputTensor =
rewriter.create<mlir::bufferization::AllocTensorOp>(
op.getLoc(),
converter.convertType(op.getResult().getType())
.cast<mlir::RankedTensorType>(),
mlir::ValueRange{});
// Create insert_slice ops to insert the different pieces.
auto outputShape =
outputTensor.getType().cast<mlir::RankedTensorType>().getShape();
mlir::SmallVector<mlir::OpFoldResult> offsets{
rewriter.getI64IntegerAttr(0)};
mlir::SmallVector<mlir::OpFoldResult> sizes{rewriter.getI64IntegerAttr(1)};
mlir::SmallVector<mlir::OpFoldResult> strides{
rewriter.getI64IntegerAttr(1)};
for (size_t dimIndex = 1; dimIndex < outputShape.size(); ++dimIndex) {
sizes.push_back(rewriter.getI64IntegerAttr(outputShape[dimIndex]));
strides.push_back(rewriter.getI64IntegerAttr(1));
offsets.push_back(rewriter.getI64IntegerAttr(0));
}
for (size_t insertionIndex = 0; insertionIndex < op.getElements().size();
++insertionIndex) {
offsets[0] = rewriter.getI64IntegerAttr(insertionIndex);
mlir::tensor::InsertSliceOp insertOp =
rewriter.create<mlir::tensor::InsertSliceOp>(
op.getLoc(), op.getElements()[insertionIndex], outputTensor,
offsets, sizes, strides);
concretelang::convertOperandAndResultTypes(
rewriter, insertOp, converter.getConversionLambda());
outputTensor = insertOp.getResult();
}
rewriter.replaceOp(op, {outputTensor});
return mlir::success();
}
};
} // namespace lowering
struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
FHEToTFHECrtPass(concretelang::CrtLoweringParameters params)
: loweringParameters(params) {}
void runOnOperation() override {
auto op = this->getOperation();
mlir::ConversionTarget target(getContext());
typing::TypeConverter converter(loweringParameters);
//------------------------------------------- Marking legal/illegal dialects
target.addIllegalDialect<FHE::FHEDialect>();
target.addLegalDialect<TFHE::TFHEDialect>();
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
target.addDynamicallyLegalOp<mlir::tensor::GenerateOp, mlir::scf::ForOp>(
[&](mlir::Operation *op) {
return (
converter.isLegal(op->getOperandTypes()) &&
converter.isLegal(op->getResultTypes()) &&
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
});
target.addDynamicallyLegalOp<mlir::tensor::InsertOp,
mlir::tensor::ExtractOp, mlir::scf::YieldOp>(
[&](mlir::Operation *op) {
return (converter.isLegal(op->getOperandTypes()) &&
converter.isLegal(op->getResultTypes()));
});
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
[&](mlir::func::FuncOp funcOp) {
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
converter.isLegal(&funcOp.getBody());
});
target.addDynamicallyLegalOp<mlir::func::ConstantOp>(
[&](mlir::func::ConstantOp op) {
return FunctionConstantOpConversion<typing::TypeConverter>::isLegal(
op, converter);
});
target.addLegalOp<mlir::func::CallOp>();
target.addLegalOp<mlir::bufferization::AllocTensorOp>();
concretelang::addDynamicallyLegalTypeOp<mlir::tensor::ExtractSliceOp>(
target, converter);
concretelang::addDynamicallyLegalTypeOp<mlir::tensor::InsertSliceOp>(
target, converter);
concretelang::addDynamicallyLegalTypeOp<mlir::tensor::FromElementsOp>(
target, converter);
concretelang::addDynamicallyLegalTypeOp<mlir::tensor::ExpandShapeOp>(
target, converter);
concretelang::addDynamicallyLegalTypeOp<mlir::tensor::CollapseShapeOp>(
target, converter);
concretelang::addDynamicallyLegalTypeOp<
concretelang::RT::MakeReadyFutureOp>(target, converter);
concretelang::addDynamicallyLegalTypeOp<concretelang::RT::AwaitFutureOp>(
target, converter);
concretelang::addDynamicallyLegalTypeOp<
concretelang::RT::CreateAsyncTaskOp>(target, converter);
concretelang::addDynamicallyLegalTypeOp<
concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter);
concretelang::addDynamicallyLegalTypeOp<
concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(target,
converter);
concretelang::addDynamicallyLegalTypeOp<
concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter);
concretelang::addDynamicallyLegalTypeOp<
concretelang::RT::WorkFunctionReturnOp>(target, converter);
concretelang::addDynamicallyLegalTypeOp<
concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter);
//---------------------------------------------------------- Adding patterns
mlir::RewritePatternSet patterns(&getContext());
// Patterns for the `FHE` dialect operations
patterns.add<
// |_ `FHE::zero_eint`
concretelang::GenericTypeAndOpConverterPattern<FHE::ZeroEintOp,
TFHE::ZeroGLWEOp>,
// |_ `FHE::zero_tensor`
concretelang::GenericTypeAndOpConverterPattern<FHE::ZeroTensorOp,
TFHE::ZeroTensorGLWEOp>>(
&getContext(), converter);
// |_ `FHE::add_eint_int`
patterns.add<lowering::AddEintIntOpPattern,
// |_ `FHE::add_eint`
lowering::AddEintOpPattern,
// |_ `FHE::sub_int_eint`
lowering::SubIntEintOpPattern,
// |_ `FHE::sub_eint_int`
lowering::SubEintIntOpPattern,
// |_ `FHE::sub_eint`
lowering::SubEintOpPattern,
// |_ `FHE::neg_eint`
lowering::NegEintOpPattern,
// |_ `FHE::mul_eint_int`
lowering::MulEintIntOpPattern,
// |_ `FHE::apply_lookup_table`
lowering::ApplyLookupTableEintOpPattern>(&getContext(),
loweringParameters);
// Patterns for the relics of the `FHELinalg` dialect operations.
// |_ `linalg::generic` turned to nested `scf::for`
patterns.add<concretelang::GenericTypeConverterPattern<mlir::scf::ForOp>>(
patterns.getContext(), converter);
patterns.add<concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>>(
patterns.getContext(), converter);
patterns.add<
RegionOpTypeConverterPattern<mlir::scf::ForOp, typing::TypeConverter>>(
&getContext(), converter);
patterns.add<lowering::TensorExtractOpPattern>(&getContext(),
loweringParameters);
patterns.add<lowering::TensorInsertOpPattern>(&getContext(),
loweringParameters);
patterns.add<concretelang::GenericTypeConverterPattern<
mlir::tensor::ExtractSliceOp>>(patterns.getContext(), converter);
patterns.add<
concretelang::GenericTypeConverterPattern<mlir::tensor::InsertSliceOp>>(
patterns.getContext(), converter);
patterns.add<concretelang::GenericTypeConverterPattern<
mlir::tensor::CollapseShapeOp>>(patterns.getContext(), converter);
patterns.add<
concretelang::GenericTypeConverterPattern<mlir::tensor::ExpandShapeOp>>(
patterns.getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
typing::TypeConverter>>(
&getContext(), converter);
// Patterns for `func` dialect operations.
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
patterns, converter);
patterns
.add<concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>>(
patterns.getContext(), converter);
patterns.add<FunctionConstantOpConversion<typing::TypeConverter>>(
&getContext(), converter);
// Pattern for the `tensor::from_element` op.
patterns.add<lowering::TensorFromElementsOpPattern>(patterns.getContext(),
loweringParameters);
// Patterns for the `RT` dialect operations.
patterns
.add<concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::MakeReadyFutureOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::AwaitFutureOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::CreateAsyncTaskOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::BuildReturnPtrPlaceholderOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::DerefReturnPtrPlaceholderOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::WorkFunctionReturnOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
converter);
//--------------------------------------------------------- Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))
.failed()) {
this->signalPassFailure();
}
}
private:
concretelang::CrtLoweringParameters loweringParameters;
};
} // namespace fhe_to_tfhe_crt_conversion
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<mlir::ModuleOp>>
createConvertFHEToTFHECrtPass(CrtLoweringParameters lowering) {
return std::make_unique<fhe_to_tfhe_crt_conversion::FHEToTFHECrtPass>(
lowering);
}
} // namespace concretelang
} // namespace mlir

View File

@@ -0,0 +1,15 @@
add_mlir_dialect_library(
FHEToTFHEScalar
FHEToTFHEScalar.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
DEPENDS
FHEDialect
mlir-headers
LINK_LIBS
PUBLIC
MLIRIR
MLIRTransforms
MLIRMathDialect)
target_link_libraries(FHEToTFHEScalar PUBLIC MLIRIR)

View File

@@ -0,0 +1,518 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <iostream>
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Linalg/IR/Linalg.h>
#include <mlir/IR/Operation.h>
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "concretelang/Conversion/FHEToTFHEScalar/Pass.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Tools.h"
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
#include "concretelang/Dialect/RT/IR/RTDialect.h"
#include "concretelang/Dialect/RT/IR/RTOps.h"
#include "concretelang/Dialect/RT/IR/RTTypes.h"
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
namespace FHE = mlir::concretelang::FHE;
namespace TFHE = mlir::concretelang::TFHE;
namespace concretelang = mlir::concretelang;
namespace fhe_to_tfhe_scalar_conversion {
namespace typing {
/// Converts `FHE::EncryptedInteger` into `TFHE::GlweCiphetext`.
TFHE::GLWECipherTextType convertEint(mlir::MLIRContext *context,
FHE::EncryptedIntegerType eint) {
return TFHE::GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth());
}
/// Converts `Tensor<FHE::EncryptedInteger>` into a
/// `Tensor<TFHE::GlweCiphertext>` if the element type is appropriate.
/// Otherwise return the input type.
mlir::Type maybeConvertEintTensor(mlir::MLIRContext *context,
mlir::RankedTensorType maybeEintTensor) {
if (!maybeEintTensor.getElementType().isa<FHE::EncryptedIntegerType>()) {
return (mlir::Type)(maybeEintTensor);
}
auto eint =
maybeEintTensor.getElementType().cast<FHE::EncryptedIntegerType>();
auto currentShape = maybeEintTensor.getShape();
return mlir::RankedTensorType::get(
currentShape,
TFHE::GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth()));
}
/// Converts the type `FHE::EncryptedInteger` to `TFHE::GlweCiphetext` if the
/// input type is appropriate. Otherwise return the input type.
mlir::Type maybeConvertEint(mlir::MLIRContext *context, mlir::Type t) {
if (auto eint = t.dyn_cast<FHE::EncryptedIntegerType>())
return convertEint(context, eint);
return t;
}
/// The type converter used to convert `FHE` to `TFHE` types using the scalar
/// strategy.
class TypeConverter : public mlir::TypeConverter {
public:
TypeConverter() {
addConversion([](mlir::Type type) { return type; });
addConversion([](FHE::EncryptedIntegerType type) {
return convertEint(type.getContext(), type);
});
addConversion([](mlir::RankedTensorType type) {
return maybeConvertEintTensor(type.getContext(), type);
});
addConversion([&](concretelang::RT::FutureType type) {
return concretelang::RT::FutureType::get(this->convertType(
type.dyn_cast<concretelang::RT::FutureType>().getElementType()));
});
addConversion([&](concretelang::RT::PointerType type) {
return concretelang::RT::PointerType::get(this->convertType(
type.dyn_cast<concretelang::RT::PointerType>().getElementType()));
});
}
/// Returns a lambda that uses this converter to turn one type into another.
std::function<mlir::Type(mlir::MLIRContext *, mlir::Type)>
getConversionLambda() {
return [&](mlir::MLIRContext *, mlir::Type t) { return convertType(t); };
}
};
} // namespace typing
namespace lowering {
/// A pattern rewriter superclass used by most op rewriters during the
/// conversion.
template <typename T>
struct ScalarOpPattern : public mlir::OpRewritePattern<T> {
ScalarOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<T>(context, benefit) {}
/// Writes the encoding of a plaintext of arbitrary precision using shift.
mlir::Value
writePlaintextShiftEncoding(mlir::Location location, mlir::Value rawPlaintext,
int64_t encryptedWidth,
mlir::PatternRewriter &rewriter) const {
int64_t intShift = 64 - 1 - encryptedWidth;
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
location, rewriter.getIntegerType(64), rawPlaintext);
mlir::Value constantShiftOp = rewriter.create<mlir::arith::ConstantOp>(
location, rewriter.getI64IntegerAttr(intShift));
mlir::Value encodedInt = rewriter.create<mlir::arith::ShLIOp>(
location, rewriter.getI64Type(), castedInt, constantShiftOp);
return encodedInt;
}
};
/// Rewriter for the `FHE::zero` operation.
struct ZeroEintOpPattern : public mlir::OpRewritePattern<FHE::ZeroEintOp> {
ZeroEintOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<FHE::ZeroEintOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(FHE::ZeroEintOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
typing::TypeConverter converter;
TFHE::ZeroGLWEOp newOp =
rewriter.create<TFHE::ZeroGLWEOp>(location, op.getType());
concretelang::convertOperandAndResultTypes(rewriter, newOp,
converter.getConversionLambda());
rewriter.replaceOp(op, {newOp.getResult()});
return mlir::success();
}
};
/// Rewriter for the `FHE::add_eint_int` operation.
struct AddEintIntOpPattern : public ScalarOpPattern<FHE::AddEintIntOp> {
AddEintIntOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ScalarOpPattern<FHE::AddEintIntOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(FHE::AddEintIntOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value eintOperand = op.a();
mlir::Value intOperand = op.b();
// Write the plaintext encoding
mlir::Value encodedInt = writePlaintextShiftEncoding(
op.getLoc(), intOperand,
eintOperand.getType().cast<FHE::EncryptedIntegerType>().getWidth(),
rewriter);
// Write the new op
auto newOp = rewriter.create<TFHE::AddGLWEIntOp>(location, op.getType(),
eintOperand, encodedInt);
typing::TypeConverter converter;
concretelang::convertOperandAndResultTypes(rewriter, newOp,
converter.getConversionLambda());
rewriter.replaceOp(op, {newOp.getResult()});
return mlir::success();
}
};
/// Rewriter for the `FHE::sub_eint_int` operation.
struct SubEintIntOpPattern : public ScalarOpPattern<FHE::SubEintIntOp> {
SubEintIntOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ScalarOpPattern<FHE::SubEintIntOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(FHE::SubEintIntOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value eintOperand = op.a();
mlir::Value intOperand = op.b();
// Write the integer negation
mlir::Type intType = intOperand.getType();
mlir::Attribute minusOneAttr = mlir::IntegerAttr::get(intType, -1);
mlir::Value minusOne =
rewriter.create<mlir::arith::ConstantOp>(location, minusOneAttr)
.getResult();
mlir::Value negative =
rewriter.create<mlir::arith::MulIOp>(location, intOperand, minusOne)
.getResult();
// Write the plaintext encoding
mlir::Value encodedInt = writePlaintextShiftEncoding(
op.getLoc(), negative,
eintOperand.getType().cast<FHE::EncryptedIntegerType>().getWidth(),
rewriter);
// Write the new op
auto newOp = rewriter.create<TFHE::AddGLWEIntOp>(location, op.getType(),
eintOperand, encodedInt);
typing::TypeConverter converter;
// Convert the types
concretelang::convertOperandAndResultTypes(rewriter, newOp,
converter.getConversionLambda());
rewriter.replaceOp(op, {newOp.getResult()});
return mlir::success();
};
};
/// Rewriter for the `FHE::sub_int_eint` operation.
struct SubIntEintOpPattern : public ScalarOpPattern<FHE::SubIntEintOp> {
SubIntEintOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ScalarOpPattern<FHE::SubIntEintOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(FHE::SubIntEintOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value intOperand = op.a();
mlir::Value eintOperand = op.b();
// Write the plaintext encoding
mlir::Value encodedInt = writePlaintextShiftEncoding(
op.getLoc(), intOperand,
eintOperand.getType().cast<FHE::EncryptedIntegerType>().getWidth(),
rewriter);
// Write the new op
auto newOp = rewriter.create<TFHE::SubGLWEIntOp>(location, op.getType(),
encodedInt, eintOperand);
typing::TypeConverter converter;
concretelang::convertOperandAndResultTypes(rewriter, newOp,
converter.getConversionLambda());
rewriter.replaceOp(op, {newOp.getResult()});
return mlir::success();
};
};
/// Rewriter for the `FHE::sub_eint` operation.
struct SubEintOpPattern : public ScalarOpPattern<FHE::SubEintOp> {
SubEintOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: ScalarOpPattern<FHE::SubEintOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(FHE::SubEintOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value lhsOperand = op.a();
mlir::Value rhsOperand = op.b();
// Write rhs negation
auto negative = rewriter.create<TFHE::NegGLWEOp>(
location, rhsOperand.getType(), rhsOperand);
typing::TypeConverter converter;
concretelang::convertOperandAndResultTypes(rewriter, negative,
converter.getConversionLambda());
// Write new op.
auto newOp = rewriter.create<TFHE::AddGLWEOp>(
location, op.getType(), lhsOperand, negative.getResult());
concretelang::convertOperandAndResultTypes(rewriter, newOp,
converter.getConversionLambda());
rewriter.replaceOp(op, {newOp.getResult()});
return mlir::success();
};
};
/// Rewriter for the `FHE::mul_eint_int` operation.
struct MulEintIntOpPattern : public ScalarOpPattern<FHE::MulEintIntOp> {
MulEintIntOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ScalarOpPattern<FHE::MulEintIntOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(FHE::MulEintIntOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value eintOperand = op.a();
mlir::Value intOperand = op.b();
// Write the cleartext "encoding"
mlir::Value castedCleartext = rewriter.create<mlir::arith::ExtSIOp>(
location, rewriter.getIntegerType(64), intOperand);
// Write the new op.
auto newOp = rewriter.create<TFHE::MulGLWEIntOp>(
location, op.getType(), eintOperand, castedCleartext);
typing::TypeConverter converter;
concretelang::convertOperandAndResultTypes(rewriter, newOp,
converter.getConversionLambda());
rewriter.replaceOp(op, {newOp.getResult()});
return mlir::success();
}
};
/// Rewriter for the `FHE::apply_lookup_table` operation.
struct ApplyLookupTableEintOpPattern
: public ScalarOpPattern<FHE::ApplyLookupTableEintOp> {
ApplyLookupTableEintOpPattern(
mlir::MLIRContext *context,
concretelang::ScalarLoweringParameters loweringParams,
mlir::PatternBenefit benefit = 1)
: ScalarOpPattern<FHE::ApplyLookupTableEintOp>(context, benefit),
loweringParameters(loweringParams) {}
mlir::LogicalResult
matchAndRewrite(FHE::ApplyLookupTableEintOp op,
mlir::PatternRewriter &rewriter) const override {
size_t outputBits =
op.getResult().getType().cast<FHE::EncryptedIntegerType>().getWidth();
mlir::Value newLut =
rewriter
.create<TFHE::EncodeExpandLutForBootstrapOp>(
op.getLoc(),
mlir::RankedTensorType::get(
mlir::ArrayRef<int64_t>(loweringParameters.polynomialSize),
rewriter.getI64Type()),
op.lut(),
rewriter.getI32IntegerAttr(loweringParameters.polynomialSize),
rewriter.getI32IntegerAttr(outputBits))
.getResult();
// Insert keyswitch
auto ksOp = rewriter.create<TFHE::KeySwitchGLWEOp>(
op.getLoc(), op.a().getType(), op.a(), -1, -1);
typing::TypeConverter converter;
concretelang::convertOperandAndResultTypes(rewriter, ksOp,
converter.getConversionLambda());
// Insert bootstrap
auto bsOp = rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
op, op.getType(), ksOp, newLut, -1, -1, -1, -1);
concretelang::convertOperandAndResultTypes(rewriter, bsOp,
converter.getConversionLambda());
return mlir::success();
};
private:
concretelang::ScalarLoweringParameters loweringParameters;
};
} // namespace lowering
struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
FHEToTFHEScalarPass(concretelang::ScalarLoweringParameters loweringParams)
: loweringParameters(loweringParams){};
void runOnOperation() override {
auto op = this->getOperation();
mlir::ConversionTarget target(getContext());
typing::TypeConverter converter;
//------------------------------------------- Marking legal/illegal dialects
target.addIllegalDialect<FHE::FHEDialect>();
target.addLegalDialect<TFHE::TFHEDialect>();
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
target.addDynamicallyLegalOp<mlir::linalg::GenericOp,
mlir::tensor::GenerateOp>(
[&](mlir::Operation *op) {
return (
converter.isLegal(op->getOperandTypes()) &&
converter.isLegal(op->getResultTypes()) &&
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
});
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
[&](mlir::func::FuncOp funcOp) {
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
converter.isLegal(&funcOp.getBody());
});
target.addDynamicallyLegalOp<mlir::func::ConstantOp>(
[&](mlir::func::ConstantOp op) {
return FunctionConstantOpConversion<typing::TypeConverter>::isLegal(
op, converter);
});
target.addLegalOp<mlir::func::CallOp>();
concretelang::addDynamicallyLegalTypeOp<
concretelang::RT::MakeReadyFutureOp>(target, converter);
concretelang::addDynamicallyLegalTypeOp<concretelang::RT::AwaitFutureOp>(
target, converter);
concretelang::addDynamicallyLegalTypeOp<
concretelang::RT::CreateAsyncTaskOp>(target, converter);
concretelang::addDynamicallyLegalTypeOp<
concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter);
concretelang::addDynamicallyLegalTypeOp<
concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(target,
converter);
concretelang::addDynamicallyLegalTypeOp<
concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter);
concretelang::addDynamicallyLegalTypeOp<
concretelang::RT::WorkFunctionReturnOp>(target, converter);
concretelang::addDynamicallyLegalTypeOp<
concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter);
//---------------------------------------------------------- Adding patterns
mlir::RewritePatternSet patterns(&getContext());
// Patterns for the `FHE` dialect operations
patterns.add<
// |_ `FHE::zero_eint`
concretelang::GenericTypeAndOpConverterPattern<FHE::ZeroEintOp,
TFHE::ZeroGLWEOp>,
// |_ `FHE::zero_tensor`
concretelang::GenericTypeAndOpConverterPattern<FHE::ZeroTensorOp,
TFHE::ZeroTensorGLWEOp>,
// |_ `FHE::neg_eint`
concretelang::GenericTypeAndOpConverterPattern<FHE::NegEintOp,
TFHE::NegGLWEOp>,
// |_ `FHE::add_eint`
concretelang::GenericTypeAndOpConverterPattern<FHE::AddEintOp,
TFHE::AddGLWEOp>>(
&getContext(), converter);
// |_ `FHE::add_eint_int`
patterns.add<lowering::AddEintIntOpPattern,
// |_ `FHE::sub_int_eint`
lowering::SubIntEintOpPattern,
// |_ `FHE::sub_eint_int`
lowering::SubEintIntOpPattern,
// |_ `FHE::sub_eint`
lowering::SubEintOpPattern,
// |_ `FHE::mul_eint_int`
lowering::MulEintIntOpPattern>(&getContext());
// |_ `FHE::apply_lookup_table`
patterns.add<lowering::ApplyLookupTableEintOpPattern>(&getContext(),
loweringParameters);
// Patterns for the relics of the `FHELinalg` dialect operations.
// |_ `linalg::generic` turned to nested `scf::for`
patterns
.add<concretelang::GenericTypeConverterPattern<mlir::linalg::YieldOp>>(
patterns.getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
typing::TypeConverter>>(
&getContext(), converter);
patterns.add<
RegionOpTypeConverterPattern<mlir::scf::ForOp, typing::TypeConverter>>(
&getContext(), converter);
concretelang::populateWithTensorTypeConverterPatterns(patterns, target,
converter);
// Patterns for `func` dialect operations.
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
patterns, converter);
patterns
.add<concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>>(
patterns.getContext(), converter);
patterns.add<FunctionConstantOpConversion<typing::TypeConverter>>(
&getContext(), converter);
// Patterns for the `RT` dialect operations.
patterns
.add<concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::MakeReadyFutureOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::AwaitFutureOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::CreateAsyncTaskOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::BuildReturnPtrPlaceholderOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::DerefReturnPtrPlaceholderOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::WorkFunctionReturnOp>,
concretelang::GenericTypeConverterPattern<
concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
converter);
//--------------------------------------------------------- Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))
.failed()) {
this->signalPassFailure();
}
}
private:
concretelang::ScalarLoweringParameters loweringParameters;
};
} // namespace fhe_to_tfhe_scalar_conversion
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertFHEToTFHEScalarPass(ScalarLoweringParameters loweringParameters) {
return std::make_unique<fhe_to_tfhe_scalar_conversion::FHEToTFHEScalarPass>(
loweringParameters);
}
} // namespace concretelang
} // namespace mlir

View File

@@ -70,17 +70,12 @@ public:
auto dimension = cryptoParameters.getNBigLweDimension();
auto polynomialSize = 1;
auto precision = (signed)type.getP();
auto crtDecomposition =
cryptoParameters.largeInteger.hasValue()
? cryptoParameters.largeInteger->crtDecomposition
: mlir::concretelang::CRTDecomposition{};
if ((int)dimension == type.getDimension() &&
(int)polynomialSize == type.getPolynomialSize()) {
return type;
}
return TFHE::GLWECipherTextType::get(type.getContext(), dimension,
polynomialSize, bits, precision,
crtDecomposition);
polynomialSize, bits, precision);
}
TFHE::GLWECipherTextType glweLookupTableType(GLWECipherTextType &type) {
@@ -89,7 +84,7 @@ public:
auto polynomialSize = cryptoParameters.getPolynomialSize();
auto precision = (signed)type.getP();
return TFHE::GLWECipherTextType::get(type.getContext(), dimension,
polynomialSize, bits, precision, {});
polynomialSize, bits, precision);
}
TFHE::GLWECipherTextType glweIntraPBSType(GLWECipherTextType &type) {
@@ -98,7 +93,7 @@ public:
auto polynomialSize = 1;
auto precision = (signed)type.getP();
return TFHE::GLWECipherTextType::get(type.getContext(), dimension,
polynomialSize, bits, precision, {});
polynomialSize, bits, precision);
}
mlir::concretelang::V0Parameter cryptoParameters;
@@ -181,7 +176,7 @@ struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
mlir::PatternRewriter &rewriter) const override {
auto newOp = rewriter.replaceOpWithNewOp<TFHE::WopPBSGLWEOp>(
wopPBSOp, converter.convertType(wopPBSOp.result().getType()),
wopPBSOp.ciphertext(), wopPBSOp.lookupTable(),
wopPBSOp.ciphertexts(), wopPBSOp.lookupTable(),
// Bootstrap parameters
cryptoParameters.brLevel, cryptoParameters.brLogBase,
// Keyswitch parameters
@@ -195,11 +190,18 @@ struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
cryptoParameters.largeInteger->wopPBS.packingKeySwitch.baseLog,
// Circuit bootstrap parameters
cryptoParameters.largeInteger->wopPBS.circuitBootstrap.level,
cryptoParameters.largeInteger->wopPBS.circuitBootstrap.baseLog);
cryptoParameters.largeInteger->wopPBS.circuitBootstrap.baseLog,
// Crt decomposition
rewriter.getI64ArrayAttr(
cryptoParameters.largeInteger->crtDecomposition));
rewriter.startRootUpdate(newOp);
auto ctType =
wopPBSOp.ciphertexts().getType().cast<mlir::RankedTensorType>();
auto ciphertextType =
wopPBSOp.ciphertext().getType().cast<TFHE::GLWECipherTextType>();
newOp.ciphertext().setType(converter.glweInterPBSType(ciphertextType));
ctType.getElementType().cast<TFHE::GLWECipherTextType>();
auto newType = mlir::RankedTensorType::get(
ctType.getShape(), converter.glweInterPBSType(ciphertextType));
newOp.ciphertexts().setType(newType);
rewriter.finalizeRootUpdate(newOp);
return mlir::success();
};
@@ -290,6 +292,8 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
target.addDynamicallyLegalOp<TFHE::WopPBSGLWEOp>(
[&](TFHE::WopPBSGLWEOp op) {
return !op.getType()
.cast<mlir::RankedTensorType>()
.getElementType()
.cast<TFHE::GLWECipherTextType>()
.hasUnparametrizedParameters();
});

View File

@@ -108,7 +108,7 @@ struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
mlir::Type resultType = converter.convertType(wopOp.getType());
auto newOp = rewriter.replaceOpWithNewOp<Concrete::WopPBSLweOp>(
wopOp, resultType, wopOp.ciphertext(), wopOp.lookupTable(),
wopOp, resultType, wopOp.ciphertexts(), wopOp.lookupTable(),
// Bootstrap parameters
wopOp.bootstrapLevel(), wopOp.bootstrapBaseLog(),
// Keyswitch parameters
@@ -118,12 +118,14 @@ struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
wopOp.packingKeySwitchoutputPolynomialSize(),
wopOp.packingKeySwitchLevel(), wopOp.packingKeySwitchBaseLog(),
// Circuit bootstrap parameters
wopOp.circuitBootstrapLevel(), wopOp.circuitBootstrapBaseLog());
wopOp.circuitBootstrapLevel(), wopOp.circuitBootstrapBaseLog(),
// Crt Decomposition
wopOp.crtDecomposition());
rewriter.startRootUpdate(newOp);
newOp.ciphertext().setType(
converter.convertType(wopOp.ciphertext().getType()));
newOp.ciphertexts().setType(
converter.convertType(wopOp.ciphertexts().getType()));
rewriter.finalizeRootUpdate(newOp);
return ::mlir::success();
@@ -179,6 +181,18 @@ void TFHEToConcretePass::runOnOperation() {
patterns.add<mlir::concretelang::GenericTypeAndOpConverterPattern<
mlir::concretelang::TFHE::ZeroTensorGLWEOp,
mlir::concretelang::Concrete::ZeroTensorLWEOp>>(&getContext(), converter);
patterns.add<mlir::concretelang::GenericTypeAndOpConverterPattern<
mlir::concretelang::TFHE::EncodeExpandLutForBootstrapOp,
mlir::concretelang::Concrete::EncodeExpandLutForBootstrapOp>>(
&getContext(), converter);
patterns.add<mlir::concretelang::GenericTypeAndOpConverterPattern<
mlir::concretelang::TFHE::EncodeExpandLutForWopPBSOp,
mlir::concretelang::Concrete::EncodeExpandLutForWopPBSOp>>(&getContext(),
converter);
patterns.add<mlir::concretelang::GenericTypeAndOpConverterPattern<
mlir::concretelang::TFHE::EncodePlaintextWithCrtOp,
mlir::concretelang::Concrete::EncodePlaintextWithCrtOp>>(&getContext(),
converter);
patterns.add<BootstrapGLWEOpPattern>(&getContext(), converter);
patterns.add<WopPBSGLWEOpPattern>(&getContext(), converter);
target.addDynamicallyLegalOp<Concrete::BootstrapLweOp>(

View File

@@ -135,5 +135,19 @@ void mlir::concretelang::BConcrete::
BConcrete::WopPBSCRTLweTensorOp::attachInterface<TensorToMemrefOp<
BConcrete::WopPBSCRTLweTensorOp, BConcrete::WopPBSCRTLweBufferOp>>(
*ctx);
// encode_plaintext_with_crt_tensor => encode_plaintext_with_crt_buffer
BConcrete::EncodePlaintextWithCrtTensorOp::attachInterface<
TensorToMemrefOp<BConcrete::EncodePlaintextWithCrtTensorOp,
BConcrete::EncodePlaintextWithCrtBufferOp>>(*ctx);
// encode_expand_lut_for_bootstrap_tensor =>
// encode_expand_lut_for_bootstrap_buffer
BConcrete::EncodeExpandLutForBootstrapTensorOp::attachInterface<
TensorToMemrefOp<BConcrete::EncodeExpandLutForBootstrapTensorOp,
BConcrete::EncodeExpandLutForBootstrapBufferOp>>(*ctx);
// encode_expand_lut_for_woppbs_tensor =>
// encode_expand_lut_for_woppbs_buffer
BConcrete::EncodeExpandLutForWopPBSTensorOp::attachInterface<
TensorToMemrefOp<BConcrete::EncodeExpandLutForWopPBSTensorOp,
BConcrete::EncodeExpandLutForWopPBSBufferOp>>(*ctx);
});
}

View File

@@ -2,7 +2,6 @@ add_mlir_dialect_library(
ConcretelangBConcreteTransforms
BufferizableOpInterfaceImpl.cpp
AddRuntimeContext.cpp
EliminateCRTOps.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/BConcrete
DEPENDS

View File

@@ -1,561 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "concretelang/ClientLib/CRT.h"
#include "concretelang/Conversion/Tools.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
#include "concretelang/Dialect/BConcrete/Transforms/Passes.h"
namespace arith = mlir::arith;
namespace tensor = mlir::tensor;
namespace bufferization = mlir::bufferization;
namespace scf = mlir::scf;
namespace BConcrete = mlir::concretelang::BConcrete;
namespace crt = concretelang::clientlib::crt;
namespace {
char encode_crt[] = "encode_crt";
// This template rewrite pattern transforms any instance of
// `BConcreteCRTOp` operators to `BConcreteOp` on
// each block.
//
// Example:
//
// ```mlir
// %0 = "BConcreteCRTOp"(%arg0, %arg1) {crtDecomposition = [...]}
// : (tensor<nbBlocksxlweSizexi64>, tensor<nbBlocksxlweSizexi64>) ->
// (tensor<nbBlocksxlweSizexi64>)
// ```
//
// becomes:
//
// ```mlir
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
// %blockArg = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
// %tmp = "BConcreteOp"(%blockArg)
// : (tensor<lweSizexi64>) -> (tensor<lweSizexi64>)
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
// }
// ```
template <typename BConcreteCRTOp, typename BConcreteOp>
struct BConcreteCRTUnaryOpPattern
: public mlir::OpRewritePattern<BConcreteCRTOp> {
BConcreteCRTUnaryOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<BConcreteCRTOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(BConcreteCRTOp op,
mlir::PatternRewriter &rewriter) const override {
auto resultTy =
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
auto loc = op.getLoc();
assert(resultTy.getShape().size() == 2);
auto shape = resultTy.getShape();
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto cB = rewriter.create<arith::ConstantIndexOp>(loc, shape[0]);
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
op.getLoc(), resultTy, mlir::ValueRange{});
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
rewriter.replaceOpWithNewOp<scf::ForOp>(
op, c0, cB, c1, init,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i,
mlir::ValueRange iterArgs) {
// [%i, 0]
mlir::SmallVector<mlir::OpFoldResult> offsets{
i, rewriter.getI64IntegerAttr(0)};
// [1, lweSize]
mlir::SmallVector<mlir::OpFoldResult> sizes{
rewriter.getI64IntegerAttr(1),
rewriter.getI64IntegerAttr(shape[1])};
// [1, 1]
mlir::SmallVector<mlir::OpFoldResult> strides{
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
auto blockTy = mlir::RankedTensorType::get({shape[1]},
resultTy.getElementType());
// %blockArg = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
auto blockArg = builder.create<tensor::ExtractSliceOp>(
loc, blockTy, op.ciphertext(), offsets, sizes, strides);
// %tmp = "BConcrete.add_lwe_buffer"(%blockArg0, %blockArg1)
// : (tensor<lweSizexi64>, tensor<lweSizexi64>) ->
// (tensor<lweSizexi64>)
auto tmp = builder.create<BConcreteOp>(loc, blockTy, blockArg);
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
// 1] : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
auto res = builder.create<tensor::InsertSliceOp>(
loc, tmp, iterArgs[0], offsets, sizes, strides);
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
builder.create<scf::YieldOp>(loc, (mlir::Value)res);
});
return mlir::success();
}
};
// This template rewrite pattern transforms any instance of
// `BConcreteCRTOp` operators to `BConcreteOp` on
// each block.
//
// Example:
//
// ```mlir
// %0 = "BConcreteCRTOp"(%arg0, %arg1) {crtDecomposition = [...]}
// : (tensor<nbBlocksxlweSizexi64>, tensor<nbBlocksxlweSizexi64>) ->
// (tensor<nbBlocksxlweSizexi64>)
// ```
//
// becomes:
//
// ```mlir
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
// %blockArg1 = tensor.extract_slice %arg1[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
// : (tensor<lweSizexi64>, tensor<lweSizexi64>) ->
// (tensor<lweSizexi64>)
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
// }
// ```
template <typename BConcreteCRTOp, typename BConcreteOp>
struct BConcreteCRTBinaryOpPattern
: public mlir::OpRewritePattern<BConcreteCRTOp> {
BConcreteCRTBinaryOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<BConcreteCRTOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(BConcreteCRTOp op,
mlir::PatternRewriter &rewriter) const override {
auto resultTy =
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
auto loc = op.getLoc();
assert(resultTy.getShape().size() == 2);
auto shape = resultTy.getShape();
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto cB = rewriter.create<arith::ConstantIndexOp>(loc, shape[0]);
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
op.getLoc(), resultTy, mlir::ValueRange{});
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
rewriter.replaceOpWithNewOp<scf::ForOp>(
op, c0, cB, c1, init,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i,
mlir::ValueRange iterArgs) {
// [%i, 0]
mlir::SmallVector<mlir::OpFoldResult> offsets{
i, rewriter.getI64IntegerAttr(0)};
// [1, lweSize]
mlir::SmallVector<mlir::OpFoldResult> sizes{
rewriter.getI64IntegerAttr(1),
rewriter.getI64IntegerAttr(shape[1])};
// [1, 1]
mlir::SmallVector<mlir::OpFoldResult> strides{
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
auto blockTy = mlir::RankedTensorType::get({shape[1]},
resultTy.getElementType());
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
auto blockArg0 = builder.create<tensor::ExtractSliceOp>(
loc, blockTy, op.lhs(), offsets, sizes, strides);
// %blockArg1 = tensor.extract_slice %arg1[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
auto blockArg1 = builder.create<tensor::ExtractSliceOp>(
loc, blockTy, op.rhs(), offsets, sizes, strides);
// %tmp = "BConcrete.add_lwe_buffer"(%blockArg0, %blockArg1)
// : (tensor<lweSizexi64>, tensor<lweSizexi64>) ->
// (tensor<lweSizexi64>)
auto tmp =
builder.create<BConcreteOp>(loc, blockTy, blockArg0, blockArg1);
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
// 1] : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
auto res = builder.create<tensor::InsertSliceOp>(
loc, tmp, iterArgs[0], offsets, sizes, strides);
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
builder.create<scf::YieldOp>(loc, (mlir::Value)res);
});
return mlir::success();
}
};
// This template rewrite pattern transforms any instance of
// `BConcreteCRTOp` operators to `BConcreteOp` on
// each block with the crt decomposition of the cleartext.
//
// Example:
//
// ```mlir
// %0 = "BConcreteCRTOp"(%arg0, %x) {crtDecomposition = [d0...dn]}
// : (tensor<nbBlocksxlweSizexi64>, i64) -> (tensor<nbBlocksxlweSizexi64>)
// ```
//
// becomes:
//
// ```mlir
// // Build the decomposition of the plaintext
// %x0_a = arith.constant 64/d0 : f64
// %x0_b = arith.mulf %x, %x0_a : i64
// %x0 = arith.fptoui %x0_b : f64 to i64
// ...
// %xn_a = arith.constant 64/dn : f64
// %xn_b = arith.mulf %x, %xn_a : i64
// %xn = arith.fptoui %xn_b : f64 to i64
// %x_decomp = tensor.from_elements %x0, ..., %xn : tensor<nbBlocksxi64>
// // Loop on blocks
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
// %blockArg1 = tensor.extract %x_decomp[%i] : tensor<nbBlocksxi64>
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
// }
// ```
struct AddPlaintextCRTLweTensorOpPattern
: public mlir::OpRewritePattern<BConcrete::AddPlaintextCRTLweTensorOp> {
AddPlaintextCRTLweTensorOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<BConcrete::AddPlaintextCRTLweTensorOp>(context,
benefit) {
}
mlir::LogicalResult
matchAndRewrite(BConcrete::AddPlaintextCRTLweTensorOp op,
mlir::PatternRewriter &rewriter) const override {
auto resultTy =
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
auto loc = op.getLoc();
assert(resultTy.getShape().size() == 2);
auto shape = resultTy.getShape();
auto rhs = op.rhs();
mlir::SmallVector<mlir::Value, 5> plaintextElements;
uint64_t moduliProduct = 1;
for (mlir::Attribute di : op.crtDecomposition()) {
moduliProduct *= di.cast<mlir::IntegerAttr>().getValue().getZExtValue();
}
if (auto cst =
mlir::dyn_cast_or_null<arith::ConstantIntOp>(rhs.getDefiningOp())) {
auto apCst = cst.getValue().cast<mlir::IntegerAttr>().getValue();
auto value = apCst.getSExtValue();
// constant value, encode at compile time
for (mlir::Attribute di : op.crtDecomposition()) {
auto modulus = di.cast<mlir::IntegerAttr>().getValue().getZExtValue();
auto encoded = crt::encode(value, modulus, moduliProduct);
plaintextElements.push_back(
rewriter.create<arith::ConstantIntOp>(loc, encoded, 64));
}
} else {
// dynamic value, encode at runtime
if (insertForwardDeclaration(
op, rewriter, encode_crt,
mlir::FunctionType::get(rewriter.getContext(),
{rewriter.getI64Type(),
rewriter.getI64Type(),
rewriter.getI64Type()},
{rewriter.getI64Type()}))
.failed()) {
return mlir::failure();
}
auto extOp =
rewriter.create<arith::ExtSIOp>(loc, rewriter.getI64Type(), rhs);
auto moduliProductOp =
rewriter.create<arith::ConstantIntOp>(loc, moduliProduct, 64);
for (mlir::Attribute di : op.crtDecomposition()) {
auto modulus = di.cast<mlir::IntegerAttr>().getValue().getZExtValue();
auto modulusOp =
rewriter.create<arith::ConstantIntOp>(loc, modulus, 64);
plaintextElements.push_back(
rewriter
.create<mlir::func::CallOp>(
loc, encode_crt, mlir::TypeRange{rewriter.getI64Type()},
mlir::ValueRange{extOp, modulusOp, moduliProductOp})
.getResult(0));
}
}
// %x_decomp = tensor.from_elements %x0, ..., %xn : tensor<nbBlocksxi64>
auto x_decomp =
rewriter.create<tensor::FromElementsOp>(loc, plaintextElements);
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto cB = rewriter.create<arith::ConstantIndexOp>(loc, shape[0]);
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
op.getLoc(), resultTy, mlir::ValueRange{});
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
rewriter.replaceOpWithNewOp<scf::ForOp>(
op, c0, cB, c1, init,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i,
mlir::ValueRange iterArgs) {
// [%i, 0]
mlir::SmallVector<mlir::OpFoldResult> offsets{
i, rewriter.getI64IntegerAttr(0)};
// [1, lweSize]
mlir::SmallVector<mlir::OpFoldResult> sizes{
rewriter.getI64IntegerAttr(1),
rewriter.getI64IntegerAttr(shape[1])};
// [1, 1]
mlir::SmallVector<mlir::OpFoldResult> strides{
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
auto blockTy = mlir::RankedTensorType::get({shape[1]},
resultTy.getElementType());
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
auto blockArg0 = builder.create<tensor::ExtractSliceOp>(
loc, blockTy, op.lhs(), offsets, sizes, strides);
// %blockArg1 = tensor.extract %x_decomp[%i] : tensor<nbBlocksxi64>
auto blockArg1 = builder.create<tensor::ExtractOp>(loc, x_decomp, i);
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
auto tmp = builder.create<BConcrete::AddPlaintextLweTensorOp>(
loc, blockTy, blockArg0, blockArg1);
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
// 1] : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
auto res = builder.create<tensor::InsertSliceOp>(
loc, tmp, iterArgs[0], offsets, sizes, strides);
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
builder.create<scf::YieldOp>(loc, (mlir::Value)res);
});
return mlir::success();
}
};
// This template rewrite pattern transforms any instance of
// `BConcreteCRTOp` operators to `BConcreteOp` on
// each block with the crt decomposition of the cleartext.
//
// Example:
//
// ```mlir
// %0 = "BConcreteCRTOp"(%arg0, %x) {crtDecomposition = [d0...dn]}
// : (tensor<nbBlocksxlweSizexi64>, i64) -> (tensor<nbBlocksxlweSizexi64>)
// ```
//
// becomes:
//
// ```mlir
// // Build the decomposition of the plaintext
// %x0_a = arith.constant 64/d0 : f64
// %x0_b = arith.mulf %x, %x0_a : i64
// %x0 = arith.fptoui %x0_b : f64 to i64
// ...
// %xn_a = arith.constant 64/dn : f64
// %xn_b = arith.mulf %x, %xn_a : i64
// %xn = arith.fptoui %xn_b : f64 to i64
// %x_decomp = tensor.from_elements %x0, ..., %xn : tensor<nbBlocksxi64>
// // Loop on blocks
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
// %blockArg1 = tensor.extract %x_decomp[%i] : tensor<nbBlocksxi64>
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
// }
// ```
struct MulCleartextCRTLweTensorOpPattern
: public mlir::OpRewritePattern<BConcrete::MulCleartextCRTLweTensorOp> {
MulCleartextCRTLweTensorOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<BConcrete::MulCleartextCRTLweTensorOp>(context,
benefit) {
}
mlir::LogicalResult
matchAndRewrite(BConcrete::MulCleartextCRTLweTensorOp op,
mlir::PatternRewriter &rewriter) const override {
auto resultTy =
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
auto loc = op.getLoc();
assert(resultTy.getShape().size() == 2);
auto shape = resultTy.getShape();
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %cB = arith.constant nbBlocks : index
auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto cB = rewriter.create<arith::ConstantIndexOp>(loc, shape[0]);
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
op.getLoc(), resultTy, mlir::ValueRange{});
auto rhs = rewriter.create<arith::ExtSIOp>(op.getLoc(),
rewriter.getI64Type(), op.rhs());
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
// (tensor<nbBlocksxlweSizexi64>) {
rewriter.replaceOpWithNewOp<scf::ForOp>(
op, c0, cB, c1, init,
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i,
mlir::ValueRange iterArgs) {
// [%i, 0]
mlir::SmallVector<mlir::OpFoldResult> offsets{
i, rewriter.getI64IntegerAttr(0)};
// [1, lweSize]
mlir::SmallVector<mlir::OpFoldResult> sizes{
rewriter.getI64IntegerAttr(1),
rewriter.getI64IntegerAttr(shape[1])};
// [1, 1]
mlir::SmallVector<mlir::OpFoldResult> strides{
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
auto blockTy = mlir::RankedTensorType::get({shape[1]},
resultTy.getElementType());
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
// : tensor<lweSizexi64>
auto blockArg0 = builder.create<tensor::ExtractSliceOp>(
loc, blockTy, op.lhs(), offsets, sizes, strides);
// %tmp = BConcrete.mul_cleartext_lwe_buffer(%blockArg0, %x)
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
auto tmp = builder.create<BConcrete::MulCleartextLweTensorOp>(
loc, blockTy, blockArg0, rhs);
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
// 1] : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
auto res = builder.create<tensor::InsertSliceOp>(
loc, tmp, iterArgs[0], offsets, sizes, strides);
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
builder.create<scf::YieldOp>(loc, (mlir::Value)res);
});
return mlir::success();
}
};
struct EliminateCRTOpsPass : public EliminateCRTOpsBase<EliminateCRTOpsPass> {
void runOnOperation() final;
};
void EliminateCRTOpsPass::runOnOperation() {
auto op = getOperation();
mlir::ConversionTarget target(getContext());
mlir::RewritePatternSet patterns(&getContext());
// add_crt_lwe_buffers
target.addIllegalOp<BConcrete::AddCRTLweTensorOp>();
patterns.add<BConcreteCRTBinaryOpPattern<BConcrete::AddCRTLweTensorOp,
BConcrete::AddLweTensorOp>>(
&getContext());
// add_plaintext_crt_lwe_buffers
target.addIllegalOp<BConcrete::AddPlaintextCRTLweTensorOp>();
patterns.add<AddPlaintextCRTLweTensorOpPattern>(&getContext());
// mul_cleartext_crt_lwe_buffer
target.addIllegalOp<BConcrete::MulCleartextCRTLweTensorOp>();
patterns.add<MulCleartextCRTLweTensorOpPattern>(&getContext());
target.addIllegalOp<BConcrete::NegateCRTLweTensorOp>();
patterns.add<BConcreteCRTUnaryOpPattern<BConcrete::NegateCRTLweTensorOp,
BConcrete::NegateLweTensorOp>>(
&getContext());
// This dialect are used to transforms crt ops to bconcrete ops
target
.addLegalDialect<arith::ArithmeticDialect, tensor::TensorDialect,
scf::SCFDialect, bufferization::BufferizationDialect,
mlir::func::FuncDialect, BConcrete::BConcreteDialect>();
// Apply the conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
this->signalPassFailure();
return;
}
}
} // namespace
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<func::FuncOp>> createEliminateCRTOps() {
return std::make_unique<EliminateCRTOpsPass>();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -67,18 +67,6 @@ void GlweCiphertextType::print(mlir::AsmPrinter &p) const {
void LweCiphertextType::print(mlir::AsmPrinter &p) const {
p << "<";
// decomposition parameters if any
auto crt = getCrtDecomposition();
if (!crt.empty()) {
p << "crt=[";
for (auto c : crt.drop_back(1)) {
printSigned(p, c);
p << ",";
}
printSigned(p, crt.back());
p << "]";
p << ",";
}
printSigned(p, getDimension());
p << ",";
printSigned(p, getP());
@@ -89,29 +77,6 @@ mlir::Type LweCiphertextType::parse(mlir::AsmParser &parser) {
if (parser.parseLess())
return mlir::Type();
// Parse for the crt decomposition if any
std::vector<int64_t> crtDecomposition;
if (!parser.parseOptionalKeyword("crt")) {
if (parser.parseEqual() || parser.parseLSquare())
return mlir::Type();
while (true) {
int64_t c = -1;
if (parser.parseOptionalKeyword("_") && parser.parseInteger(c)) {
return mlir::Type();
}
crtDecomposition.push_back(c);
if (parser.parseOptionalComma()) {
if (parser.parseRSquare()) {
return mlir::Type();
} else {
break;
}
}
}
if (parser.parseComma())
return mlir::Type();
}
int dimension = -1;
if (parser.parseOptionalKeyword("_") && parser.parseInteger(dimension))
return mlir::Type();
@@ -125,7 +90,7 @@ mlir::Type LweCiphertextType::parse(mlir::AsmParser &parser) {
mlir::Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
return getChecked(loc, loc.getContext(), dimension, p, crtDecomposition);
return getChecked(loc, loc.getContext(), dimension, p);
}
void CleartextType::print(mlir::AsmPrinter &p) const {

View File

@@ -32,8 +32,7 @@ void TFHEDialect::initialize() {
/// - The bits parameter is 64 (we support only this for v0)
::mlir::LogicalResult GLWECipherTextType::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
signed dimension, signed polynomialSize, signed bits, signed p,
llvm::ArrayRef<int64_t>) {
signed dimension, signed polynomialSize, signed bits, signed p) {
if (bits != -1 && bits != 64) {
emitError() << "GLWE bits parameter can only be 64";
return ::mlir::failure();

View File

@@ -40,16 +40,14 @@ mlir::LogicalResult _verifyGLWEIntegerOperator(mlir::OpState &op,
emitOpErrorForIncompatibleGLWEParameter(op, "p");
return mlir::failure();
}
if (a.getCrtDecomposition() != result.getCrtDecomposition()) {
emitOpErrorForIncompatibleGLWEParameter(op, "crt");
return mlir::failure();
}
if ((int)b.getWidth() != 64) {
op.emitOpError() << "should have the width of `b` equals to 64.";
}
// verify consistency of width of inputs
if ((int)b.getWidth() > a.getP() + 1) {
op.emitOpError()
<< "should have the width of `b` equals or less than 'p'+1: "
<< b.getWidth() << " <= " << a.getP() << "+ 1";
if ((int)b.getWidth() != 64) {
op.emitOpError() << "should have the width of `b` equals 64 : "
<< b.getWidth() << " != 64";
return mlir::failure();
}
return mlir::success();
@@ -111,11 +109,6 @@ mlir::LogicalResult verifyBinaryGLWEOperator(Operator &op) {
emitOpErrorForIncompatibleGLWEParameter(op, "p");
return mlir::failure();
}
if (a.getCrtDecomposition() != b.getCrtDecomposition() ||
a.getCrtDecomposition() != result.getCrtDecomposition()) {
emitOpErrorForIncompatibleGLWEParameter(op, "crt");
return mlir::failure();
}
return mlir::success();
}
@@ -146,10 +139,6 @@ mlir::LogicalResult verifyUnaryGLWEOperator(Operator &op) {
emitOpErrorForIncompatibleGLWEParameter(op, "p");
return mlir::failure();
}
if (a.getCrtDecomposition() != result.getCrtDecomposition()) {
emitOpErrorForIncompatibleGLWEParameter(op, "crt");
return mlir::failure();
}
return mlir::success();
}

View File

@@ -18,16 +18,6 @@ void printSigned(mlir::AsmPrinter &p, signed i) {
void GLWECipherTextType::print(mlir::AsmPrinter &p) const {
p << "<";
auto crt = getCrtDecomposition();
if (!crt.empty()) {
p << "crt=[";
for (auto c : crt.drop_back(1)) {
printSigned(p, c);
p << ",";
}
printSigned(p, crt.back());
p << "]";
}
p << "{";
printSigned(p, getDimension());
p << ",";
@@ -45,27 +35,6 @@ mlir::Type GLWECipherTextType::parse(AsmParser &parser) {
if (parser.parseLess())
return mlir::Type();
// Parse for the crt decomposition if any
std::vector<int64_t> crtDecomposition;
if (!parser.parseOptionalKeyword("crt")) {
if (parser.parseEqual() || parser.parseLSquare())
return mlir::Type();
while (true) {
signed c = -1;
if (parser.parseOptionalKeyword("_") && parser.parseInteger(c)) {
return mlir::Type();
}
crtDecomposition.push_back(c);
if (parser.parseOptionalComma()) {
if (parser.parseRSquare()) {
return mlir::Type();
} else {
break;
}
}
}
}
if (parser.parseLBrace())
return mlir::Type();
@@ -98,8 +67,7 @@ mlir::Type GLWECipherTextType::parse(AsmParser &parser) {
if (parser.parseGreater())
return mlir::Type();
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
return getChecked(loc, loc.getContext(), dimension, polynomialSize, bits, p,
llvm::ArrayRef<int64_t>(crtDecomposition));
return getChecked(loc, loc.getContext(), dimension, polynomialSize, bits, p);
}
} // namespace TFHE
} // namespace concretelang

View File

@@ -23,34 +23,6 @@ DefaultEngine *get_levelled_engine() {
return levelled_engine;
}
void encode_and_expand_lut(uint64_t *output, size_t output_size,
size_t out_MESSAGE_BITS, const uint64_t *lut,
size_t lut_size) {
assert((output_size % lut_size) == 0);
size_t mega_case_size = output_size / lut_size;
assert((mega_case_size % 2) == 0);
for (size_t idx = 0; idx < mega_case_size / 2; ++idx) {
output[idx] = lut[0] << (64 - out_MESSAGE_BITS - 1);
}
for (size_t idx = (lut_size - 1) * mega_case_size + mega_case_size / 2;
idx < output_size; ++idx) {
output[idx] = -(lut[0] << (64 - out_MESSAGE_BITS - 1));
}
for (size_t lut_idx = 1; lut_idx < lut_size; ++lut_idx) {
uint64_t lut_value = lut[lut_idx] << (64 - out_MESSAGE_BITS - 1);
size_t start = mega_case_size * (lut_idx - 1) + mega_case_size / 2;
for (size_t output_idx = start; output_idx < start + mega_case_size;
++output_idx) {
output[output_idx] = lut_value;
}
}
}
#include "concretelang/ClientLib/CRT.h"
#include "concretelang/Runtime/wrappers.h"
@@ -217,13 +189,11 @@ void memref_batched_bootstrap_lwe_cuda_u64(
uint64_t glwe_ct_len = poly_size * (glwe_dim + 1);
uint64_t glwe_ct_size = glwe_ct_len * sizeof(uint64_t);
uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size);
std::vector<uint64_t> expanded_tabulated_function_array(poly_size);
encode_and_expand_lut(expanded_tabulated_function_array.data(), poly_size,
precision, tlu_aligned + tlu_offset, tlu_size);
CAPI_ASSERT_ERROR(
default_engine_discard_trivially_encrypt_glwe_ciphertext_u64_raw_ptr_buffers(
get_levelled_engine(), glwe_ct, glwe_ct_len,
expanded_tabulated_function_array.data(), poly_size));
get_levelled_engine(), glwe_ct, glwe_ct_len, tlu_aligned + tlu_offset,
poly_size));
// Move the glwe accumulator to the GPU
void *glwe_ct_gpu =
@@ -261,34 +231,117 @@ void memref_batched_bootstrap_lwe_cuda_u64(
#endif
void memref_expand_lut_in_trivial_glwe_ct_u64(
uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
uint32_t poly_size, uint32_t glwe_dimension, uint32_t out_precision,
uint64_t *lut_allocated, uint64_t *lut_aligned, uint64_t lut_offset,
uint64_t lut_size, uint64_t lut_stride) {
void memref_encode_plaintext_with_crt(
uint64_t *output_allocated, uint64_t *output_aligned,
uint64_t output_offset, uint64_t output_size, uint64_t output_stride,
uint64_t input, uint64_t *mods_allocated, uint64_t *mods_aligned,
uint64_t mods_offset, uint64_t mods_size, uint64_t mods_stride,
uint64_t mods_product) {
assert(lut_stride == 1 && "Runtime: stride not equal to 1, check "
"memref_expand_lut_in_trivial_glwe_ct_u64");
assert(output_stride == 1 && "Runtime: stride not equal to 1, check "
"memref_encode_plaintext_with_crt");
assert(glwe_ct_stride == 1 && "Runtime: stride not equal to 1, check "
"memref_expand_lut_in_trivial_glwe_ct_u64");
assert(mods_stride == 1 && "Runtime: stride not equal to 1, check "
"memref_encode_plaintext_with_crt");
assert(glwe_ct_size == poly_size * (glwe_dimension + 1));
std::vector<uint64_t> expanded_tabulated_function_array(poly_size);
encode_and_expand_lut(expanded_tabulated_function_array.data(), poly_size,
out_precision, lut_aligned + lut_offset, lut_size);
CAPI_ASSERT_ERROR(
default_engine_discard_trivially_encrypt_glwe_ciphertext_u64_raw_ptr_buffers(
get_levelled_engine(), glwe_ct_aligned + glwe_ct_offset, glwe_ct_size,
expanded_tabulated_function_array.data(), poly_size));
for (size_t i = 0; i < (size_t)mods_size; ++i) {
output_aligned[output_offset + i] =
encode_crt(input, mods_aligned[mods_offset + i], mods_product);
}
return;
}
void memref_encode_expand_lut_for_bootstrap(
uint64_t *output_lut_allocated, uint64_t *output_lut_aligned,
uint64_t output_lut_offset, uint64_t output_lut_size,
uint64_t output_lut_stride, uint64_t *input_lut_allocated,
uint64_t *input_lut_aligned, uint64_t input_lut_offset,
uint64_t input_lut_size, uint64_t input_lut_stride, uint32_t poly_size,
uint32_t out_MESSAGE_BITS) {
assert(input_lut_stride == 1 && "Runtime: stride not equal to 1, check "
"memref_encode_expand_lut_bootstrap");
assert(output_lut_stride == 1 && "Runtime: stride not equal to 1, check "
"memref_encode_expand_lut_bootstrap");
size_t mega_case_size = output_lut_size / input_lut_size;
assert((mega_case_size % 2) == 0);
for (size_t idx = 0; idx < mega_case_size / 2; ++idx) {
output_lut_aligned[output_lut_offset + idx] =
input_lut_aligned[input_lut_offset] << (64 - out_MESSAGE_BITS - 1);
}
for (size_t idx = (input_lut_size - 1) * mega_case_size + mega_case_size / 2;
idx < output_lut_size; ++idx) {
output_lut_aligned[output_lut_offset + idx] =
-(input_lut_aligned[input_lut_offset] << (64 - out_MESSAGE_BITS - 1));
}
for (size_t lut_idx = 1; lut_idx < input_lut_size; ++lut_idx) {
uint64_t lut_value = input_lut_aligned[input_lut_offset + lut_idx]
<< (64 - out_MESSAGE_BITS - 1);
size_t start = mega_case_size * (lut_idx - 1) + mega_case_size / 2;
for (size_t output_idx = start; output_idx < start + mega_case_size;
++output_idx) {
output_lut_aligned[output_lut_offset + output_idx] = lut_value;
}
}
return;
}
void memref_encode_expand_lut_for_woppbs(
// Output encoded/expanded lut
uint64_t *output_lut_allocated, uint64_t *output_lut_aligned,
uint64_t output_lut_offset, uint64_t output_lut_size,
uint64_t output_lut_stride,
// Input lut
uint64_t *input_lut_allocated, uint64_t *input_lut_aligned,
uint64_t input_lut_offset, uint64_t input_lut_size,
uint64_t input_lut_stride,
// Crt coprimes
uint64_t *crt_decomposition_allocated, uint64_t *crt_decomposition_aligned,
uint64_t crt_decomposition_offset, uint64_t crt_decomposition_size,
uint64_t crt_decomposition_stride,
// Crt number of bits
uint64_t *crt_bits_allocated, uint64_t *crt_bits_aligned,
uint64_t crt_bits_offset, uint64_t crt_bits_size, uint64_t crt_bits_stride,
// Crypto parameters
uint32_t poly_size, uint32_t modulus_product) {
assert(input_lut_stride == 1 && "Runtime: stride not equal to 1, check "
"memref_encode_expand_lut_woppbs");
assert(output_lut_stride == 1 && "Runtime: stride not equal to 1, check "
"memref_encode_expand_lut_woppbs");
assert(modulus_product > input_lut_size);
uint64_t lut_crt_size = output_lut_size / crt_decomposition_size;
for (uint64_t value = 0; value < input_lut_size; value++) {
uint64_t index_lut = 0;
uint64_t tmp = 1;
for (size_t block = 0; block < crt_decomposition_size; block++) {
auto base = crt_decomposition_aligned[crt_decomposition_offset + block];
auto bits = crt_bits_aligned[crt_bits_offset + block];
index_lut += (((value % base) << bits) / base) * tmp;
tmp <<= bits;
}
for (size_t block = 0; block < crt_decomposition_size; block++) {
auto base = crt_decomposition_aligned[crt_decomposition_offset + block];
auto v = encode_crt(input_lut_aligned[input_lut_offset + value], base,
modulus_product);
output_lut_aligned[output_lut_offset + block * lut_crt_size + index_lut] =
v;
}
}
}
void memref_add_lwe_ciphertexts_u64(
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
@@ -387,15 +440,10 @@ void memref_bootstrap_lwe_u64(
uint64_t glwe_ct_size = poly_size * (glwe_dim + 1);
uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size * sizeof(uint64_t));
std::vector<uint64_t> expanded_tabulated_function_array(poly_size);
encode_and_expand_lut(expanded_tabulated_function_array.data(), poly_size,
precision, tlu_aligned + tlu_offset, tlu_size);
CAPI_ASSERT_ERROR(
default_engine_discard_trivially_encrypt_glwe_ciphertext_u64_raw_ptr_buffers(
get_levelled_engine(), glwe_ct, glwe_ct_size,
expanded_tabulated_function_array.data(), poly_size));
tlu_aligned + tlu_offset, poly_size));
CAPI_ASSERT_ERROR(
fft_engine_lwe_ciphertext_discarding_bootstrap_u64_raw_ptr_buffers(
@@ -430,40 +478,6 @@ uint64_t encode_crt(int64_t plaintext, uint64_t modulus, uint64_t product) {
return concretelang::clientlib::crt::encode(plaintext, modulus, product);
}
void generate_luts_crt_without_padding(
uint64_t *&luts_crt, uint64_t &total_luts_crt_size, uint64_t *crt_decomp,
uint64_t *number_of_bits_per_block, size_t crt_size, uint64_t *lut,
uint64_t lut_size, uint64_t total_number_of_bits, uint64_t modulus,
uint64_t polynomialSize) {
uint64_t lut_crt_size = uint64_t(1) << total_number_of_bits;
lut_crt_size = std::max(lut_crt_size, polynomialSize);
total_luts_crt_size = crt_size * lut_crt_size;
luts_crt = (uint64_t *)aligned_alloc(U64_ALIGNMENT,
sizeof(uint64_t) * total_luts_crt_size);
assert(modulus > lut_size);
for (uint64_t value = 0; value < lut_size; value++) {
uint64_t index_lut = 0;
uint64_t tmp = 1;
for (size_t block = 0; block < crt_size; block++) {
auto base = crt_decomp[block];
auto bits = number_of_bits_per_block[block];
index_lut += (((value % base) << bits) / base) * tmp;
tmp <<= bits;
}
for (size_t block = 0; block < crt_size; block++) {
auto base = crt_decomp[block];
auto v = encode_crt(lut[value], base, modulus);
luts_crt[block * lut_crt_size + index_lut] = v;
}
}
}
void memref_wop_pbs_crt_buffer(
// Output 2D memref
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
@@ -543,23 +557,14 @@ void memref_wop_pbs_crt_buffer(
in_block, nb_bits_to_extract, delta_log));
}
uint64_t *luts_crt;
uint64_t luts_crt_size;
generate_luts_crt_without_padding(
luts_crt, luts_crt_size, crt_decomp_aligned, number_of_bits_per_block,
crt_decomp_size, lut_ct_aligned, lut_ct_size,
total_number_of_bits_per_block, message_modulus, polynomial_size);
// Vertical packing
CAPI_ASSERT_ERROR(
fft_engine_lwe_ciphertext_vector_discarding_circuit_bootstrap_boolean_vertical_packing_u64_raw_ptr_buffers(
context->get_fft_engine(), context->get_default_engine(),
context->get_fft_fourier_bsk(), out_aligned, lwe_big_size,
crt_decomp_size, extract_bits_output_buffer, lwe_small_size,
total_number_of_bits_per_block, luts_crt, luts_crt_size,
cbs_level_count, cbs_base_log, context->get_fpksk()));
free(luts_crt);
total_number_of_bits_per_block, lut_ct_aligned + lut_ct_offset,
lut_ct_size, cbs_level_count, cbs_base_log, context->get_fpksk()));
}
void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned,

View File

@@ -21,7 +21,8 @@ add_mlir_library(
FHELinalgDialect
FHELinalgDialectTransforms
FHETensorOpsToLinalg
FHEToTFHE
FHEToTFHECrt
FHEToTFHEScalar
ExtractSDFGOps
MLIRLowerableDialectsToLLVM
FHEDialectAnalysis

View File

@@ -311,31 +311,14 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
if (target == Target::FHE)
return std::move(res);
// FHE -> TFHE
if (mlir::concretelang::pipeline::lowerFHEToTFHE(mlirContext, module,
res.fheContext, enablePass)
// FHELinalg -> FHE
if (mlir::concretelang::pipeline::lowerFHELinalgToFHE(
mlirContext, module, res.fheContext, enablePass, loopParallelize,
options.batchConcreteOps)
.failed()) {
return errorDiag("Lowering from FHE to TFHE failed");
return errorDiag("Lowering from FHELinalg to FHE failed");
}
if (target == Target::TFHE)
return std::move(res);
// TFHE -> Concrete
if (mlir::concretelang::pipeline::lowerTFHEToConcrete(
mlirContext, module, res.fheContext, this->enablePass)
.failed()) {
return errorDiag("Lowering from TFHE to Concrete failed");
}
// Optimizing Concrete
if (this->compilerOptions.optimizeConcrete &&
mlir::concretelang::pipeline::optimizeConcrete(mlirContext, module,
this->enablePass)
.failed()) {
return errorDiag("Optimizing Concrete failed");
}
if (target == Target::CONCRETE)
if (target == Target::FHE_NO_LINALG)
return std::move(res);
// Generate client parameters if requested
@@ -371,19 +354,33 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
}
}
// Concrete with linalg ops -> Concrete with loop ops
if (mlir::concretelang::pipeline::lowerConcreteLinalgToLoops(
mlirContext, module, this->enablePass, loopParallelize,
options.batchConcreteOps)
// FHE -> TFHE
if (mlir::concretelang::pipeline::lowerFHEToTFHE(mlirContext, module,
res.fheContext, enablePass)
.failed()) {
return StreamStringError(
"Lowering from Concrete with linalg ops to Concrete with loops failed");
return errorDiag("Lowering from FHE to TFHE failed");
}
if (target == Target::TFHE)
return std::move(res);
// TFHE -> Concrete
if (mlir::concretelang::pipeline::lowerTFHEToConcrete(
mlirContext, module, res.fheContext, this->enablePass)
.failed()) {
return errorDiag("Lowering from TFHE to Concrete failed");
}
if (target == Target::CONCRETEWITHLOOPS) {
return std::move(res);
// Optimizing Concrete
if (this->compilerOptions.optimizeConcrete &&
mlir::concretelang::pipeline::optimizeConcrete(mlirContext, module,
this->enablePass)
.failed()) {
return errorDiag("Optimizing Concrete failed");
}
if (target == Target::CONCRETE)
return std::move(res);
// Concrete -> BConcrete
if (mlir::concretelang::pipeline::lowerConcreteToBConcrete(
mlirContext, module, this->enablePass, loopParallelize)

View File

@@ -183,27 +183,56 @@ markFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module,
return pm.run(module.getOperation());
}
mlir::LogicalResult
lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelizeLoops, bool batchOperations) {
mlir::PassManager pm(&context);
pipelinePrinting("FHELinalgToFHE", pm, context);
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertFHETensorOpsToLinalg(), enablePass);
addPotentiallyNestedPass(pm, mlir::createLinalgGeneralizationPass(),
enablePass);
addPotentiallyNestedPass(
pm,
mlir::concretelang::createLinalgGenericOpWithTensorsToLoopsPass(
parallelizeLoops),
enablePass);
if (batchOperations) {
addPotentiallyNestedPass(pm, mlir::concretelang::createBatchingPass(),
enablePass);
}
return pm.run(module.getOperation());
}
mlir::LogicalResult
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("FHEToTFHE", pm, context);
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertFHETensorOpsToLinalg(), enablePass);
// FHETensorOpsToLinalg does generate linalg named ops that need to be lowered
// to linalg.generic operations
addPotentiallyNestedPass(pm, mlir::createLinalgGeneralizationPass(),
enablePass);
mlir::concretelang::ApplyLookupTableLowering lowerStrategy =
mlir::concretelang::KeySwitchBoostrapLowering;
if (fheContext.hasValue() && fheContext->parameter.largeInteger.hasValue()) {
lowerStrategy = mlir::concretelang::WopPBSLowering;
pipelinePrinting("FHEToTFHECrt", pm, context);
auto dec =
fheContext.value().parameter.largeInteger.value().crtDecomposition;
auto mods = mlir::SmallVector<int64_t>(dec.begin(), dec.end());
auto polySize = fheContext.value().parameter.getPolynomialSize();
addPotentiallyNestedPass(
pm,
mlir::concretelang::createConvertFHEToTFHECrtPass(
mlir::concretelang::CrtLoweringParameters(mods, polySize)),
enablePass);
} else if (fheContext.hasValue()) {
pipelinePrinting("FHEToTFHEScalar", pm, context);
size_t polySize = fheContext.value().parameter.getPolynomialSize();
addPotentiallyNestedPass(
pm,
mlir::concretelang::createConvertFHEToTFHEScalarPass(
mlir::concretelang::ScalarLoweringParameters(polySize)),
enablePass);
}
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertFHEToTFHEPass(lowerStrategy),
enablePass);
return pm.run(module.getOperation());
}
@@ -240,27 +269,6 @@ optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
return pm.run(module.getOperation());
}
mlir::LogicalResult
lowerConcreteLinalgToLoops(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelizeLoops, bool batchOperations) {
mlir::PassManager pm(&context);
pipelinePrinting("ConcreteLinalgToLoops", pm, context);
addPotentiallyNestedPass(
pm,
mlir::concretelang::createLinalgGenericOpWithTensorsToLoopsPass(
parallelizeLoops),
enablePass);
if (batchOperations) {
addPotentiallyNestedPass(pm, mlir::concretelang::createBatchingPass(),
enablePass);
}
return pm.run(module.getOperation());
}
mlir::LogicalResult
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
@@ -293,8 +301,6 @@ lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("BConcreteToStd", pm, context);
addPotentiallyNestedPass(pm, mlir::concretelang::createEliminateCRTOps(),
enablePass);
addPotentiallyNestedPass(pm, mlir::concretelang::createAddRuntimeContext(),
enablePass);
return pm.run(module.getOperation());

View File

@@ -14,6 +14,7 @@
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
#include "concretelang/Support/Error.h"
#include "concretelang/Support/V0Curves.h"
@@ -34,7 +35,8 @@ const auto keyFormat = KEY_FORMAT_BINARY;
const auto v0Curve = getV0Curves(securityLevel, keyFormat);
/// For the v0 the secretKeyID and precision are the same for all gates.
llvm::Expected<CircuitGate> gateFromMLIRType(LweSecretKeyID secretKeyID,
llvm::Expected<CircuitGate> gateFromMLIRType(V0FHEContext fheContext,
LweSecretKeyID secretKeyID,
Variance variance,
mlir::Type type) {
if (type.isIntOrIndex()) {
@@ -58,29 +60,35 @@ llvm::Expected<CircuitGate> gateFromMLIRType(LweSecretKeyID secretKeyID,
};
}
if (auto lweTy = type.dyn_cast_or_null<
mlir::concretelang::Concrete::LweCiphertextType>()) {
mlir::concretelang::FHE::EncryptedIntegerType>()) {
bool sign = lweTy.isSignedInteger();
std::vector<int64_t> crt;
if (fheContext.parameter.largeInteger.has_value()) {
crt = fheContext.parameter.largeInteger.value().crtDecomposition;
}
return CircuitGate{
/* .encryption = */ llvm::Optional<EncryptionGate>({
/* .secretKeyID = */ secretKeyID,
/* .variance = */ variance,
/* .encoding = */
{
/* .precision = */ (size_t)lweTy.getP(),
/* .crt = */ lweTy.getCrtDecomposition().vec(),
/* .precision = */ lweTy.getWidth(),
/* .crt = */ crt,
},
}),
/*.shape = */
{/*.width = */ (size_t)lweTy.getP(),
/*.dimensions = */ std::vector<int64_t>(),
/*.size = */ 0,
/* .sign */ sign},
{
/*.width = */ (size_t)lweTy.getWidth(),
/*.dimensions = */ std::vector<int64_t>(),
/*.size = */ 0,
/*.sign = */ sign,
},
};
}
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
if (tensor != nullptr) {
auto gate =
gateFromMLIRType(secretKeyID, variance, tensor.getElementType());
auto gate = gateFromMLIRType(fheContext, secretKeyID, variance,
tensor.getElementType());
if (auto err = gate.takeError()) {
return std::move(err);
}
@@ -179,14 +187,13 @@ createClientParametersForV0(V0FHEContext fheContext,
auto funcType = (*funcOp).getFunctionType();
auto inputs = funcType.getInputs();
bool hasContext =
inputs.empty()
? false
: inputs.back().isa<mlir::concretelang::Concrete::ContextType>();
auto gateFromType = [&](mlir::Type ty) {
return gateFromMLIRType(clientlib::BIG_KEY, inputVariance, ty);
return gateFromMLIRType(fheContext, clientlib::BIG_KEY, inputVariance, ty);
};
for (auto inType = funcType.getInputs().begin();
inType < funcType.getInputs().end() - hasContext; inType++) {

View File

@@ -46,9 +46,9 @@ namespace clientlib = concretelang::clientlib;
enum Action {
ROUND_TRIP,
DUMP_FHE,
DUMP_FHE_NO_LINALG,
DUMP_TFHE,
DUMP_CONCRETE,
DUMP_CONCRETEWITHLOOPS,
DUMP_BCONCRETE,
DUMP_SDFG,
DUMP_STD,
@@ -119,13 +119,13 @@ static llvm::cl::opt<enum Action> action(
"Parse input module and regenerate textual representation")),
llvm::cl::values(clEnumValN(Action::DUMP_FHE, "dump-fhe",
"Dump FHE module")),
llvm::cl::values(clEnumValN(Action::DUMP_FHE_NO_LINALG,
"dump-fhe-no-linalg",
"Lower FHELinalg to FHE and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_TFHE, "dump-tfhe",
"Lower to TFHE and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_CONCRETE, "dump-concrete",
"Lower to Concrete and dump result")),
llvm::cl::values(clEnumValN(
Action::DUMP_CONCRETEWITHLOOPS, "dump-concrete-with-loops",
"Lower to Concrete, replace linalg ops with loops and dump result")),
llvm::cl::values(
clEnumValN(Action::DUMP_BCONCRETE, "dump-bconcrete",
"Lower to Bufferized Concrete and dump result")),
@@ -369,9 +369,9 @@ cmdlineCompilationOptions() {
"The large-integers options should all be set",
llvm::inconvertibleErrorCode());
}
if (cmdline::largeIntegerPackingKeyswitch.size() != 5) {
if (cmdline::largeIntegerPackingKeyswitch.size() != 4) {
return llvm::make_error<llvm::StringError>(
"The large-integers-packing-keyswitch must be a list of 5 integer",
"The large-integers-packing-keyswitch must be a list of 4 integer",
llvm::inconvertibleErrorCode());
}
if (cmdline::largeIntegerCircuitBootstrap.size() != 2) {
@@ -488,15 +488,15 @@ mlir::LogicalResult processInputBuffer(
case Action::DUMP_FHE:
target = mlir::concretelang::CompilerEngine::Target::FHE;
break;
case Action::DUMP_FHE_NO_LINALG:
target = mlir::concretelang::CompilerEngine::Target::FHE_NO_LINALG;
break;
case Action::DUMP_TFHE:
target = mlir::concretelang::CompilerEngine::Target::TFHE;
break;
case Action::DUMP_CONCRETE:
target = mlir::concretelang::CompilerEngine::Target::CONCRETE;
break;
case Action::DUMP_CONCRETEWITHLOOPS:
target = mlir::concretelang::CompilerEngine::Target::CONCRETEWITHLOOPS;
break;
case Action::DUMP_BCONCRETE:
target = mlir::concretelang::CompilerEngine::Target::BCONCRETE;
break;

View File

@@ -6,3 +6,10 @@ func.func @main(%arg0: !FHE.eint<15>, %cst: tensor<32768xi64>) -> tensor<1x!FHE.
%6 = tensor.from_elements %1 : tensor<1x!FHE.eint<15>> // ERROR HERE line 4
return %6 : tensor<1x!FHE.eint<15>>
}
// Ensures that tensors of multiple elements can be constructed as well.
func.func @main2(%arg0: !FHE.eint<15>, %cst: tensor<32768xi64>) -> tensor<2x!FHE.eint<15>> {
%1 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<15>, tensor<32768xi64>) -> !FHE.eint<15>
%6 = tensor.from_elements %1, %arg0 : tensor<2x!FHE.eint<15>> // ERROR HERE line 4
return %6 : tensor<2x!FHE.eint<15>>
}

View File

@@ -8,12 +8,3 @@ func.func @add_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !
%0 = "Concrete.add_lwe_ciphertexts"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7>
return %0 : !Concrete.lwe_ciphertext<2048,7>
}
//CHECK: func @add_crt_lwe_ciphertexts(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64>
//CHECK: return %[[V0]] : tensor<5x2049xi64>
//CHECK: }
func.func @add_crt_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>, %arg1: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7> {
%0 = "Concrete.add_lwe_ciphertexts"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>, !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>
return %0 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>
}

View File

@@ -2,39 +2,21 @@
//CHECK: func.func @add_glwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
//CHECK: %c1_i8 = arith.constant 1 : i8
//CHECK: %[[V0:.*]] = arith.extui %c1_i8 : i8 to i64
//CHECK: %c56_i64 = arith.constant 56 : i64
//CHECK: %[[V1:.*]] = arith.shli %[[V0]], %c56_i64 : i64
//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: %c1_i64 = arith.constant 1 : i64
//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %c1_i64) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: return %[[V2]] : tensor<1025xi64>
//CHECK: }
func.func @add_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
%0 = arith.constant 1 : i8
%2 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7>
%0 = arith.constant 1 : i64
%2 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,7>, i64) -> !Concrete.lwe_ciphertext<1024,7>
return %2 : !Concrete.lwe_ciphertext<1024,7>
}
//CHECK: func.func @add_glwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i5) -> tensor<1025xi64> {
//CHECK: %[[V0:.*]] = arith.extui %[[A1]] : i5 to i64
//CHECK: %c59_i64 = arith.constant 59 : i64
//CHECK: %[[V1:.*]] = arith.shli %[[V0]], %c59_i64 : i64
//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: func.func @add_glwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i64) -> tensor<1025xi64> {
//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: return %[[V2]] : tensor<1025xi64>
//CHECK: }
func.func @add_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> {
%1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4>
func.func @add_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i64) -> !Concrete.lwe_ciphertext<1024,4> {
%1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<1024,4>, i64) -> !Concrete.lwe_ciphertext<1024,4>
return %1 : !Concrete.lwe_ciphertext<1024,4>
}
//CHECK: func.func @add_plaintext_lwe_ciphertext(%[[A0:.*]]: tensor<5x1025xi64>) -> tensor<5x1025xi64> {
//CHECK: %c1_i8 = arith.constant 1 : i8
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_tensor"(%[[A0]], %c1_i8) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i8) -> tensor<5x1025xi64>
//CHECK: return %[[V0]] : tensor<5x1025xi64>
//CHECK: }
func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7> {
%0 = arith.constant 1 : i8
%2 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>, i8) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>
return %2 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>
}

View File

@@ -0,0 +1,10 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
// CHECK: func.func @apply_lookup_table(%arg0: tensor<4xi64>) -> tensor<1024xi64> {
// CHECK-NEXT: %0 = "BConcrete.encode_expand_lut_for_bootstrap_tensor"(%arg0) {outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64>
// CHECK-NEXT: return %0 : tensor<1024xi64>
// CHECK-NEXT: }
func.func @apply_lookup_table(%arg0: tensor<4xi64>) -> tensor<1024xi64> {
%0 = "Concrete.encode_expand_lut_for_bootstrap"(%arg0) {outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64>
return %0 : tensor<1024xi64>
}

View File

@@ -0,0 +1,10 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
// CHECK: func.func @main(%arg0: tensor<4xi64>) -> tensor<40960xi64> {
// CHECK-NEXT: %0 = "BConcrete.encode_expand_lut_for_woppbs_tensor"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
// CHECK-NEXT: return %0 : tensor<40960xi64>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<4xi64>) -> tensor<40960xi64> {
%0 = "Concrete.encode_expand_lut_for_woppbs"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
return %0 : tensor<40960xi64>
}

View File

@@ -0,0 +1,10 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
// CHECK: func.func @main(%arg0: i64) -> tensor<5xi64> {
// CHECK-NEXT: %0 = "BConcrete.encode_plaintext_with_crt_tensor"(%arg0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
// CHECK-NEXT: return %0 : tensor<5xi64>
// CHECK-NEXT: }
func.func @main(%arg0: i64) -> tensor<5xi64> {
%0 = "Concrete.encode_plaintext_with_crt"(%arg0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
return %0 : tensor<5xi64>
}

View File

@@ -6,10 +6,3 @@
func.func @identity(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
return %arg0 : !Concrete.lwe_ciphertext<1024,7>
}
// CHECK: func.func @identity_crt(%arg0: tensor<5x1025xi64>) -> tensor<5x1025xi64> {
// CHECK-NEXT: return %arg0 : tensor<5x1025xi64>
// CHECK-NEXT: }
func.func @identity_crt(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7> {
return %arg0 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>
}

View File

@@ -1,33 +1,21 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
//CHECK: func.func @mul_lwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
//CHECK: %c1_i8 = arith.constant 1 : i8
//CHECK: %[[V0:.*]] = arith.extui %c1_i8 : i8 to i64
//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: %c1_i64 = arith.constant 1 : i64
//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %c1_i64) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: return %[[V1]] : tensor<1025xi64>
//CHECK: }
func.func @mul_lwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
%0 = arith.constant 1 : i8
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7>
%0 = arith.constant 1 : i64
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,7>, i64) -> !Concrete.lwe_ciphertext<1024,7>
return %2 : !Concrete.lwe_ciphertext<1024,7>
}
//CHECK: func.func @mul_lwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i5) -> tensor<1025xi64> {
//CHECK: %[[V0:.*]] = arith.extui %[[A1]] : i5 to i64
//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: func.func @mul_lwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i64) -> tensor<1025xi64> {
//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: return %[[V1]] : tensor<1025xi64>
//CHECK: }
func.func @mul_lwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> {
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4>
func.func @mul_lwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i64) -> !Concrete.lwe_ciphertext<1024,4> {
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<1024,4>, i64) -> !Concrete.lwe_ciphertext<1024,4>
return %1 : !Concrete.lwe_ciphertext<1024,4>
}
//CHECK: func.func @mul_cleartext_lwe_ciphertext_crt(%[[A0:.*]]: tensor<5x1025xi64>, %[[A1:.*]]: i5) -> tensor<5x1025xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i5) -> tensor<5x1025xi64>
//CHECK: return %[[V0]] : tensor<5x1025xi64>
//CHECK: }
func.func @mul_cleartext_lwe_ciphertext_crt(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4> {
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>, i5) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>
return %1 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>
}

View File

@@ -8,12 +8,3 @@ func.func @neg_lwe(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_cip
%0 = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4>
return %0 : !Concrete.lwe_ciphertext<1024,4>
}
//CHECK: func.func @negate_lwe_ciphertext(%[[A0:.*]]: tensor<5x1025xi64>) -> tensor<5x1025xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_tensor"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>) -> tensor<5x1025xi64>
//CHECK: return %[[V0]] : tensor<5x1025xi64>
//CHECK: }
func.func @negate_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4> {
%0 = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>
return %0 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>
}

View File

@@ -31,16 +31,6 @@ func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x2x3x4x!Concrete.lwe_ciphe
return %0 : tensor<6x2x12x!Concrete.lwe_ciphertext<1024,4>>
}
// -----
//CHECK: func.func @tensor_collapse_shape_crt(%[[A0:.*]]: tensor<2x3x4x5x6x5x1025xi64>) -> tensor<720x5x1025xi64> {
//CHECK: %[[V0:.*]] = tensor.collapse_shape %[[A0]] [[_:\[\[0, 1, 2, 3, 4\], \[5\], \[6\]\]]] : tensor<2x3x4x5x6x5x1025xi64> into tensor<720x5x1025xi64>
//CHECK: return %[[V0]] : tensor<720x5x1025xi64>
//CHECK: }
func.func @tensor_collapse_shape_crt(%arg0: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>>) -> tensor<720x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>> {
%0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4]] {MANP = 1 : ui1}: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>> into tensor<720x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>>
return %0 : tensor<720x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>>
}
// -----
//CHECK: func.func @tensor_expand_shape_crt(%[[A0:.*]]: tensor<30x1025xi64>) -> tensor<5x6x1025xi64> {
//CHECK: %[[V0:.*]] = tensor.expand_shape %[[A0]] [[_:\[\[0, 1\], \[2\]\]]] : tensor<30x1025xi64> into tensor<5x6x1025xi64>

View File

@@ -6,10 +6,3 @@
func.func @tensor_identity(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>> {
return %arg0 : tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>>
}
// CHECK: func.func @tensor_identity_crt(%arg0: tensor<2x3x4x5x1025xi64>) -> tensor<2x3x4x5x1025xi64> {
// CHECK-NEXT: return %arg0 : tensor<2x3x4x5x1025xi64>
// CHECK-NEXT: }
func.func @tensor_identity_crt(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>> {
return %arg0 : tensor<2x3x4x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>>
}

View File

@@ -1,12 +0,0 @@
// RUN: concretecompiler %s --passes fhe-to-tfhe --action=dump-tfhe 2>&1| FileCheck %s
// CHECK-LABEL: func.func @add_eint_int(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
func.func @add_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
// CHECK-NEXT: %[[V2:.*]] = "TFHE.add_glwe_int"(%arg0, %[[V1]]) : (!TFHE.glwe<{_,_,_}{7}>, i8) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{7}>
%0 = arith.constant 1 : i8
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}

View File

@@ -1,10 +0,0 @@
// RUN: concretecompiler %s --passes fhe-to-tfhe --action=dump-tfhe 2>&1| FileCheck %s
// CHECK: func.func @apply_lookup_table(%[[A0:.*]]: !TFHE.glwe<{_,_,_}{2}>, %[[LUT:.*]]: tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}> {
// CHECK-NEXT: %[[V1:.*]] = "TFHE.keyswitch_glwe"(%[[A0]]) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %[[LUT]]) {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}>
// CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{3}>
func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<3> {
%1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>)
return %1: !FHE.eint<3>
}

View File

@@ -1,29 +0,0 @@
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
//CHECK: #map0 = affine_map<(d0, d1, d2, d3) -> (d1)>
//CHECK-NEXT: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
//CHECK-NEXT: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
//CHECK-NEXT: #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
//CHECK-NEXT: #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
//CHECK-NEXT: module {
//CHECK-NEXT: func.func @conv2d(%arg0: tensor<100x3x28x28x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4x3x14x14xi3>, %arg2: tensor<4xi3>) -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>> {
//CHECK-NEXT: %0 = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<4xi3>) outs(%0 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: ^bb0(%arg3: i3, %arg4: !TFHE.glwe<{_,_,_}{2}>):
//CHECK-NEXT: %3 = "TFHE.add_glwe_int"(%arg4, %arg3) : (!TFHE.glwe<{_,_,_}{2}>, i3) -> !TFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: linalg.yield %3 : !TFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: } -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<100x3x28x28x!TFHE.glwe<{_,_,_}{2}>>, tensor<4x3x14x14xi3>) outs(%1 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: ^bb0(%arg3: !TFHE.glwe<{_,_,_}{2}>, %arg4: i3, %arg5: !TFHE.glwe<{_,_,_}{2}>):
//CHECK-NEXT: %3 = "TFHE.mul_glwe_int"(%arg3, %arg4) : (!TFHE.glwe<{_,_,_}{2}>, i3) -> !TFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: %4 = "TFHE.add_glwe"(%arg5, %3) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: linalg.yield %4 : !TFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: } -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: return %2 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: }
//CHECK-NEXT: }
func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0, 0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}

View File

@@ -1,29 +0,0 @@
// RUN: concretecompiler %s --passes fhe-to-tfhe --action=dump-tfhe 2>&1| FileCheck %s
// CHECK: #map0 = affine_map<(d0) -> (d0)>
// CHECK-NEXT: #map1 = affine_map<(d0) -> (0)>
// CHECK-NEXT: module {
// CHECK-NEXT: func.func @linalg_generic(%arg0: tensor<2x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<2xi3>, %arg2: tensor<1x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %0 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!TFHE.glwe<{_,_,_}{2}>>, tensor<2xi3>) outs(%arg2 : tensor<1x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: ^bb0(%arg3: !TFHE.glwe<{_,_,_}{2}>, %arg4: i3, %arg5: !TFHE.glwe<{_,_,_}{2}>):
// CHECK-NEXT: %1 = "TFHE.mul_glwe_int"(%arg3, %arg4) : (!TFHE.glwe<{_,_,_}{2}>, i3) -> !TFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT: %2 = "TFHE.add_glwe"(%1, %arg5) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT: linalg.yield %2 : !TFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT: } -> tensor<1x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: return
// CHECK-NEXT: }
// CHECK-NEXT: }
#map0 = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0) -> (0)>
module {
func.func @linalg_generic(%arg0: tensor<2x!FHE.eint<2>>, %arg1: tensor<2xi3>, %acc: tensor<1x!FHE.eint<2>>) {
%2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!FHE.eint<2>>, tensor<2xi3>) outs(%acc : tensor<1x!FHE.eint<2>>) {
^bb0(%arg2: !FHE.eint<2>, %arg3: i3, %arg4: !FHE.eint<2>): // no predecessors
%4 = "FHE.mul_eint_int"(%arg2, %arg3) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
%5 = "FHE.add_eint"(%4, %arg4) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
linalg.yield %5 : !FHE.eint<2>
} -> tensor<1x!FHE.eint<2>>
return
}
}

View File

@@ -1,12 +0,0 @@
// RUN: concretecompiler %s --passes fhe-to-tfhe --action=dump-tfhe 2>&1| FileCheck %s
// CHECK-LABEL: func.func @mul_eint_int(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
func.func @mul_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 2 : i8
// CHECK-NEXT: %[[V2:.*]] = "TFHE.mul_glwe_int"(%arg0, %[[V1]]) : (!TFHE.glwe<{_,_,_}{7}>, i8) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{7}>
%0 = arith.constant 2 : i8
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}

View File

@@ -1,10 +0,0 @@
// RUN: concretecompiler %s --passes fhe-to-tfhe --action=dump-tfhe 2>&1| FileCheck %s
// CHECK-LABEL: func.func @neg_eint(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %[[V1:.*]] = "TFHE.neg_glwe"(%arg0) : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: return %[[V1]] : !TFHE.glwe<{_,_,_}{7}>
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}

View File

@@ -1,12 +0,0 @@
// RUN: concretecompiler %s --passes fhe-to-tfhe --action=dump-tfhe 2>&1| FileCheck %s
// CHECK-LABEL: func.func @sub_int_eint(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
func.func @sub_int_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
// CHECK-NEXT: %[[V2:.*]] = "TFHE.sub_int_glwe"(%[[V1]], %arg0) : (i8, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{7}>
%0 = arith.constant 1 : i8
%1 = "FHE.sub_int_eint"(%0, %arg0): (i8, !FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}

View File

@@ -0,0 +1,19 @@
// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
// CHECK-LABEL: func.func @add_eint(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>, %arg1: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>>
func.func @add_eint(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %c0 = arith.constant 0 : index
// CHECK-NEXT: %c1 = arith.constant 1 : index
// CHECK-NEXT: %c5 = arith.constant 5 : index
// CHECK-NEXT: %0:2 = scf.for %arg2 = %c0 to %c5 step %c1 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
// CHECK-NEXT: %1 = tensor.extract %arg3[%arg2] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: %2 = tensor.extract %arg4[%arg2] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: %3 = "TFHE.add_glwe"(%1, %2) : (!TFHE.glwe<{_,_,_}{7}>, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: %4 = tensor.insert %3 into %arg3[%arg2] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: scf.yield %4, %arg4 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: }
// CHECK-NEXT: return %0#0 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}

View File

@@ -0,0 +1,23 @@
// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
// CHECK-LABEL: func.func @add_eint_int(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>>
func.func @add_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8
// CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64
// CHECK-NEXT: %1 = "TFHE.encode_plaintext_with_crt"(%0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
// CHECK-NEXT: %c0 = arith.constant 0 : index
// CHECK-NEXT: %c1 = arith.constant 1 : index
// CHECK-NEXT: %c5 = arith.constant 5 : index
// CHECK-NEXT: %2:2 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %arg0, %arg3 = %1) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5xi64>) {
// CHECK-NEXT: %3 = tensor.extract %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: %4 = tensor.extract %arg3[%arg1] : tensor<5xi64>
// CHECK-NEXT: %5 = "TFHE.add_glwe_int"(%3, %4) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: %6 = tensor.insert %5 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: scf.yield %6, %arg3 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5xi64>
// CHECK-NEXT: }
// CHECK-NEXT: return %2#0 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
%0 = arith.constant 1 : i8
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}

View File

@@ -0,0 +1,10 @@
// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
// CHECK: func.func @apply_lookup_table(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{3}>>
// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bootstrapBaseLog = -1 : i32, bootstrapLevel = -1 : i32, circuitBootstrapBaseLog = -1 : i32, circuitBootstrapLevel = -1 : i32, crtDecomposition = [], keyswitchBaseLog = -1 : i32, keyswitchLevel = -1 : i32, packingKeySwitchBaseLog = -1 : i32, packingKeySwitchInputLweDimension = -1 : i32, packingKeySwitchLevel = -1 : i32, packingKeySwitchoutputPolynomialSize = -1 : i32} : (tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<40960xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{3}>>
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{3}>>
func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<3> {
%1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>)
return %1: !FHE.eint<3>
}

View File

@@ -1,11 +1,10 @@
// RUN: concretecompiler %s --passes fhe-to-tfhe --action=dump-tfhe 2>&1| FileCheck %s
// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
//CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> {
//CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
//CHECK-NEXT: %[[V1:.*]] = "TFHE.keyswitch_glwe"(%[[A0]]) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
//CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %cst) {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>, tensor<128xi64>) -> !TFHE.glwe<{_,_,_}{7}>
//CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{7}>
//CHECK-NEXT: }
// CHECK: func.func @apply_lookup_table_cst(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
// CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_woppbs"(%cst) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<128xi64>) -> tensor<40960xi64>
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bootstrapBaseLog = -1 : i32, bootstrapLevel = -1 : i32, circuitBootstrapBaseLog = -1 : i32, circuitBootstrapLevel = -1 : i32, crtDecomposition = [], keyswitchBaseLog = -1 : i32, keyswitchLevel = -1 : i32, packingKeySwitchBaseLog = -1 : i32, packingKeySwitchInputLweDimension = -1 : i32, packingKeySwitchLevel = -1 : i32, packingKeySwitchoutputPolynomialSize = -1 : i32} : (tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<40960xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
func.func @apply_lookup_table_cst(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64>
%1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<7>, tensor<128xi64>) -> (!FHE.eint<7>)

View File

@@ -0,0 +1,95 @@
// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
// CHECK: func.func @conv2d(%arg0: tensor<100x3x28x28x5x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4x3x14x14xi3>, %arg2: tensor<4xi3>) -> tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: %c4 = arith.constant 4 : index
// CHECK-NEXT: %c100 = arith.constant 100 : index
// CHECK-NEXT: %c15 = arith.constant 15 : index
// CHECK-NEXT: %c0 = arith.constant 0 : index
// CHECK-NEXT: %c1 = arith.constant 1 : index
// CHECK-NEXT: %c3 = arith.constant 3 : index
// CHECK-NEXT: %c14 = arith.constant 14 : index
// CHECK-NEXT: %0 = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: %1 = scf.for %arg3 = %c0 to %c100 step %c1 iter_args(%arg4 = %0) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %3 = scf.for %arg5 = %c0 to %c4 step %c1 iter_args(%arg6 = %arg4) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %4 = scf.for %arg7 = %c0 to %c15 step %c1 iter_args(%arg8 = %arg6) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %5 = scf.for %arg9 = %c0 to %c15 step %c1 iter_args(%arg10 = %arg8) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %6 = tensor.extract %arg2[%arg5] : tensor<4xi3>
// CHECK-NEXT: %c0_0 = arith.constant 0 : index
// CHECK-NEXT: %7 = tensor.extract_slice %0[%arg3, %arg5, %arg7, %arg9, %c0_0] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: %8 = arith.extui %6 : i3 to i64
// CHECK-NEXT: %9 = "TFHE.encode_plaintext_with_crt"(%8) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
// CHECK-NEXT: %c0_1 = arith.constant 0 : index
// CHECK-NEXT: %c1_2 = arith.constant 1 : index
// CHECK-NEXT: %c5 = arith.constant 5 : index
// CHECK-NEXT: %10:2 = scf.for %arg11 = %c0_1 to %c5 step %c1_2 iter_args(%arg12 = %7, %arg13 = %9) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<5xi64>) {
// CHECK-NEXT: %12 = tensor.extract %arg12[%arg11] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: %13 = tensor.extract %arg13[%arg11] : tensor<5xi64>
// CHECK-NEXT: %14 = "TFHE.add_glwe_int"(%12, %13) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT: %15 = tensor.insert %14 into %arg12[%arg11] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: scf.yield %15, %arg13 : tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<5xi64>
// CHECK-NEXT: }
// CHECK-NEXT: %c0_3 = arith.constant 0 : index
// CHECK-NEXT: %11 = tensor.insert_slice %10#0 into %arg10[%arg3, %arg5, %arg7, %arg9, %c0_3] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> into tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: scf.yield %11 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %5 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %4 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %3 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: %2 = scf.for %arg3 = %c0 to %c100 step %c1 iter_args(%arg4 = %1) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %3 = scf.for %arg5 = %c0 to %c4 step %c1 iter_args(%arg6 = %arg4) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %4 = scf.for %arg7 = %c0 to %c15 step %c1 iter_args(%arg8 = %arg6) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %5 = scf.for %arg9 = %c0 to %c15 step %c1 iter_args(%arg10 = %arg8) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %6 = scf.for %arg11 = %c0 to %c3 step %c1 iter_args(%arg12 = %arg10) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %7 = scf.for %arg13 = %c0 to %c14 step %c1 iter_args(%arg14 = %arg12) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %8 = scf.for %arg15 = %c0 to %c14 step %c1 iter_args(%arg16 = %arg14) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %9 = affine.apply #map(%arg7, %arg13)
// CHECK-NEXT: %10 = affine.apply #map(%arg9, %arg15)
// CHECK-NEXT: %c0_0 = arith.constant 0 : index
// CHECK-NEXT: %11 = tensor.extract_slice %arg0[%arg3, %arg11, %9, %10, %c0_0] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<100x3x28x28x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: %12 = tensor.extract %arg1[%arg5, %arg11, %arg13, %arg15] : tensor<4x3x14x14xi3>
// CHECK-NEXT: %c0_1 = arith.constant 0 : index
// CHECK-NEXT: %13 = tensor.extract_slice %1[%arg3, %arg5, %arg7, %arg9, %c0_1] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: %14 = arith.extsi %12 : i3 to i64
// CHECK-NEXT: %c0_2 = arith.constant 0 : index
// CHECK-NEXT: %c1_3 = arith.constant 1 : index
// CHECK-NEXT: %c5 = arith.constant 5 : index
// CHECK-NEXT: %15 = scf.for %arg17 = %c0_2 to %c5 step %c1_3 iter_args(%arg18 = %11) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %18 = tensor.extract %arg18[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: %19 = "TFHE.mul_glwe_int"(%18, %14) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT: %20 = tensor.insert %19 into %arg18[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: scf.yield %20 : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: %c0_4 = arith.constant 0 : index
// CHECK-NEXT: %c1_5 = arith.constant 1 : index
// CHECK-NEXT: %c5_6 = arith.constant 5 : index
// CHECK-NEXT: %16:2 = scf.for %arg17 = %c0_4 to %c5_6 step %c1_5 iter_args(%arg18 = %13, %arg19 = %15) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<5x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %18 = tensor.extract %arg18[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: %19 = tensor.extract %arg19[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: %20 = "TFHE.add_glwe"(%18, %19) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT: %21 = tensor.insert %20 into %arg18[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: scf.yield %21, %arg19 : tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: %c0_7 = arith.constant 0 : index
// CHECK-NEXT: %17 = tensor.insert_slice %16#0 into %arg16[%arg3, %arg5, %arg7, %arg9, %c0_7] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> into tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: scf.yield %17 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %8 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %7 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %6 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %5 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %4 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %3 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: return %2 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0, 0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}

View File

@@ -0,0 +1,21 @@
// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
// CHECK-LABEL: func.func @mul_eint_int(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>>
func.func @mul_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %c2_i8 = arith.constant 2 : i8
// CHECK-NEXT: %0 = arith.extsi %c2_i8 : i8 to i64
// CHECK-NEXT: %c0 = arith.constant 0 : index
// CHECK-NEXT: %c1 = arith.constant 1 : index
// CHECK-NEXT: %c5 = arith.constant 5 : index
// CHECK-NEXT: %1 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %arg0) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
// CHECK-NEXT: %2 = tensor.extract %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: %3 = "TFHE.mul_glwe_int"(%2, %0) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: %4 = tensor.insert %3 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: scf.yield %4 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: }
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
%0 = arith.constant 2 : i8
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}

View File

@@ -0,0 +1,18 @@
// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
// CHECK-LABEL: func.func @neg_eint(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>>
func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %c0 = arith.constant 0 : index
// CHECK-NEXT: %c1 = arith.constant 1 : index
// CHECK-NEXT: %c5 = arith.constant 5 : index
// CHECK-NEXT: %0 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %arg0) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
// CHECK-NEXT: %1 = tensor.extract %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: %2 = "TFHE.neg_glwe"(%1) : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: %3 = tensor.insert %2 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: scf.yield %3 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: }
// CHECK-NEXT: return %0 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}

View File

@@ -0,0 +1,24 @@
// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
// CHECK-LABEL: func.func @sub_int_eint(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>>
func.func @sub_int_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8
// CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64
// CHECK-NEXT: %1 = "TFHE.encode_plaintext_with_crt"(%0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
// CHECK-NEXT: %c0 = arith.constant 0 : index
// CHECK-NEXT: %c1 = arith.constant 1 : index
// CHECK-NEXT: %c5 = arith.constant 5 : index
// CHECK-NEXT: %2:2 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %arg0, %arg3 = %1) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5xi64>) {
// CHECK-NEXT: %3 = tensor.extract %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: %4 = tensor.extract %arg3[%arg1] : tensor<5xi64>
// CHECK-NEXT: %5 = "TFHE.sub_int_glwe"(%4, %3) : (i64, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: %6 = tensor.insert %5 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
// CHECK-NEXT: scf.yield %6, %arg3 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5xi64>
// CHECK-NEXT: }
// CHECK-NEXT: return %2#0 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
%0 = arith.constant 1 : i8
%1 = "FHE.sub_int_eint"(%0, %arg0): (i8, !FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}

View File

@@ -1,8 +1,8 @@
// RUN: concretecompiler %s --passes fhe-to-tfhe --action=dump-tfhe 2>&1| FileCheck %s
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
// CHECK-LABEL: func.func @add_eint(%arg0: !TFHE.glwe<{_,_,_}{7}>, %arg1: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
func.func @add_eint(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %[[V1:.*]] = "TFHE.add_glwe"(%arg0, %arg1) : (!TFHE.glwe<{_,_,_}{7}>, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: %[[V1:.*]] = "TFHE.add_glwe"(%arg0, %arg1) {MANP = 2 : ui3} : (!TFHE.glwe<{_,_,_}{7}>, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: return %[[V1]] : !TFHE.glwe<{_,_,_}{7}>
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)

View File

@@ -0,0 +1,16 @@
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
// CHECK-LABEL: func.func @add_eint_int(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
func.func @add_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8
// CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64
// CHECK-NEXT: %c56_i64 = arith.constant 56 : i64
// CHECK-NEXT: %1 = arith.shli %0, %c56_i64 : i64
// CHECK-NEXT: %2 = "TFHE.add_glwe_int"(%arg0, %1) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: return %2 : !TFHE.glwe<{_,_,_}{7}>
%0 = arith.constant 1 : i8
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}

View File

@@ -0,0 +1,11 @@
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
// CHECK: func.func @apply_lookup_table(%arg0: !TFHE.glwe<{_,_,_}{2}>, %arg1: tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}> {
// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_bootstrap"(%arg1) {outputBits = 3 : i32, polySize = 256 : i32} : (tensor<4xi64>) -> tensor<256xi64>
// CHECK-NEXT: %1 = "TFHE.keyswitch_glwe"(%arg0) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT: %2 = "TFHE.bootstrap_glwe"(%1, %0) {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>, tensor<256xi64>) -> !TFHE.glwe<{_,_,_}{3}>
// CHECK-NEXT: return %2 : !TFHE.glwe<{_,_,_}{3}>
func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<3> {
%1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>)
return %1: !FHE.eint<3>
}

View File

@@ -0,0 +1,14 @@
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
//CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> {
//CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
//CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_bootstrap"(%cst) {outputBits = 7 : i32, polySize = 8192 : i32} : (tensor<128xi64>) -> tensor<8192xi64>
//CHECK-NEXT: %1 = "TFHE.keyswitch_glwe"(%arg0) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
//CHECK-NEXT: %2 = "TFHE.bootstrap_glwe"(%1, %0) {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>, tensor<8192xi64>) -> !TFHE.glwe<{_,_,_}{7}>
//CHECK-NEXT: return %2 : !TFHE.glwe<{_,_,_}{7}>
func.func @apply_lookup_table_cst(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64>
%1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<7>, tensor<128xi64>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}

View File

@@ -0,0 +1,65 @@
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
//CHECK: func.func @conv2d(%arg0: tensor<100x3x28x28x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4x3x14x14xi3>, %arg2: tensor<4xi3>) -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: %c4 = arith.constant 4 : index
// CHECK-NEXT: %c100 = arith.constant 100 : index
// CHECK-NEXT: %c15 = arith.constant 15 : index
// CHECK-NEXT: %c0 = arith.constant 0 : index
// CHECK-NEXT: %c1 = arith.constant 1 : index
// CHECK-NEXT: %c3 = arith.constant 3 : index
// CHECK-NEXT: %c14 = arith.constant 14 : index
// CHECK-NEXT: %0 = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: %1 = scf.for %arg3 = %c0 to %c100 step %c1 iter_args(%arg4 = %0) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %3 = scf.for %arg5 = %c0 to %c4 step %c1 iter_args(%arg6 = %arg4) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %4 = scf.for %arg7 = %c0 to %c15 step %c1 iter_args(%arg8 = %arg6) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %5 = scf.for %arg9 = %c0 to %c15 step %c1 iter_args(%arg10 = %arg8) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %6 = tensor.extract %arg2[%arg5] : tensor<4xi3>
// CHECK-NEXT: %7 = tensor.extract %0[%arg3, %arg5, %arg7, %arg9] : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: %8 = arith.extui %6 : i3 to i64
// CHECK-NEXT: %c61_i64 = arith.constant 61 : i64
// CHECK-NEXT: %9 = arith.shli %8, %c61_i64 : i64
// CHECK-NEXT: %10 = "TFHE.add_glwe_int"(%7, %9) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT: %11 = tensor.insert %10 into %arg10[%arg3, %arg5, %arg7, %arg9] : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: scf.yield %11 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %5 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %4 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %3 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: %2 = scf.for %arg3 = %c0 to %c100 step %c1 iter_args(%arg4 = %1) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %3 = scf.for %arg5 = %c0 to %c4 step %c1 iter_args(%arg6 = %arg4) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %4 = scf.for %arg7 = %c0 to %c15 step %c1 iter_args(%arg8 = %arg6) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %5 = scf.for %arg9 = %c0 to %c15 step %c1 iter_args(%arg10 = %arg8) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %6 = scf.for %arg11 = %c0 to %c3 step %c1 iter_args(%arg12 = %arg10) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %7 = scf.for %arg13 = %c0 to %c14 step %c1 iter_args(%arg14 = %arg12) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %8 = scf.for %arg15 = %c0 to %c14 step %c1 iter_args(%arg16 = %arg14) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
// CHECK-NEXT: %9 = affine.apply #map(%arg7, %arg13)
// CHECK-NEXT: %10 = affine.apply #map(%arg9, %arg15)
// CHECK-NEXT: %11 = tensor.extract %arg0[%arg3, %arg11, %9, %10] : tensor<100x3x28x28x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: %12 = tensor.extract %arg1[%arg5, %arg11, %arg13, %arg15] : tensor<4x3x14x14xi3>
// CHECK-NEXT: %13 = tensor.extract %1[%arg3, %arg5, %arg7, %arg9] : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: %14 = arith.extsi %12 : i3 to i64
// CHECK-NEXT: %15 = "TFHE.mul_glwe_int"(%11, %14) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT: %16 = "TFHE.add_glwe"(%13, %15) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
// CHECK-NEXT: %17 = tensor.insert %16 into %arg16[%arg3, %arg5, %arg7, %arg9] : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: scf.yield %17 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %8 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %7 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %6 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %5 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %4 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %3 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
// CHECK-NEXT: }
// CHECK-NEXT: return %2 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0, 0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}

View File

@@ -0,0 +1,13 @@
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
// CHECK-LABEL: func.func @mul_eint_int(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
func.func @mul_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %c2_i8 = arith.constant 2 : i8
// CHECK-NEXT: %0 = arith.extsi %c2_i8 : i8 to i64
// CHECK-NEXT: %1 = "TFHE.mul_glwe_int"(%arg0, %0) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: return %1 : !TFHE.glwe<{_,_,_}{7}>
%0 = arith.constant 2 : i8
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}

View File

@@ -0,0 +1,10 @@
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
// CHECK-LABEL: func.func @neg_eint(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) {MANP = 1 : ui1} : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: return %0 : !TFHE.glwe<{_,_,_}{7}>
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}

View File

@@ -0,0 +1,15 @@
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
// CHECK-LABEL: func.func @sub_int_eint(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
func.func @sub_int_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8
// CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64
// CHECK-NEXT: %c56_i64 = arith.constant 56 : i64
// CHECK-NEXT: %1 = arith.shli %0, %c56_i64 : i64
// CHECK-NEXT: %2 = "TFHE.sub_int_glwe"(%1, %arg0) : (i64, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: return %2 : !TFHE.glwe<{_,_,_}{7}>
%0 = arith.constant 1 : i8
%1 = "FHE.sub_int_eint"(%0, %arg0): (i8, !FHE.eint<7>) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}

View File

@@ -1,22 +1,22 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
//CHECK: func.func @add_glwe_const_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
//CHECK: %c1_i8 = arith.constant 1 : i8
//CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[A0]], %c1_i8) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7>
//CHECK: %c1_i64 = arith.constant 1 : i64
//CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[A0]], %c1_i64) : (!Concrete.lwe_ciphertext<1024,7>, i64) -> !Concrete.lwe_ciphertext<1024,7>
//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,7>
//CHECK: }
func.func @add_glwe_const_int(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{7}> {
%0 = arith.constant 1 : i8
%1 = "TFHE.add_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,1,64}{7}>, i8) -> (!TFHE.glwe<{1024,1,64}{7}>)
%0 = arith.constant 1 : i64
%1 = "TFHE.add_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,1,64}{7}>, i64) -> (!TFHE.glwe<{1024,1,64}{7}>)
return %1: !TFHE.glwe<{1024,1,64}{7}>
}
//CHECK: func.func @add_glwe_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i5) -> !Concrete.lwe_ciphertext<1024,4> {
//CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[A0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4>
//CHECK: func.func @add_glwe_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i64) -> !Concrete.lwe_ciphertext<1024,4> {
//CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[A0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i64) -> !Concrete.lwe_ciphertext<1024,4>
//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,4>
//CHECK: }
func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !TFHE.glwe<{1024,1,64}{4}> {
%1 = "TFHE.add_glwe_int"(%arg0, %arg1): (!TFHE.glwe<{1024,1,64}{4}>, i5) -> (!TFHE.glwe<{1024,1,64}{4}>)
func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i64) -> !TFHE.glwe<{1024,1,64}{4}> {
%1 = "TFHE.add_glwe_int"(%arg0, %arg1): (!TFHE.glwe<{1024,1,64}{4}>, i64) -> (!TFHE.glwe<{1024,1,64}{4}>)
return %1: !TFHE.glwe<{1024,1,64}{4}>
}

View File

@@ -0,0 +1,10 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
// CHECK: func.func @apply_lookup_table(%arg0: tensor<4xi64>) -> tensor<1024xi64> {
// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_bootstrap"(%arg0) {outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64>
// CHECK-NEXT: return %0 : tensor<1024xi64>
// CHECK-NEXT: }
func.func @apply_lookup_table(%arg1: tensor<4xi64>) -> tensor<1024xi64> {
%0 = "TFHE.encode_expand_lut_for_bootstrap"(%arg1) {outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64>
return %0: tensor<1024xi64>
}

View File

@@ -0,0 +1,10 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
// CHECK: func.func @main(%arg0: tensor<4xi64>) -> tensor<40960xi64> {
// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_woppbs"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
// CHECK-NEXT: return %0 : tensor<40960xi64>
// CHECK-NEXT: }
func.func @main(%arg1: tensor<4xi64>) -> tensor<40960xi64> {
%0 = "TFHE.encode_expand_lut_for_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
return %0: tensor<40960xi64>
}

View File

@@ -0,0 +1,10 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
// CHECK: func.func @main(%arg0: i64) -> tensor<5xi64> {
// CHECK-NEXT: %0 = "Concrete.encode_plaintext_with_crt"(%arg0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
// CHECK-NEXT: return %0 : tensor<5xi64>
// CHECK-NEXT: }
func.func @main(%arg1: i64) -> tensor<5xi64> {
%0 = "TFHE.encode_plaintext_with_crt"(%arg1) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
return %0: tensor<5xi64>
}

View File

@@ -1,22 +1,22 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
//CHECK: func.func @mul_glwe_const_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
//CHECK: %c1_i8 = arith.constant 1 : i8
//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%[[A0]], %c1_i8) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7>
//CHECK: %c1_i64 = arith.constant 1 : i64
//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%[[A0]], %c1_i64) : (!Concrete.lwe_ciphertext<1024,7>, i64) -> !Concrete.lwe_ciphertext<1024,7>
//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,7>
//CHECK: }
func.func @mul_glwe_const_int(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{7}> {
%0 = arith.constant 1 : i8
%1 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,1,64}{7}>, i8) -> (!TFHE.glwe<{1024,1,64}{7}>)
%0 = arith.constant 1 : i64
%1 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,1,64}{7}>, i64) -> (!TFHE.glwe<{1024,1,64}{7}>)
return %1: !TFHE.glwe<{1024,1,64}{7}>
}
//CHECK: func.func @mul_glwe_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i5) -> !Concrete.lwe_ciphertext<1024,4> {
//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%[[A0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4>
//CHECK: func.func @mul_glwe_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i64) -> !Concrete.lwe_ciphertext<1024,4> {
//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%[[A0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i64) -> !Concrete.lwe_ciphertext<1024,4>
//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,4>
//CHECK: }
func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !TFHE.glwe<{1024,1,64}{4}> {
%1 = "TFHE.mul_glwe_int"(%arg0, %arg1): (!TFHE.glwe<{1024,1,64}{4}>, i5) -> (!TFHE.glwe<{1024,1,64}{4}>)
func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i64) -> !TFHE.glwe<{1024,1,64}{4}> {
%1 = "TFHE.mul_glwe_int"(%arg0, %arg1): (!TFHE.glwe<{1024,1,64}{4}>, i64) -> (!TFHE.glwe<{1024,1,64}{4}>)
return %1: !TFHE.glwe<{1024,1,64}{4}>
}

View File

@@ -1,23 +1,23 @@
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
//CHECK: func.func @sub_const_int_glwe(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
//CHECK: %c1_i8 = arith.constant 1 : i8
//CHECK: %c1_i64 = arith.constant 1 : i64
//CHECK: %[[V0:.*]] = "Concrete.negate_lwe_ciphertext"(%[[A0]]) : (!Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7>
//CHECK: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[V0]], %c1_i8) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7>
//CHECK: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[V0]], %c1_i64) : (!Concrete.lwe_ciphertext<1024,7>, i64) -> !Concrete.lwe_ciphertext<1024,7>
//CHECK: return %[[V1]] : !Concrete.lwe_ciphertext<1024,7>
//CHECK: }
func.func @sub_const_int_glwe(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{7}> {
%0 = arith.constant 1 : i8
%1 = "TFHE.sub_int_glwe"(%0, %arg0): (i8, !TFHE.glwe<{1024,1,64}{7}>) -> (!TFHE.glwe<{1024,1,64}{7}>)
%0 = arith.constant 1 : i64
%1 = "TFHE.sub_int_glwe"(%0, %arg0): (i64, !TFHE.glwe<{1024,1,64}{7}>) -> (!TFHE.glwe<{1024,1,64}{7}>)
return %1: !TFHE.glwe<{1024,1,64}{7}>
}
//CHECK: func.func @sub_int_glwe(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i5) -> !Concrete.lwe_ciphertext<1024,4> {
//CHECK: func.func @sub_int_glwe(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i64) -> !Concrete.lwe_ciphertext<1024,4> {
//CHECK: %[[V0:.*]] = "Concrete.negate_lwe_ciphertext"(%[[A0]]) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4>
//CHECK: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[V0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4>
//CHECK: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[V0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i64) -> !Concrete.lwe_ciphertext<1024,4>
//CHECK: return %[[V1]] : !Concrete.lwe_ciphertext<1024,4>
//CHECK: }
func.func @sub_int_glwe(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !TFHE.glwe<{1024,1,64}{4}> {
%1 = "TFHE.sub_int_glwe"(%arg1, %arg0): (i5, !TFHE.glwe<{1024,1,64}{4}>) -> (!TFHE.glwe<{1024,1,64}{4}>)
func.func @sub_int_glwe(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i64) -> !TFHE.glwe<{1024,1,64}{4}> {
%1 = "TFHE.sub_int_glwe"(%arg1, %arg0): (i64, !TFHE.glwe<{1024,1,64}{4}>) -> (!TFHE.glwe<{1024,1,64}{4}>)
return %1: !TFHE.glwe<{1024,1,64}{4}>
}

View File

@@ -9,15 +9,6 @@ func.func @add_lwe_ciphertexts(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>)
return %0 : tensor<2049xi64>
}
//CHECK: func.func @add_crt_lwe_ciphertexts(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64>
//CHECK: return %[[V0]] : tensor<5x2049xi64>
//CHECK: }
func.func @add_crt_lwe_ciphertexts(%arg0: tensor<5x2049xi64>, %arg1: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
%0 = "BConcrete.add_crt_lwe_tensor"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> ( tensor<5x2049xi64>)
return %0 : tensor<5x2049xi64>
}
//CHECK: func.func @add_plaintext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64>
//CHECK: return %[[V0]] : tensor<2049xi64>
@@ -27,15 +18,6 @@ func.func @add_plaintext_lwe_ciphertext(%arg0: tensor<2049xi64>, %arg1: i64) ->
return %0 : tensor<2049xi64>
}
//CHECK: func.func @add_plaintext_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: i64) -> tensor<5x2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64>
//CHECK: return %[[V0]] : tensor<5x2049xi64>
//CHECK: }
func.func @add_plaintext_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>, %arg1: i64) -> tensor<5x2049xi64> {
%0 = "BConcrete.add_plaintext_crt_lwe_tensor"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> ( tensor<5x2049xi64>)
return %0 : tensor<5x2049xi64>
}
//CHECK: func @mul_cleartext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64>
//CHECK: return %[[V0]] : tensor<2049xi64>
@@ -45,15 +27,6 @@ func.func @mul_cleartext_lwe_ciphertext(%arg0: tensor<2049xi64>, %arg1: i64) ->
return %0 : tensor<2049xi64>
}
//CHECK: func.func @mul_cleartext_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: i64) -> tensor<5x2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64>
//CHECK: return %[[V0]] : tensor<5x2049xi64>
//CHECK: }
func.func @mul_cleartext_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>, %arg1: i64) -> tensor<5x2049xi64> {
%0 = "BConcrete.mul_cleartext_crt_lwe_tensor"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> (tensor<5x2049xi64>)
return %0 : tensor<5x2049xi64>
}
//CHECK: func.func @negate_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.negate_lwe_tensor"(%[[A0]]) : (tensor<2049xi64>) -> tensor<2049xi64>
//CHECK: return %[[V0]] : tensor<2049xi64>
@@ -63,15 +36,6 @@ func.func @negate_lwe_ciphertext(%arg0: tensor<2049xi64>) -> tensor<2049xi64> {
return %0 : tensor<2049xi64>
}
//CHECK: func.func @negate_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_tensor"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> tensor<5x2049xi64>
//CHECK: return %[[V0]] : tensor<5x2049xi64>
//CHECK: }
func.func @negate_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
%0 = "BConcrete.negate_crt_lwe_tensor"(%arg0) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> (tensor<5x2049xi64>)
return %0 : tensor<5x2049xi64>
}
//CHECK: func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>) -> tensor<2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.bootstrap_lwe_tensor"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> tensor<2049xi64>
//CHECK: return %[[V0]] : tensor<2049xi64>

View File

@@ -13,12 +13,6 @@ func.func @type_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Conc
return %arg0: !Concrete.lwe_ciphertext<2048,7>
}
// CHECK-LABEL: func @type_lwe_ciphertext_with_crt(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>
func.func @type_lwe_ciphertext_with_crt(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7> {
// CHECK-NEXT: return %arg0 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>
return %arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>
}
// CHECK-LABEL: func @type_cleartext(%arg0: !Concrete.cleartext<5>) -> !Concrete.cleartext<5>
func.func @type_cleartext(%arg0: !Concrete.cleartext<5>) -> !Concrete.cleartext<5> {
// CHECK-NEXT: return %arg0 : !Concrete.cleartext<5>

View File

@@ -1,21 +1,24 @@
// RUN: concretecompiler %s --action=dump-tfhe 2>&1 | FileCheck %s
// RUN: concretecompiler %s --action=dump-fhe-no-linalg 2>&1 | FileCheck %s
//CHECK: #map0 = affine_map<(d0) -> (d0)>
//CHECK-NEXT: #map1 = affine_map<(d0) -> (0)>
//CHECK-NEXT: module {
//CHECK-NEXT: func.func @dot_eint_int(%arg0: tensor<2x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<2xi3>) -> !TFHE.glwe<{_,_,_}{2}> {
//CHECK-NEXT: %c0 = arith.constant 0 : index
//CHECK-NEXT: %0 = "TFHE.zero_tensor"() : () -> tensor<1x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!TFHE.glwe<{_,_,_}{2}>>, tensor<2xi3>) outs(%0 : tensor<1x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: ^bb0(%arg2: !TFHE.glwe<{_,_,_}{2}>, %arg3: i3, %arg4: !TFHE.glwe<{_,_,_}{2}>):
//CHECK-NEXT: %3 = "TFHE.mul_glwe_int"(%arg2, %arg3) : (!TFHE.glwe<{_,_,_}{2}>, i3) -> !TFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: %4 = "TFHE.add_glwe"(%3, %arg4) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: linalg.yield %4 : !TFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: } -> tensor<1x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %2 = tensor.extract %1[%c0] : tensor<1x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: return %2 : !TFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: }
//CHECK-NEXT: }
// CHECK: module {
// CHECK-NEXT: func.func @dot_eint_int(%arg0: tensor<2x!FHE.eint<2>>, %arg1: tensor<2xi3>) -> !FHE.eint<2> {
// CHECK-NEXT: %c2 = arith.constant 2 : index
// CHECK-NEXT: %c0 = arith.constant 0 : index
// CHECK-NEXT: %c1 = arith.constant 1 : index
// CHECK-NEXT: %0 = "FHE.zero_tensor"() : () -> tensor<1x!FHE.eint<2>>
// CHECK-NEXT: %1 = scf.for %arg2 = %c0 to %c2 step %c1 iter_args(%arg3 = %0) -> (tensor<1x!FHE.eint<2>>) {
// CHECK-NEXT: %3 = tensor.extract %arg0[%arg2] : tensor<2x!FHE.eint<2>>
// CHECK-NEXT: %4 = tensor.extract %arg1[%arg2] : tensor<2xi3>
// CHECK-NEXT: %5 = tensor.extract %0[%c0] : tensor<1x!FHE.eint<2>>
// CHECK-NEXT: %6 = "FHE.mul_eint_int"(%3, %4) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
// CHECK-NEXT: %7 = "FHE.add_eint"(%6, %5) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
// CHECK-NEXT: %8 = tensor.insert %7 into %arg3[%c0] : tensor<1x!FHE.eint<2>>
// CHECK-NEXT: scf.yield %8 : tensor<1x!FHE.eint<2>>
// CHECK-NEXT: }
// CHECK-NEXT: %2 = tensor.extract %1[%c0] : tensor<1x!FHE.eint<2>>
// CHECK-NEXT: return %2 : !FHE.eint<2>
// CHECK-NEXT: }
// CHECK-NEXT: }
func.func @dot_eint_int(%arg0: tensor<2x!FHE.eint<2>>,
%arg1: tensor<2xi3>) -> !FHE.eint<2>
{

View File

@@ -51,21 +51,3 @@ func.func @add_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>, %arg1: !TFHE.glwe<{1024,
%1 = "TFHE.add_glwe"(%arg0, %arg1): (!TFHE.glwe<{1024,12,64}{7}>, !TFHE.glwe<{1024,11,64}{7}>) -> (!TFHE.glwe<{1024,12,64}{7}>)
return %1: !TFHE.glwe<{1024,12,64}{7}>
}
// -----
// GLWE polynomialSize parameter result
func.func @add_glwe(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, %arg1: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}> {
// expected-error @+1 {{'TFHE.add_glwe' op should have the same GLWE 'crt' parameter}}
%1 = "TFHE.add_glwe"(%arg0, %arg1): (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> (!TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>)
return %1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>
}
// -----
// GLWE polynomialSize parameter inputs
func.func @add_glwe(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, %arg1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}> {
// expected-error @+1 {{'TFHE.add_glwe' op should have the same GLWE 'crt' parameter}}
%1 = "TFHE.add_glwe"(%arg0, %arg1): (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>) -> (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>)
return %1: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>
}

View File

@@ -27,23 +27,3 @@ func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,
%1 = "TFHE.add_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,12,64}{7}>, i8) -> (!TFHE.glwe<{1024,11,64}{7}>)
return %1: !TFHE.glwe<{1024,11,64}{7}>
}
// -----
// GLWE crt parameter
func.func @add_glwe_int(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}> {
%0 = arith.constant 1 : i8
// expected-error @+1 {{'TFHE.add_glwe_int' op should have the same GLWE 'crt' parameter}}
%1 = "TFHE.add_glwe_int"(%arg0, %0): (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, i8) -> (!TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>)
return %1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>
}
// -----
// integer width doesn't match GLWE parameter
func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}> {
%0 = arith.constant 1 : i9
// expected-error @+1 {{'TFHE.add_glwe_int' op should have the width of `b` equals or less than 'p'+1}}
%1 = "TFHE.add_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,12,64}{7}>, i9) -> (!TFHE.glwe<{1024,12,64}{7}>)
return %1: !TFHE.glwe<{1024,12,64}{7}>
}

View File

@@ -2,11 +2,11 @@
// CHECK-LABEL: func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}>
func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
// CHECK-NEXT: %[[V2:.*]] = "TFHE.add_glwe_int"(%arg0, %[[V1]]) : (!TFHE.glwe<{1024,12,64}{7}>, i8) -> !TFHE.glwe<{1024,12,64}{7}>
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i64
// CHECK-NEXT: %[[V2:.*]] = "TFHE.add_glwe_int"(%arg0, %[[V1]]) : (!TFHE.glwe<{1024,12,64}{7}>, i64) -> !TFHE.glwe<{1024,12,64}{7}>
// CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{1024,12,64}{7}>
%0 = arith.constant 1 : i8
%1 = "TFHE.add_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,12,64}{7}>, i8) -> (!TFHE.glwe<{1024,12,64}{7}>)
%0 = arith.constant 1 : i64
%1 = "TFHE.add_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,12,64}{7}>, i64) -> (!TFHE.glwe<{1024,12,64}{7}>)
return %1: !TFHE.glwe<{1024,12,64}{7}>
}

View File

@@ -27,23 +27,3 @@ func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,
%1 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,12,64}{7}>, i8) -> (!TFHE.glwe<{1024,11,64}{7}>)
return %1: !TFHE.glwe<{1024,11,64}{7}>
}
// -----
// GLWE crt parameter
func.func @mul_glwe_int(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}> {
%0 = arith.constant 1 : i8
// expected-error @+1 {{'TFHE.mul_glwe_int' op should have the same GLWE 'crt' parameter}}
%1 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, i8) -> (!TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>)
return %1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>
}
// -----
// integer width doesn't match GLWE parameter
func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}> {
%0 = arith.constant 1 : i9
// expected-error @+1 {{'TFHE.mul_glwe_int' op should have the width of `b` equals or less than 'p'+1}}
%1 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,12,64}{7}>, i9) -> (!TFHE.glwe<{1024,12,64}{7}>)
return %1: !TFHE.glwe<{1024,12,64}{7}>
}

View File

@@ -2,11 +2,11 @@
// CHECK-LABEL: func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}>
func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
// CHECK-NEXT: %[[V2:.*]] = "TFHE.mul_glwe_int"(%arg0, %[[V1]]) : (!TFHE.glwe<{1024,12,64}{7}>, i8) -> !TFHE.glwe<{1024,12,64}{7}>
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i64
// CHECK-NEXT: %[[V2:.*]] = "TFHE.mul_glwe_int"(%arg0, %[[V1]]) : (!TFHE.glwe<{1024,12,64}{7}>, i64) -> !TFHE.glwe<{1024,12,64}{7}>
// CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{1024,12,64}{7}>
%0 = arith.constant 1 : i8
%1 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,12,64}{7}>, i8) -> (!TFHE.glwe<{1024,12,64}{7}>)
%0 = arith.constant 1 : i64
%1 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,12,64}{7}>, i64) -> (!TFHE.glwe<{1024,12,64}{7}>)
return %1: !TFHE.glwe<{1024,12,64}{7}>
}

View File

@@ -27,15 +27,6 @@ func.func @neg_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,11,6
// -----
// GLWE crt parameter
func.func @neg_glwe(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}> {
// expected-error @+1 {{'TFHE.neg_glwe' op should have the same GLWE 'crt' parameter}}
%1 = "TFHE.neg_glwe"(%arg0): (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> (!TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>)
return %1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>
}
// -----
// integer width doesn't match GLWE parameter
func.func @neg_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,11,64}{7}> {
// expected-error @+1 {{'TFHE.neg_glwe' op should have the same GLWE 'polynomialSize' parameter}}

View File

@@ -28,16 +28,6 @@ func.func @sub_int_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,
return %1: !TFHE.glwe<{1024,11,64}{7}>
}
// -----
// GLWE polynomialSize parameter
func.func @sub_int_glwe(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}> {
%0 = arith.constant 1 : i8
// expected-error @+1 {{'TFHE.sub_int_glwe' op should have the same GLWE 'crt' parameter}}
%1 = "TFHE.sub_int_glwe"(%0, %arg0): (i8, !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> (!TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>)
return %1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>
}
// -----

View File

@@ -2,11 +2,11 @@
// CHECK-LABEL: func.func @sub_int_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}>
func.func @sub_int_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
// CHECK-NEXT: %[[V2:.*]] = "TFHE.sub_int_glwe"(%[[V1]], %arg0) : (i8, !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}>
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i64
// CHECK-NEXT: %[[V2:.*]] = "TFHE.sub_int_glwe"(%[[V1]], %arg0) : (i64, !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}>
// CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{1024,12,64}{7}>
%0 = arith.constant 1 : i8
%1 = "TFHE.sub_int_glwe"(%0, %arg0): (i8, !TFHE.glwe<{1024,12,64}{7}>) -> (!TFHE.glwe<{1024,12,64}{7}>)
%0 = arith.constant 1 : i64
%1 = "TFHE.sub_int_glwe"(%0, %arg0): (i64, !TFHE.glwe<{1024,12,64}{7}>) -> (!TFHE.glwe<{1024,12,64}{7}>)
return %1: !TFHE.glwe<{1024,12,64}{7}>
}

View File

@@ -11,15 +11,3 @@ func.func @glwe_1(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> {
// CHECK-LABEL: return %arg0 : !TFHE.glwe<{_,_,_}{7}>
return %arg0: !TFHE.glwe<{_,_,_}{7}>
}
// CHECK-LABEL: func.func @glwe_crt(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}>) -> !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}>
func.func @glwe_crt(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}>) -> !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}> {
// CHECK-LABEL: return %arg0 : !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}>
return %arg0: !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}>
}
// CHECK-LABEL: func.func @glwe_crt_undef(%arg0: !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}>) -> !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}>
func.func @glwe_crt_undef(%arg0: !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}>) -> !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}> {
// CHECK-LABEL: return %arg0 : !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}>
return %arg0: !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}>
}

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --split-input-file --action=dump-concrete-with-loops --batch-concrete-ops %s 2>&1| FileCheck %s
// RUN: concretecompiler --split-input-file --action=dump-concrete --batch-concrete-ops %s 2>&1| FileCheck %s
// CHECK-LABEL: func.func @batch_continuous_slice_keyswitch(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<572,2>>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<572,2>> {
func.func @batch_continuous_slice_keyswitch(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<572,2>>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<572,2>> {