mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor(encodings): raise plaintext/lut encodings higher up in the pipeline
This commit is contained in:
@@ -3,5 +3,4 @@ mlir_tablegen(Passes.h.inc -gen-pass-decls -name Conversion)
|
||||
add_public_tablegen_target(ConcretelangConversionPassIncGen)
|
||||
add_dependencies(mlir-headers ConcretelangConversionPassIncGen)
|
||||
|
||||
add_subdirectory(FHEToTFHE)
|
||||
add_subdirectory(TFHEToConcrete)
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Patterns.td)
|
||||
mlir_tablegen(Patterns.h.inc -gen-rewriters -name FHE)
|
||||
add_public_tablegen_target(FHEToTFHEPatternsIncGen)
|
||||
add_dependencies(mlir-headers FHEToTFHEPatternsIncGen)
|
||||
|
||||
add_concretelang_doc(Patterns FHEToTFHEPatterns concretelang/ -gen-pass-doc)
|
||||
@@ -1,27 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_FHETOTFHE_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_FHETOTFHE_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
// ApplyLookupTableLowering indicates the strategy to lower an
|
||||
// FHE.apply_loopup_table ops
|
||||
enum ApplyLookupTableLowering {
|
||||
KeySwitchBoostrapLowering,
|
||||
WopPBSLowering,
|
||||
};
|
||||
|
||||
/// Create a pass to convert `FHE` dialect to `TFHE` dialect.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertFHEToTFHEPass(ApplyLookupTableLowering lower);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -1,89 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_FHETOTFHE_PATTERNS_H_
|
||||
#define CONCRETELANG_CONVERSION_FHETOTFHE_PATTERNS_H_
|
||||
|
||||
#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
using FHE::EncryptedIntegerType;
|
||||
using TFHE::GLWECipherTextType;
|
||||
|
||||
/// Converts FHE::EncryptedInteger into TFHE::GlweCiphetext
|
||||
GLWECipherTextType
|
||||
convertTypeEncryptedIntegerToGLWE(mlir::MLIRContext *context,
|
||||
EncryptedIntegerType eint) {
|
||||
return GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth(),
|
||||
llvm::ArrayRef<int64_t>());
|
||||
}
|
||||
|
||||
/// Converts the type `t` to `TFHE::GlweCiphetext` if `t` is a
|
||||
/// `FHE::EncryptedInteger`, otherwise just returns `t`.
|
||||
mlir::Type convertTypeToGLWEIfEncryptedIntegerType(mlir::MLIRContext *context,
|
||||
mlir::Type t) {
|
||||
if (auto eint = t.dyn_cast<EncryptedIntegerType>())
|
||||
return convertTypeEncryptedIntegerToGLWE(context, eint);
|
||||
|
||||
return t;
|
||||
}
|
||||
|
||||
mlir::Value createZeroGLWEOpFromFHE(mlir::PatternRewriter &rewriter,
|
||||
mlir::Location loc, mlir::OpResult result) {
|
||||
mlir::SmallVector<mlir::Value> args{};
|
||||
mlir::SmallVector<mlir::NamedAttribute, 0> attrs;
|
||||
mlir::SmallVector<mlir::Type, 1> resTypes{result.getType()};
|
||||
TFHE::ZeroGLWEOp op =
|
||||
rewriter.create<TFHE::ZeroGLWEOp>(loc, resTypes, args, attrs);
|
||||
convertOperandAndResultTypes(rewriter, op,
|
||||
convertTypeToGLWEIfEncryptedIntegerType);
|
||||
return op.getODSResults(0).front();
|
||||
}
|
||||
|
||||
template <class Operator>
|
||||
mlir::Value createGLWEOpFromFHE(mlir::PatternRewriter &rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
mlir::Value arg1, mlir::OpResult result) {
|
||||
mlir::SmallVector<mlir::Value, 2> args{arg0, arg1};
|
||||
mlir::SmallVector<mlir::NamedAttribute, 0> attrs;
|
||||
mlir::SmallVector<mlir::Type, 1> resTypes{result.getType()};
|
||||
Operator op = rewriter.create<Operator>(loc, resTypes, args, attrs);
|
||||
convertOperandAndResultTypes(rewriter, op,
|
||||
convertTypeToGLWEIfEncryptedIntegerType);
|
||||
return op.getODSResults(0).front();
|
||||
}
|
||||
|
||||
template <class Operator>
|
||||
mlir::Value createGLWEOpFromFHE(mlir::PatternRewriter &rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
mlir::OpResult result) {
|
||||
mlir::SmallVector<mlir::Value, 1> args{arg0};
|
||||
mlir::SmallVector<mlir::NamedAttribute, 0> attrs;
|
||||
mlir::SmallVector<mlir::Type, 1> resTypes{result.getType()};
|
||||
Operator op = rewriter.create<Operator>(loc, resTypes, args, attrs);
|
||||
convertOperandAndResultTypes(rewriter, op,
|
||||
convertTypeToGLWEIfEncryptedIntegerType);
|
||||
return op.getODSResults(0).front();
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
namespace {
|
||||
#include "concretelang/Conversion/FHEToTFHE/Patterns.h.inc"
|
||||
}
|
||||
|
||||
void populateWithGeneratedFHEToTFHE(mlir::RewritePatternSet &patterns) {
|
||||
populateWithGenerated(patterns);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,45 +0,0 @@
|
||||
#ifndef CONCRETELANG_CONVERSION_FHETOTFHE_PATTERNS
|
||||
#define CONCRETELANG_CONVERSION_FHETOTFHE_PATTERNS
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
include "mlir/IR/PatternBase.td"
|
||||
include "concretelang/Dialect/FHE/IR/FHEOps.td"
|
||||
include "concretelang/Dialect/TFHE/IR/TFHEOps.td"
|
||||
|
||||
def createZeroGLWEOp : NativeCodeCall<"mlir::concretelang::createZeroGLWEOpFromFHE($_builder, $_loc, $0)">;
|
||||
|
||||
def ZeroEintPattern : Pat<
|
||||
(FHE_ZeroEintOp:$result),
|
||||
(createZeroGLWEOp $result)>;
|
||||
|
||||
def createAddGLWEIntOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromFHE<mlir::concretelang::TFHE::AddGLWEIntOp>($_builder, $_loc, $0, $1, $2)">;
|
||||
|
||||
def AddEintIntPattern : Pat<
|
||||
(FHE_AddEintIntOp:$result $arg0, $arg1),
|
||||
(createAddGLWEIntOp $arg0, $arg1, $result)>;
|
||||
|
||||
def createAddGLWEOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromFHE<mlir::concretelang::TFHE::AddGLWEOp>($_builder, $_loc, $0, $1, $2)">;
|
||||
|
||||
def AddEintPattern : Pat<
|
||||
(FHE_AddEintOp:$result $arg0, $arg1),
|
||||
(createAddGLWEOp $arg0, $arg1, $result)>;
|
||||
|
||||
def createSubGLWEIntOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromFHE<mlir::concretelang::TFHE::SubGLWEIntOp>($_builder, $_loc, $0, $1, $2)">;
|
||||
|
||||
def SubIntEintPattern : Pat<
|
||||
(FHE_SubIntEintOp:$result $arg0, $arg1),
|
||||
(createSubGLWEIntOp $arg0, $arg1, $result)>;
|
||||
|
||||
def createNegGLWEOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromFHE<mlir::concretelang::TFHE::NegGLWEOp>($_builder, $_loc, $0, $1)">;
|
||||
|
||||
def NegEintPattern : Pat<
|
||||
(FHE_NegEintOp:$result $arg0),
|
||||
(createNegGLWEOp $arg0, $result)>;
|
||||
|
||||
def createMulGLWEIntOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromFHE<mlir::concretelang::TFHE::MulGLWEIntOp>($_builder, $_loc, $0, $1, $2)">;
|
||||
|
||||
def MulEintIntPattern : Pat<
|
||||
(FHE_MulEintIntOp:$result $arg0, $arg1),
|
||||
(createMulGLWEIntOp $arg0, $arg1, $result)>;
|
||||
|
||||
#endif
|
||||
51
compiler/include/concretelang/Conversion/FHEToTFHECrt/Pass.h
Normal file
51
compiler/include/concretelang/Conversion/FHEToTFHECrt/Pass.h
Normal file
@@ -0,0 +1,51 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_FHETOTFHECRT_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_FHETOTFHECRT_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include <list>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
struct CrtLoweringParameters {
|
||||
mlir::SmallVector<int64_t> mods;
|
||||
mlir::SmallVector<int64_t> bits;
|
||||
size_t nMods;
|
||||
size_t modsProd;
|
||||
size_t bitsTotal;
|
||||
size_t polynomialSize;
|
||||
size_t lutSize;
|
||||
|
||||
CrtLoweringParameters(mlir::SmallVector<int64_t> mods, size_t polySize)
|
||||
: mods(mods), polynomialSize(polySize) {
|
||||
nMods = mods.size();
|
||||
modsProd = 1;
|
||||
bitsTotal = 0;
|
||||
bits.clear();
|
||||
for (auto &mod : mods) {
|
||||
modsProd *= mod;
|
||||
uint64_t nbits =
|
||||
static_cast<uint64_t>(ceil(log2(static_cast<double>(mod))));
|
||||
bits.push_back(nbits);
|
||||
bitsTotal += nbits;
|
||||
}
|
||||
size_t lutCrtSize = size_t(1) << bitsTotal;
|
||||
lutCrtSize = std::max(lutCrtSize, polynomialSize);
|
||||
lutSize = mods.size() * lutCrtSize;
|
||||
}
|
||||
};
|
||||
|
||||
/// Create a pass to convert `FHE` dialect to `TFHE` dialect with the crt
|
||||
// strategy.
|
||||
std::unique_ptr<OperationPass<mlir::ModuleOp>>
|
||||
createConvertFHEToTFHECrtPass(CrtLoweringParameters lowering);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,28 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_FHETOTFHESCALAR_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_FHETOTFHESCALAR_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include <list>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
struct ScalarLoweringParameters {
|
||||
size_t polynomialSize;
|
||||
ScalarLoweringParameters(size_t polySize) : polynomialSize(polySize){};
|
||||
};
|
||||
|
||||
/// Create a pass to convert `FHE` dialect to `TFHE` dialect with the scalar
|
||||
// strategy.
|
||||
std::unique_ptr<OperationPass<mlir::ModuleOp>>
|
||||
createConvertFHEToTFHEScalarPass(ScalarLoweringParameters loweringParameters);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -17,7 +17,8 @@
|
||||
#include "concretelang/Conversion/ConcreteToBConcrete/Pass.h"
|
||||
#include "concretelang/Conversion/ExtractSDFGOps/Pass.h"
|
||||
#include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h"
|
||||
#include "concretelang/Conversion/FHEToTFHE/Pass.h"
|
||||
#include "concretelang/Conversion/FHEToTFHECrt/Pass.h"
|
||||
#include "concretelang/Conversion/FHEToTFHEScalar/Pass.h"
|
||||
#include "concretelang/Conversion/LinalgExtras/Passes.h"
|
||||
#include "concretelang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
|
||||
#include "concretelang/Conversion/SDFGToStreamEmulator/Pass.h"
|
||||
|
||||
@@ -9,10 +9,18 @@ def FHETensorOpsToLinalg : Pass<"fhe-tensor-ops-to-linalg", "::mlir::func::FuncO
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect"];
|
||||
}
|
||||
|
||||
def FHEToTFHE : Pass<"fhe-to-tfhe", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the FHE dialect to TFHE";
|
||||
def FHEToTFHEScalar : Pass<"fhe-to-tfhe-scalar", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the FHE dialect to TFHE using the scalar strategy.";
|
||||
let description = [{ Lowers operations from the FHE dialect to Std + Math }];
|
||||
let constructor = "mlir::concretelang::createConvertFHEToTFHEPass()";
|
||||
let constructor = "mlir::concretelang::createConvertFHEToTFHEScalarPass()";
|
||||
let options = [];
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect"];
|
||||
}
|
||||
|
||||
def FHEToTFHECrt : Pass<"fhe-to-tfhe-crt", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the FHE dialect to TFHE using the crt strategy.";
|
||||
let description = [{ Lowers operations from the FHE dialect to Std + Math }];
|
||||
let constructor = "mlir::concretelang::createConvertFHEToTFHECrtPass()";
|
||||
let options = [];
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect"];
|
||||
}
|
||||
|
||||
@@ -26,8 +26,7 @@ LweCiphertextType convertTypeToLWE(mlir::MLIRContext *context,
|
||||
auto glwe = type.dyn_cast_or_null<GLWECipherTextType>();
|
||||
if (glwe != nullptr) {
|
||||
assert(glwe.getPolynomialSize() == 1);
|
||||
return LweCiphertextType::get(context, glwe.getDimension(), glwe.getP(),
|
||||
glwe.getCrtDecomposition());
|
||||
return LweCiphertextType::get(context, glwe.getDimension(), glwe.getP());
|
||||
}
|
||||
auto lwe = type.dyn_cast_or_null<LweCiphertextType>();
|
||||
if (lwe != nullptr) {
|
||||
|
||||
@@ -25,56 +25,21 @@ def BConcrete_AddLweTensorOp : BConcrete_Op<"add_lwe_tensor", [NoSideEffect]> {
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_AddCRTLweTensorOp : BConcrete_Op<"add_crt_lwe_tensor", [NoSideEffect]> {
|
||||
let arguments = (ins
|
||||
2DTensorOf<[I64]>:$lhs,
|
||||
2DTensorOf<[I64]>:$rhs,
|
||||
I64ArrayAttr:$crtDecomposition
|
||||
);
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_AddPlaintextLweTensorOp : BConcrete_Op<"add_plaintext_lwe_tensor", [NoSideEffect]> {
|
||||
let arguments = (ins 1DTensorOf<[I64]>:$lhs, I64:$rhs);
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_AddPlaintextCRTLweTensorOp : BConcrete_Op<"add_plaintext_crt_lwe_tensor", [NoSideEffect]> {
|
||||
let arguments = (ins
|
||||
2DTensorOf<[I64]>:$lhs,
|
||||
AnyInteger:$rhs,
|
||||
I64ArrayAttr:$crtDecomposition
|
||||
);
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_MulCleartextLweTensorOp : BConcrete_Op<"mul_cleartext_lwe_tensor", [NoSideEffect]> {
|
||||
let arguments = (ins 1DTensorOf<[I64]>:$lhs, I64:$rhs);
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_MulCleartextCRTLweTensorOp : BConcrete_Op<"mul_cleartext_crt_lwe_tensor", [NoSideEffect]> {
|
||||
let arguments = (ins
|
||||
2DTensorOf<[I64]>:$lhs,
|
||||
AnyInteger:$rhs,
|
||||
I64ArrayAttr:$crtDecomposition
|
||||
);
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_NegateLweTensorOp : BConcrete_Op<"negate_lwe_tensor", [NoSideEffect]> {
|
||||
let arguments = (ins 1DTensorOf<[I64]>:$ciphertext);
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_NegateCRTLweTensorOp : BConcrete_Op<"negate_crt_lwe_tensor", [NoSideEffect]> {
|
||||
let arguments = (ins
|
||||
2DTensorOf<[I64]>:$ciphertext,
|
||||
I64ArrayAttr:$crtDecomposition
|
||||
);
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_KeySwitchLweTensorOp : BConcrete_Op<"keyswitch_lwe_tensor", [NoSideEffect]> {
|
||||
let arguments = (ins
|
||||
// LweKeySwitchKeyType:$keyswitch_key,
|
||||
@@ -99,6 +64,49 @@ def BConcrete_BatchedKeySwitchLweTensorOp : BConcrete_Op<"batched_keyswitch_lwe_
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_EncodeExpandLutForBootstrapTensorOp : BConcrete_Op<"encode_expand_lut_for_bootstrap_tensor", [NoSideEffect]> {
|
||||
let summary =
|
||||
"Encode and expand a lookup table so that it can be used for a bootstrap.";
|
||||
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]> : $input_lookup_table,
|
||||
I32Attr: $polySize,
|
||||
I32Attr: $outputBits
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[I64]> : $result);
|
||||
}
|
||||
|
||||
def BConcrete_EncodeExpandLutForWopPBSTensorOp : BConcrete_Op<"encode_expand_lut_for_woppbs_tensor", [NoSideEffect]> {
|
||||
let summary =
|
||||
"Encode and expand a lookup table so that it can be used for a wop pbs.";
|
||||
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]> : $input_lookup_table,
|
||||
I64ArrayAttr: $crtDecomposition,
|
||||
I64ArrayAttr: $crtBits,
|
||||
I32Attr : $polySize,
|
||||
I32Attr : $modulusProduct
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[I64]> : $result);
|
||||
}
|
||||
|
||||
|
||||
def BConcrete_EncodePlaintextWithCrtTensorOp : BConcrete_Op<"encode_plaintext_with_crt_tensor", [NoSideEffect]> {
|
||||
let summary =
|
||||
"Encodes a plaintext by decomposing it on a crt basis.";
|
||||
|
||||
let arguments = (ins
|
||||
I64 : $input,
|
||||
I64ArrayAttr: $mods,
|
||||
I64Attr: $modsProd
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[I64]> : $result);
|
||||
}
|
||||
|
||||
|
||||
def BConcrete_BootstrapLweTensorOp : BConcrete_Op<"bootstrap_lwe_tensor", [NoSideEffect]> {
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]>:$input_ciphertext,
|
||||
@@ -144,8 +152,7 @@ def BConcrete_WopPBSCRTLweTensorOp : BConcrete_Op<"wop_pbs_crt_lwe_tensor", [NoS
|
||||
I32Attr : $packingKeySwitchBaseLog,
|
||||
// Circuit bootstrap parameters
|
||||
I32Attr : $circuitBootstrapLevel,
|
||||
I32Attr : $circuitBootstrapBaseLog,
|
||||
I64ArrayAttr:$crtDecomposition
|
||||
I32Attr : $circuitBootstrapBaseLog
|
||||
);
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
@@ -153,6 +160,8 @@ def BConcrete_WopPBSCRTLweTensorOp : BConcrete_Op<"wop_pbs_crt_lwe_tensor", [NoS
|
||||
// BConcrete memref operators /////////////////////////////////////////////////
|
||||
|
||||
def BConcrete_LweBuffer : MemRefRankOf<[I64], [1]>;
|
||||
def BConcrete_LutBuffer : MemRefRankOf<[I64], [1]>;
|
||||
def BConcrete_CrtPlaintextBuffer : MemRefRankOf<[I64], [1]>;
|
||||
def BConcrete_LweCRTBuffer : MemRefRankOf<[I64], [2]>;
|
||||
def BConcrete_BatchLweBuffer : MemRefRankOf<[I64], [2]>;
|
||||
|
||||
@@ -209,11 +218,49 @@ def BConcrete_BatchedKeySwitchLweBufferOp : BConcrete_Op<"batched_keyswitch_lwe_
|
||||
);
|
||||
}
|
||||
|
||||
def BConcrete_EncodeExpandLutForBootstrapBufferOp : BConcrete_Op<"encode_expand_lut_for_bootstrap_buffer"> {
|
||||
let summary =
|
||||
"Encode and expand a lookup table so that it can be used for a bootstrap.";
|
||||
|
||||
let arguments = (ins
|
||||
BConcrete_LutBuffer: $result,
|
||||
BConcrete_LutBuffer: $input_lookup_table,
|
||||
I32Attr: $polySize,
|
||||
I32Attr: $outputBits
|
||||
);
|
||||
}
|
||||
|
||||
def BConcrete_EncodeExpandLutForWopPBSBufferOp : BConcrete_Op<"encode_expand_lut_for_woppbs_buffer"> {
|
||||
let summary =
|
||||
"Encode and expand a lookup table so that it can be used for a wop pbs.";
|
||||
|
||||
let arguments = (ins
|
||||
BConcrete_LutBuffer : $result,
|
||||
BConcrete_LutBuffer : $input_lookup_table,
|
||||
I64ArrayAttr: $crtDecomposition,
|
||||
I64ArrayAttr: $crtBits,
|
||||
I32Attr : $polySize,
|
||||
I32Attr : $modulusProduct
|
||||
);
|
||||
}
|
||||
|
||||
def BConcrete_EncodePlaintextWithCrtBufferOp : BConcrete_Op<"encode_plaintext_with_crt_buffer"> {
|
||||
let summary =
|
||||
"Encodes a plaintext by decomposing it on a crt basis.";
|
||||
|
||||
let arguments = (ins
|
||||
BConcrete_CrtPlaintextBuffer: $result,
|
||||
I64 : $input,
|
||||
I64ArrayAttr: $mods,
|
||||
I64Attr: $modsProd
|
||||
);
|
||||
}
|
||||
|
||||
def BConcrete_BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
BConcrete_LweBuffer:$result,
|
||||
BConcrete_LweBuffer:$input_ciphertext,
|
||||
MemRefRankOf<[I64], [1]>:$lookup_table,
|
||||
BConcrete_LutBuffer:$lookup_table,
|
||||
I32Attr:$inputLweDim,
|
||||
I32Attr:$polySize,
|
||||
I32Attr:$level,
|
||||
@@ -227,7 +274,7 @@ def BConcrete_BatchedBootstrapLweBufferOp : BConcrete_Op<"batched_bootstrap_lwe_
|
||||
let arguments = (ins
|
||||
BConcrete_BatchLweBuffer:$result,
|
||||
BConcrete_BatchLweBuffer:$input_ciphertext,
|
||||
MemRefRankOf<[I64], [1]>:$lookup_table,
|
||||
BConcrete_LutBuffer:$lookup_table,
|
||||
I32Attr:$inputLweDim,
|
||||
I32Attr:$polySize,
|
||||
I32Attr:$level,
|
||||
@@ -241,7 +288,7 @@ def BConcrete_WopPBSCRTLweBufferOp : BConcrete_Op<"wop_pbs_crt_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
BConcrete_LweCRTBuffer:$result,
|
||||
BConcrete_LweCRTBuffer:$ciphertext,
|
||||
MemRefRankOf<[I64], [1]>:$lookup_table,
|
||||
BConcrete_LutBuffer:$lookup_table,
|
||||
// Bootstrap parameters
|
||||
I32Attr : $bootstrapLevel,
|
||||
I32Attr : $bootstrapBaseLog,
|
||||
|
||||
@@ -53,6 +53,47 @@ def Concrete_NegateLweCiphertextOp : Concrete_Op<"negate_lwe_ciphertext"> {
|
||||
let results = (outs Concrete_LweCiphertextType:$result);
|
||||
}
|
||||
|
||||
def Concrete_EncodeExpandLutForBootstrapOp : Concrete_Op<"encode_expand_lut_for_bootstrap"> {
|
||||
let summary =
|
||||
"Encode and expand a lookup table so that it can be used for a bootstrap.";
|
||||
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]> : $input_lookup_table,
|
||||
I32Attr: $polySize,
|
||||
I32Attr: $outputBits
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[I64]> : $result);
|
||||
}
|
||||
|
||||
def Concrete_EncodeExpandLutForWopPBSOp : Concrete_Op<"encode_expand_lut_for_woppbs"> {
|
||||
let summary =
|
||||
"Encode and expand a lookup table so that it can be used for a wop pbs.";
|
||||
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]> : $input_lookup_table,
|
||||
I64ArrayAttr: $crtDecomposition,
|
||||
I64ArrayAttr: $crtBits,
|
||||
I32Attr : $polySize,
|
||||
I32Attr : $modulusProduct
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[I64]> : $result);
|
||||
}
|
||||
|
||||
def Concrete_EncodePlaintextWithCrtOp : Concrete_Op<"encode_plaintext_with_crt"> {
|
||||
let summary =
|
||||
"Encodes a plaintext by decomposing it on a crt basis.";
|
||||
|
||||
let arguments = (ins
|
||||
I64 : $input,
|
||||
I64ArrayAttr: $mods,
|
||||
I64Attr: $modsProd
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[I64]> : $result);
|
||||
}
|
||||
|
||||
def Concrete_BootstrapLweOp : Concrete_Op<"bootstrap_lwe", [BatchableOpInterface]> {
|
||||
let summary = "Bootstraps a LWE ciphertext with a GLWE trivial encryption of the lookup table";
|
||||
|
||||
@@ -153,7 +194,7 @@ def Concrete_WopPBSLweOp : Concrete_Op<"wop_pbs_lwe"> {
|
||||
let summary = "";
|
||||
|
||||
let arguments = (ins
|
||||
Concrete_LweCiphertextType:$ciphertext,
|
||||
Type<And<[TensorOf<[Concrete_LweCiphertextType]>.predicate, HasStaticShapePred]>>:$ciphertexts,
|
||||
1DTensorOf<[I64]>:$accumulator,
|
||||
// Bootstrap parameters
|
||||
I32Attr : $bootstrapLevel,
|
||||
@@ -168,9 +209,11 @@ def Concrete_WopPBSLweOp : Concrete_Op<"wop_pbs_lwe"> {
|
||||
I32Attr : $packingKeySwitchBaseLog,
|
||||
// Circuit bootstrap parameters
|
||||
I32Attr : $circuitBootstrapLevel,
|
||||
I32Attr : $circuitBootstrapBaseLog
|
||||
I32Attr : $circuitBootstrapBaseLog,
|
||||
// Crt decomposition
|
||||
I64ArrayAttr: $crtDecomposition
|
||||
);
|
||||
let results = (outs Concrete_LweCiphertextType:$result);
|
||||
let results = (outs Type<And<[TensorOf<[Concrete_LweCiphertextType]>.predicate, HasStaticShapePred]>>:$result);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -40,9 +40,7 @@ def Concrete_LweCiphertextType : Concrete_Type<"LweCiphertext", [MemRefElementTy
|
||||
// The dimension of the lwe ciphertext
|
||||
"signed":$dimension,
|
||||
// Precision of the lwe ciphertext
|
||||
"signed":$p,
|
||||
// CRT decomposition for large integers
|
||||
ArrayRefParameter<"int64_t", "CRT decomposition">:$crtDecomposition
|
||||
"signed":$p
|
||||
|
||||
);
|
||||
|
||||
|
||||
@@ -19,6 +19,48 @@ include "concretelang/Dialect/TFHE/IR/TFHETypes.td"
|
||||
class TFHE_Op<string mnemonic, list<Trait> traits = []>
|
||||
: Op<TFHE_Dialect, mnemonic, traits>;
|
||||
|
||||
def TFHE_EncodeExpandLutForBootstrapOp : TFHE_Op<"encode_expand_lut_for_bootstrap"> {
|
||||
let summary =
|
||||
"Encode and expand a lookup table so that it can be used for a bootstrap.";
|
||||
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]> : $input_lookup_table,
|
||||
I32Attr: $polySize,
|
||||
I32Attr: $outputBits
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[I64]> : $result);
|
||||
}
|
||||
|
||||
def TFHE_EncodeExpandLutForWopPBSOp : TFHE_Op<"encode_expand_lut_for_woppbs"> {
|
||||
let summary =
|
||||
"Encode and expand a lookup table so that it can be used for a wop pbs.";
|
||||
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]> : $input_lookup_table,
|
||||
I64ArrayAttr: $crtDecomposition,
|
||||
I64ArrayAttr: $crtBits,
|
||||
I32Attr : $polySize,
|
||||
I32Attr : $modulusProduct
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[I64]> : $result);
|
||||
}
|
||||
|
||||
def TFHE_EncodePlaintextWithCrtOp : TFHE_Op<"encode_plaintext_with_crt"> {
|
||||
let summary =
|
||||
"Encodes a plaintext by decomposing it on a crt basis.";
|
||||
|
||||
let arguments = (ins
|
||||
I64 : $input,
|
||||
I64ArrayAttr: $mods,
|
||||
I64Attr: $modsProd
|
||||
);
|
||||
|
||||
let results = (outs 1DTensorOf<[I64]> : $result);
|
||||
}
|
||||
|
||||
|
||||
def TFHE_ZeroGLWEOp : TFHE_Op<"zero"> {
|
||||
let summary = "Returns a trivial encyption of 0";
|
||||
|
||||
@@ -92,6 +134,7 @@ def TFHE_KeySwitchGLWEOp : TFHE_Op<"keyswitch_glwe"> {
|
||||
let results = (outs TFHE_GLWECipherTextType : $result);
|
||||
}
|
||||
|
||||
|
||||
def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe"> {
|
||||
let summary =
|
||||
"Programmable bootstraping of a GLWE ciphertext with a lookup table";
|
||||
@@ -112,7 +155,7 @@ def TFHE_WopPBSGLWEOp : TFHE_Op<"wop_pbs_glwe"> {
|
||||
let summary = "";
|
||||
|
||||
let arguments = (ins
|
||||
TFHE_GLWECipherTextType : $ciphertext,
|
||||
Type<And<[TensorOf<[TFHE_GLWECipherTextType]>.predicate, HasStaticShapePred]>>: $ciphertexts,
|
||||
1DTensorOf<[I64]> : $lookupTable,
|
||||
// Bootstrap parameters
|
||||
I32Attr : $bootstrapLevel,
|
||||
@@ -127,9 +170,11 @@ def TFHE_WopPBSGLWEOp : TFHE_Op<"wop_pbs_glwe"> {
|
||||
I32Attr : $packingKeySwitchBaseLog,
|
||||
// Circuit bootstrap parameters
|
||||
I32Attr : $circuitBootstrapLevel,
|
||||
I32Attr : $circuitBootstrapBaseLog
|
||||
I32Attr : $circuitBootstrapBaseLog,
|
||||
// Crt decomposition
|
||||
I64ArrayAttr: $crtDecomposition
|
||||
);
|
||||
let results = (outs TFHE_GLWECipherTextType:$result);
|
||||
let results = (outs Type<And<[TensorOf<[TFHE_GLWECipherTextType]>.predicate, HasStaticShapePred]>>:$result);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -25,9 +25,7 @@ def TFHE_GLWECipherTextType
|
||||
// Number of bits of the ciphertext
|
||||
"signed":$bits,
|
||||
// Number of bits of the plain text representation
|
||||
"signed":$p,
|
||||
// CRT decomposition for large integers
|
||||
ArrayRefParameter<"int64_t", "CRT decomposition">:$crtDecomposition
|
||||
"signed":$p
|
||||
);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
@@ -21,16 +21,33 @@ extern "C" {
|
||||
/// \param out_MESSAGE_BITS number of bits of message to be used
|
||||
/// \param lut original LUT
|
||||
/// \param lut_size
|
||||
void encode_and_expand_lut(uint64_t *output, size_t output_size,
|
||||
size_t out_MESSAGE_BITS, const uint64_t *lut,
|
||||
size_t lut_size);
|
||||
void memref_encode_expand_lut_for_bootstrap(
|
||||
uint64_t *output_lut_allocated, uint64_t *output_lut_aligned,
|
||||
uint64_t output_lut_offset, uint64_t output_lut_size,
|
||||
uint64_t output_lut_stride, uint64_t *input_lut_allocated,
|
||||
uint64_t *input_lut_aligned, uint64_t input_lut_offset,
|
||||
uint64_t input_lut_size, uint64_t input_lut_stride, uint32_t poly_size,
|
||||
uint32_t out_MESSAGE_BITS);
|
||||
|
||||
void memref_expand_lut_in_trivial_glwe_ct_u64(
|
||||
uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
uint32_t poly_size, uint32_t glwe_dimension, uint32_t out_precision,
|
||||
uint64_t *lut_allocated, uint64_t *lut_aligned, uint64_t lut_offset,
|
||||
uint64_t lut_size, uint64_t lut_stride);
|
||||
void memref_encode_expand_lut_for_woppbs(
|
||||
uint64_t *output_lut_allocated, uint64_t *output_lut_aligned,
|
||||
uint64_t output_lut_offset, uint64_t output_lut_size,
|
||||
uint64_t output_lut_stride, uint64_t *input_lut_allocated,
|
||||
uint64_t *input_lut_aligned, uint64_t input_lut_offset,
|
||||
uint64_t input_lut_size, uint64_t input_lut_stride,
|
||||
uint64_t *crt_decomposition_allocated, uint64_t *crt_decomposition_aligned,
|
||||
uint64_t crt_decomposition_offset, uint64_t crt_decomposition_size,
|
||||
uint64_t crt_decomposition_stride, uint64_t *crt_bits_allocated,
|
||||
uint64_t *crt_bits_aligned, uint64_t crt_bits_offset,
|
||||
uint64_t crt_bits_size, uint64_t crt_bits_stride, uint32_t poly_size,
|
||||
uint32_t modulus_product);
|
||||
|
||||
void memref_encode_plaintext_with_crt(
|
||||
uint64_t *output_allocated, uint64_t *output_aligned,
|
||||
uint64_t output_offset, uint64_t output_size, uint64_t output_stride,
|
||||
uint64_t input, uint64_t *mods_allocated, uint64_t *output_lut_aligned,
|
||||
uint64_t mods_offset, uint64_t output_lut_size, uint64_t mods_stride,
|
||||
uint64_t mods_product);
|
||||
|
||||
void memref_add_lwe_ciphertexts_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
|
||||
@@ -192,6 +192,10 @@ public:
|
||||
/// Read sources and exit before any lowering
|
||||
FHE,
|
||||
|
||||
/// Read sources and lower all the FHELinalg operations to FHE operations
|
||||
/// and scf loops
|
||||
FHE_NO_LINALG,
|
||||
|
||||
/// Read sources and lower all FHE operations to TFHE
|
||||
/// operations
|
||||
TFHE,
|
||||
@@ -200,10 +204,6 @@ public:
|
||||
/// operations
|
||||
CONCRETE,
|
||||
|
||||
/// Read sources and lower all FHE and TFHE operations to Concrete
|
||||
/// operations with all linalg ops replaced by loops
|
||||
CONCRETEWITHLOOPS,
|
||||
|
||||
/// Read sources and lower all FHE, TFHE and Concrete operations to
|
||||
/// BConcrete operations
|
||||
BCONCRETE,
|
||||
|
||||
@@ -34,6 +34,12 @@ markFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::ArrayRef<int64_t> tileSizes,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelize, bool batch);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
|
||||
@@ -43,6 +43,12 @@ char memref_expand_lut_in_trivial_glwe_ct_u64[] =
|
||||
|
||||
char memref_wop_pbs_crt_buffer[] = "memref_wop_pbs_crt_buffer";
|
||||
|
||||
char memref_encode_plaintext_with_crt[] = "memref_encode_plaintext_with_crt";
|
||||
char memref_encode_expand_lut_for_bootstrap[] =
|
||||
"memref_encode_expand_lut_for_bootstrap";
|
||||
char memref_encode_expand_lut_for_woppbs[] =
|
||||
"memref_encode_expand_lut_for_woppbs";
|
||||
|
||||
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
|
||||
size_t rank) {
|
||||
std::vector<int64_t> shape(rank, -1);
|
||||
@@ -161,6 +167,24 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
contextType,
|
||||
},
|
||||
{});
|
||||
|
||||
} else if (funcName == memref_encode_plaintext_with_crt) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, rewriter.getI64Type(),
|
||||
memref1DType, rewriter.getI64Type()},
|
||||
{});
|
||||
} else if (funcName == memref_encode_expand_lut_for_bootstrap) {
|
||||
funcType =
|
||||
mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType,
|
||||
rewriter.getI32Type(), rewriter.getI32Type()},
|
||||
{});
|
||||
} else if (funcName == memref_encode_expand_lut_for_woppbs) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{memref1DType, memref1DType, memref1DType, memref1DType,
|
||||
rewriter.getI32Type(), rewriter.getI32Type()},
|
||||
{});
|
||||
} else {
|
||||
op->emitError("unknwon external function") << funcName;
|
||||
return mlir::failure();
|
||||
@@ -301,6 +325,92 @@ void wopPBSAddOperands(BConcrete::WopPBSCRTLweBufferOp op,
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
|
||||
void encodePlaintextWithCrtAddOperands(
|
||||
BConcrete::EncodePlaintextWithCrtBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands, mlir::RewriterBase &rewriter) {
|
||||
// mods
|
||||
mlir::Type modsType = mlir::RankedTensorType::get({(int)op.modsAttr().size()},
|
||||
rewriter.getI64Type());
|
||||
std::vector<int64_t> modsValues;
|
||||
for (auto a : op.mods()) {
|
||||
modsValues.push_back(a.cast<mlir::IntegerAttr>().getValue().getZExtValue());
|
||||
}
|
||||
auto modsAttr = rewriter.getI64TensorAttr(modsValues);
|
||||
auto modsOp =
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), modsAttr, modsType);
|
||||
auto modsGlobalMemref = mlir::bufferization::getGlobalFor(modsOp, 0);
|
||||
rewriter.eraseOp(modsOp);
|
||||
assert(!failed(modsGlobalMemref));
|
||||
auto modsGlobalRef = rewriter.create<memref::GetGlobalOp>(
|
||||
op.getLoc(), (*modsGlobalMemref).type(), (*modsGlobalMemref).getName());
|
||||
operands.push_back(getCastedMemRef(rewriter, modsGlobalRef));
|
||||
|
||||
// mods_prod
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.modsProdAttr()));
|
||||
}
|
||||
|
||||
void encodeExpandLutForBootstrapAddOperands(
|
||||
BConcrete::EncodeExpandLutForBootstrapBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands, mlir::RewriterBase &rewriter) {
|
||||
// poly_size
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
|
||||
// output bits
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.outputBitsAttr()));
|
||||
}
|
||||
|
||||
void encodeExpandLutForWopPBSAddOperands(
|
||||
BConcrete::EncodeExpandLutForWopPBSBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands, mlir::RewriterBase &rewriter) {
|
||||
|
||||
// crt_decomposition
|
||||
mlir::Type crtDecompositionType = mlir::RankedTensorType::get(
|
||||
{(int)op.crtDecompositionAttr().size()}, rewriter.getI64Type());
|
||||
std::vector<int64_t> crtDecompositionValues;
|
||||
for (auto a : op.crtDecomposition()) {
|
||||
crtDecompositionValues.push_back(
|
||||
a.cast<mlir::IntegerAttr>().getValue().getZExtValue());
|
||||
}
|
||||
auto crtDecompositionAttr = rewriter.getI64TensorAttr(crtDecompositionValues);
|
||||
auto crtDecompositionOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), crtDecompositionAttr, crtDecompositionType);
|
||||
auto crtDecompositionGlobalMemref =
|
||||
mlir::bufferization::getGlobalFor(crtDecompositionOp, 0);
|
||||
rewriter.eraseOp(crtDecompositionOp);
|
||||
assert(!failed(crtDecompositionGlobalMemref));
|
||||
auto crtDecompositionGlobalRef = rewriter.create<memref::GetGlobalOp>(
|
||||
op.getLoc(), (*crtDecompositionGlobalMemref).type(),
|
||||
(*crtDecompositionGlobalMemref).getName());
|
||||
operands.push_back(getCastedMemRef(rewriter, crtDecompositionGlobalRef));
|
||||
|
||||
// crt_bits
|
||||
mlir::Type crtBitsType = mlir::RankedTensorType::get(
|
||||
{(int)op.crtBitsAttr().size()}, rewriter.getI64Type());
|
||||
std::vector<int64_t> crtBitsValues;
|
||||
for (auto a : op.crtBits()) {
|
||||
crtBitsValues.push_back(
|
||||
a.cast<mlir::IntegerAttr>().getValue().getZExtValue());
|
||||
}
|
||||
auto crtBitsAttr = rewriter.getI64TensorAttr(crtBitsValues);
|
||||
auto crtBitsOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), crtBitsAttr, crtBitsType);
|
||||
auto crtBitsGlobalMemref = mlir::bufferization::getGlobalFor(crtBitsOp, 0);
|
||||
rewriter.eraseOp(crtBitsOp);
|
||||
assert(!failed(crtBitsGlobalMemref));
|
||||
auto crtBitsGlobalRef = rewriter.create<memref::GetGlobalOp>(
|
||||
op.getLoc(), (*crtBitsGlobalMemref).type(),
|
||||
(*crtBitsGlobalMemref).getName());
|
||||
operands.push_back(getCastedMemRef(rewriter, crtBitsGlobalRef));
|
||||
// poly_size
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
|
||||
// modulus_product
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.modulusProductAttr()));
|
||||
}
|
||||
|
||||
struct BConcreteToCAPIPass : public BConcreteToCAPIBase<BConcreteToCAPIPass> {
|
||||
|
||||
BConcreteToCAPIPass(bool gpu) : gpu(gpu) {}
|
||||
@@ -334,6 +444,18 @@ struct BConcreteToCAPIPass : public BConcreteToCAPIBase<BConcreteToCAPIPass> {
|
||||
patterns.add<BConcreteToCAPICallPattern<BConcrete::NegateLweBufferOp,
|
||||
memref_negate_lwe_ciphertext_u64>>(
|
||||
&getContext());
|
||||
patterns.add<
|
||||
BConcreteToCAPICallPattern<BConcrete::EncodePlaintextWithCrtBufferOp,
|
||||
memref_encode_plaintext_with_crt>>(
|
||||
&getContext(), encodePlaintextWithCrtAddOperands);
|
||||
patterns.add<BConcreteToCAPICallPattern<
|
||||
BConcrete::EncodeExpandLutForBootstrapBufferOp,
|
||||
memref_encode_expand_lut_for_bootstrap>>(
|
||||
&getContext(), encodeExpandLutForBootstrapAddOperands);
|
||||
patterns.add<
|
||||
BConcreteToCAPICallPattern<BConcrete::EncodeExpandLutForWopPBSBufferOp,
|
||||
memref_encode_expand_lut_for_woppbs>>(
|
||||
&getContext(), encodeExpandLutForWopPBSAddOperands);
|
||||
if (gpu) {
|
||||
patterns.add<BConcreteToCAPICallPattern<BConcrete::KeySwitchLweBufferOp,
|
||||
memref_keyswitch_lwe_cuda_u64>>(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
add_subdirectory(FHEToTFHE)
|
||||
add_subdirectory(FHEToTFHEScalar)
|
||||
add_subdirectory(FHEToTFHECrt)
|
||||
add_subdirectory(TFHEGlobalParametrization)
|
||||
add_subdirectory(TFHEToConcrete)
|
||||
add_subdirectory(FHETensorOpsToLinalg)
|
||||
|
||||
@@ -68,10 +68,6 @@ public:
|
||||
addConversion([&](mlir::concretelang::Concrete::LweCiphertextType type) {
|
||||
assert(type.getDimension() != -1);
|
||||
llvm::SmallVector<int64_t, 2> shape;
|
||||
auto crt = type.getCrtDecomposition();
|
||||
if (!crt.empty()) {
|
||||
shape.push_back(crt.size());
|
||||
}
|
||||
shape.push_back(type.getDimension() + 1);
|
||||
return mlir::RankedTensorType::get(
|
||||
shape, mlir::IntegerType::get(type.getContext(), 64));
|
||||
@@ -95,10 +91,6 @@ public:
|
||||
mlir::SmallVector<int64_t> newShape;
|
||||
newShape.reserve(type.getShape().size() + 1);
|
||||
newShape.append(type.getShape().begin(), type.getShape().end());
|
||||
auto crt = lwe.getCrtDecomposition();
|
||||
if (!crt.empty()) {
|
||||
newShape.push_back(crt.size());
|
||||
}
|
||||
newShape.push_back(lwe.getDimension() + 1);
|
||||
mlir::Type r = mlir::RankedTensorType::get(
|
||||
newShape, mlir::IntegerType::get(type.getContext(), 64));
|
||||
@@ -146,7 +138,7 @@ struct ZeroOpPattern : public mlir::OpRewritePattern<ZeroOp> {
|
||||
};
|
||||
};
|
||||
|
||||
template <typename ConcreteOp, typename BConcreteOp, typename BConcreteCRTOp>
|
||||
template <typename ConcreteOp, typename BConcreteOp>
|
||||
struct LowToBConcrete : public mlir::OpRewritePattern<ConcreteOp> {
|
||||
LowToBConcrete(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<ConcreteOp>(context, benefit) {}
|
||||
@@ -161,29 +153,9 @@ struct LowToBConcrete : public mlir::OpRewritePattern<ConcreteOp> {
|
||||
concreteOp.getOperation()->getAttrs();
|
||||
|
||||
mlir::Operation *bConcreteOp;
|
||||
if (resultTyRange.size() == 1 &&
|
||||
resultTyRange.front()
|
||||
.isa<mlir::concretelang::Concrete::LweCiphertextType>()) {
|
||||
auto crt = resultTyRange.front()
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>()
|
||||
.getCrtDecomposition();
|
||||
if (crt.empty()) {
|
||||
bConcreteOp = rewriter.replaceOpWithNewOp<BConcreteOp>(
|
||||
concreteOp, resultTyRange, concreteOp.getOperation()->getOperands(),
|
||||
attributes);
|
||||
} else {
|
||||
auto newAttributes = attributes.vec();
|
||||
newAttributes.push_back(rewriter.getNamedAttr(
|
||||
"crtDecomposition", rewriter.getI64ArrayAttr(crt)));
|
||||
bConcreteOp = rewriter.replaceOpWithNewOp<BConcreteCRTOp>(
|
||||
concreteOp, resultTyRange, concreteOp.getOperation()->getOperands(),
|
||||
newAttributes);
|
||||
}
|
||||
} else {
|
||||
bConcreteOp = rewriter.replaceOpWithNewOp<BConcreteOp>(
|
||||
concreteOp, resultTyRange, concreteOp.getOperation()->getOperands(),
|
||||
attributes);
|
||||
}
|
||||
bConcreteOp = rewriter.replaceOpWithNewOp<BConcreteOp>(
|
||||
concreteOp, resultTyRange, concreteOp.getOperation()->getOperands(),
|
||||
attributes);
|
||||
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
@@ -361,7 +333,6 @@ struct AddPlaintextLweCiphertextOpPattern
|
||||
matchAndRewrite(Concrete::AddPlaintextLweCiphertextOp concreteOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
auto loc = concreteOp.getLoc();
|
||||
mlir::concretelang::Concrete::LweCiphertextType resultTy =
|
||||
((mlir::Type)concreteOp->getResult(0).getType())
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
@@ -371,31 +342,11 @@ struct AddPlaintextLweCiphertextOpPattern
|
||||
llvm::ArrayRef<::mlir::NamedAttribute> attributes =
|
||||
concreteOp.getOperation()->getAttrs();
|
||||
|
||||
auto crt = resultTy.getCrtDecomposition();
|
||||
mlir::Operation *bConcreteOp;
|
||||
if (crt.empty()) {
|
||||
// Encode the plaintext value
|
||||
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
|
||||
loc, rewriter.getIntegerType(64), concreteOp.rhs());
|
||||
mlir::Value constantShiftOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
loc,
|
||||
rewriter.getI64IntegerAttr(64 - concreteOp.getType().getP() - 1));
|
||||
auto encoded = rewriter.create<mlir::arith::ShLIOp>(
|
||||
loc, rewriter.getI64Type(), castedInt, constantShiftOp);
|
||||
bConcreteOp =
|
||||
rewriter.replaceOpWithNewOp<BConcrete::AddPlaintextLweTensorOp>(
|
||||
concreteOp, newResultTy,
|
||||
mlir::ValueRange{concreteOp.lhs(), encoded}, attributes);
|
||||
} else {
|
||||
// The encoding is done when we eliminate CRT ops
|
||||
auto newAttributes = attributes.vec();
|
||||
newAttributes.push_back(rewriter.getNamedAttr(
|
||||
"crtDecomposition", rewriter.getI64ArrayAttr(crt)));
|
||||
bConcreteOp =
|
||||
rewriter.replaceOpWithNewOp<BConcrete::AddPlaintextCRTLweTensorOp>(
|
||||
concreteOp, newResultTy, concreteOp.getOperation()->getOperands(),
|
||||
newAttributes);
|
||||
}
|
||||
bConcreteOp =
|
||||
rewriter.replaceOpWithNewOp<BConcrete::AddPlaintextLweTensorOp>(
|
||||
concreteOp, newResultTy,
|
||||
mlir::ValueRange{concreteOp.lhs(), concreteOp.rhs()}, attributes);
|
||||
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
@@ -417,7 +368,6 @@ struct MulCleartextLweCiphertextOpPattern
|
||||
matchAndRewrite(Concrete::MulCleartextLweCiphertextOp concreteOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
auto loc = concreteOp.getLoc();
|
||||
mlir::concretelang::Concrete::LweCiphertextType resultTy =
|
||||
((mlir::Type)concreteOp->getResult(0).getType())
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
@@ -427,25 +377,11 @@ struct MulCleartextLweCiphertextOpPattern
|
||||
llvm::ArrayRef<::mlir::NamedAttribute> attributes =
|
||||
concreteOp.getOperation()->getAttrs();
|
||||
|
||||
auto crt = resultTy.getCrtDecomposition();
|
||||
mlir::Operation *bConcreteOp;
|
||||
if (crt.empty()) {
|
||||
// Encode the plaintext value
|
||||
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
|
||||
loc, rewriter.getIntegerType(64), concreteOp.rhs());
|
||||
bConcreteOp =
|
||||
rewriter.replaceOpWithNewOp<BConcrete::MulCleartextLweTensorOp>(
|
||||
concreteOp, newResultTy,
|
||||
mlir::ValueRange{concreteOp.lhs(), castedInt}, attributes);
|
||||
} else {
|
||||
auto newAttributes = attributes.vec();
|
||||
newAttributes.push_back(rewriter.getNamedAttr(
|
||||
"crtDecomposition", rewriter.getI64ArrayAttr(crt)));
|
||||
bConcreteOp =
|
||||
rewriter.replaceOpWithNewOp<BConcrete::MulCleartextCRTLweTensorOp>(
|
||||
concreteOp, newResultTy, concreteOp.getOperation()->getOperands(),
|
||||
newAttributes);
|
||||
}
|
||||
bConcreteOp =
|
||||
rewriter.replaceOpWithNewOp<BConcrete::MulCleartextLweTensorOp>(
|
||||
concreteOp, newResultTy,
|
||||
mlir::ValueRange{concreteOp.lhs(), concreteOp.rhs()}, attributes);
|
||||
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, bConcreteOp, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
@@ -468,11 +404,6 @@ struct ExtractSliceOpPattern
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
auto resultTy = extractSliceOp.result().getType();
|
||||
auto lweResultTy =
|
||||
resultTy.cast<mlir::RankedTensorType>()
|
||||
.getElementType()
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
auto nbBlock = lweResultTy.getCrtDecomposition().size();
|
||||
auto newResultTy =
|
||||
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
|
||||
|
||||
@@ -480,19 +411,12 @@ struct ExtractSliceOpPattern
|
||||
mlir::SmallVector<mlir::Attribute> staticOffsets;
|
||||
staticOffsets.append(extractSliceOp.static_offsets().begin(),
|
||||
extractSliceOp.static_offsets().end());
|
||||
if (nbBlock != 0) {
|
||||
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
|
||||
}
|
||||
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
|
||||
|
||||
// add the lweSize to the sizes
|
||||
mlir::SmallVector<mlir::Attribute> staticSizes;
|
||||
staticSizes.append(extractSliceOp.static_sizes().begin(),
|
||||
extractSliceOp.static_sizes().end());
|
||||
if (nbBlock != 0) {
|
||||
staticSizes.push_back(rewriter.getI64IntegerAttr(
|
||||
newResultTy.getDimSize(newResultTy.getRank() - 2)));
|
||||
}
|
||||
staticSizes.push_back(rewriter.getI64IntegerAttr(
|
||||
newResultTy.getDimSize(newResultTy.getRank() - 1)));
|
||||
|
||||
@@ -500,9 +424,6 @@ struct ExtractSliceOpPattern
|
||||
mlir::SmallVector<mlir::Attribute> staticStrides;
|
||||
staticStrides.append(extractSliceOp.static_strides().begin(),
|
||||
extractSliceOp.static_strides().end());
|
||||
if (nbBlock != 0) {
|
||||
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
}
|
||||
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
|
||||
// replace tensor.extract_slice to the new one
|
||||
@@ -545,29 +466,20 @@ struct ExtractOpPattern
|
||||
if (lweResultTy == nullptr) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto nbBlock = lweResultTy.getCrtDecomposition().size();
|
||||
auto newResultTy =
|
||||
converter.convertType(lweResultTy).cast<mlir::RankedTensorType>();
|
||||
auto rankOfResult = extractOp.indices().size() +
|
||||
/* for the lwe dimension */ 1 +
|
||||
/* for the block dimension */
|
||||
(nbBlock == 0 ? 0 : 1);
|
||||
auto rankOfResult = extractOp.indices().size() + 1;
|
||||
|
||||
// [min..., 0] for static_offsets ()
|
||||
mlir::SmallVector<mlir::Attribute> staticOffsets(
|
||||
rankOfResult,
|
||||
rewriter.getI64IntegerAttr(std::numeric_limits<int64_t>::min()));
|
||||
if (nbBlock != 0) {
|
||||
staticOffsets[staticOffsets.size() - 2] = rewriter.getI64IntegerAttr(0);
|
||||
}
|
||||
staticOffsets[staticOffsets.size() - 1] = rewriter.getI64IntegerAttr(0);
|
||||
|
||||
// [1..., lweDimension+1] for static_sizes or
|
||||
// [1..., nbBlock, lweDimension+1]
|
||||
mlir::SmallVector<mlir::Attribute> staticSizes(
|
||||
rankOfResult, rewriter.getI64IntegerAttr(1));
|
||||
if (nbBlock != 0) {
|
||||
staticSizes[staticSizes.size() - 2] = rewriter.getI64IntegerAttr(nbBlock);
|
||||
}
|
||||
staticSizes[staticSizes.size() - 1] = rewriter.getI64IntegerAttr(
|
||||
newResultTy.getDimSize(newResultTy.getRank() - 1));
|
||||
|
||||
@@ -577,14 +489,8 @@ struct ExtractOpPattern
|
||||
|
||||
// replace tensor.extract_slice to the new one
|
||||
mlir::SmallVector<int64_t> extractedSliceShape(rankOfResult, 1);
|
||||
if (nbBlock != 0) {
|
||||
extractedSliceShape[extractedSliceShape.size() - 2] = nbBlock;
|
||||
extractedSliceShape[extractedSliceShape.size() - 1] =
|
||||
newResultTy.getDimSize(1);
|
||||
} else {
|
||||
extractedSliceShape[extractedSliceShape.size() - 1] =
|
||||
newResultTy.getDimSize(0);
|
||||
}
|
||||
extractedSliceShape[extractedSliceShape.size() - 1] =
|
||||
newResultTy.getDimSize(0);
|
||||
|
||||
auto extractedSliceType =
|
||||
mlir::RankedTensorType::get(extractedSliceShape, rewriter.getI64Type());
|
||||
@@ -601,17 +507,12 @@ struct ExtractOpPattern
|
||||
});
|
||||
|
||||
mlir::ReassociationIndices reassociation;
|
||||
for (int64_t i = 0;
|
||||
i < extractedSliceType.getRank() - (nbBlock == 0 ? 0 : 1); i++) {
|
||||
for (int64_t i = 0; i < extractedSliceType.getRank(); i++) {
|
||||
reassociation.push_back(i);
|
||||
}
|
||||
|
||||
mlir::SmallVector<mlir::ReassociationIndices> reassocs{reassociation};
|
||||
|
||||
if (nbBlock != 0) {
|
||||
reassocs.push_back({extractedSliceType.getRank() - 1});
|
||||
}
|
||||
|
||||
mlir::tensor::CollapseShapeOp collapseOp =
|
||||
rewriter.replaceOpWithNewOp<mlir::tensor::CollapseShapeOp>(
|
||||
extractOp, newResultTy, extractedSlice, reassocs);
|
||||
@@ -644,7 +545,6 @@ struct InsertSliceOpPattern
|
||||
if (lweResultTy == nullptr) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto nbBlock = lweResultTy.getCrtDecomposition().size();
|
||||
auto newResultTy =
|
||||
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
|
||||
|
||||
@@ -652,19 +552,12 @@ struct InsertSliceOpPattern
|
||||
mlir::SmallVector<mlir::Attribute> staticOffsets;
|
||||
staticOffsets.append(insertSliceOp.static_offsets().begin(),
|
||||
insertSliceOp.static_offsets().end());
|
||||
if (nbBlock != 0) {
|
||||
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
|
||||
}
|
||||
staticOffsets.push_back(rewriter.getI64IntegerAttr(0));
|
||||
|
||||
// add lweDimension+1 to static_sizes
|
||||
mlir::SmallVector<mlir::Attribute> staticSizes;
|
||||
staticSizes.append(insertSliceOp.static_sizes().begin(),
|
||||
insertSliceOp.static_sizes().end());
|
||||
if (nbBlock != 0) {
|
||||
staticSizes.push_back(rewriter.getI64IntegerAttr(
|
||||
newResultTy.getDimSize(newResultTy.getRank() - 2)));
|
||||
}
|
||||
staticSizes.push_back(rewriter.getI64IntegerAttr(
|
||||
newResultTy.getDimSize(newResultTy.getRank() - 1)));
|
||||
|
||||
@@ -672,9 +565,6 @@ struct InsertSliceOpPattern
|
||||
mlir::SmallVector<mlir::Attribute> staticStrides;
|
||||
staticStrides.append(insertSliceOp.static_strides().begin(),
|
||||
insertSliceOp.static_strides().end());
|
||||
if (nbBlock != 0) {
|
||||
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
}
|
||||
staticStrides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
|
||||
// replace tensor.insert_slice with the new one
|
||||
@@ -710,7 +600,6 @@ struct InsertOpPattern : public mlir::OpRewritePattern<mlir::tensor::InsertOp> {
|
||||
if (lweResultTy == nullptr) {
|
||||
return mlir::failure();
|
||||
};
|
||||
auto hasBlock = lweResultTy.getCrtDecomposition().size() != 0;
|
||||
mlir::RankedTensorType newResultTy =
|
||||
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
|
||||
|
||||
@@ -718,9 +607,6 @@ struct InsertOpPattern : public mlir::OpRewritePattern<mlir::tensor::InsertOp> {
|
||||
mlir::SmallVector<mlir::OpFoldResult> offsets;
|
||||
offsets.append(insertOp.indices().begin(), insertOp.indices().end());
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
if (hasBlock) {
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
}
|
||||
|
||||
// Inserting a smaller tensor into a (potentially) bigger one. Set
|
||||
// dimensions for all leading dimensions of the target tensor not
|
||||
@@ -729,10 +615,6 @@ struct InsertOpPattern : public mlir::OpRewritePattern<mlir::tensor::InsertOp> {
|
||||
rewriter.getI64IntegerAttr(1));
|
||||
|
||||
// Add size for the bufferized source element
|
||||
if (hasBlock) {
|
||||
sizes.push_back(rewriter.getI64IntegerAttr(
|
||||
newResultTy.getDimSize(newResultTy.getRank() - 2)));
|
||||
}
|
||||
sizes.push_back(rewriter.getI64IntegerAttr(
|
||||
newResultTy.getDimSize(newResultTy.getRank() - 1)));
|
||||
|
||||
@@ -871,9 +753,6 @@ struct TensorShapeOpPattern : public mlir::OpRewritePattern<ShapeOp> {
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
auto resultTy = ((mlir::Type)shapeOp.result().getType()).cast<VecTy>();
|
||||
auto lweResultTy =
|
||||
((mlir::Type)resultTy.getElementType())
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
|
||||
auto newResultTy =
|
||||
((mlir::Type)converter.convertType(resultTy)).cast<VecTy>();
|
||||
@@ -886,12 +765,6 @@ struct TensorShapeOpPattern : public mlir::OpRewritePattern<ShapeOp> {
|
||||
auto oldReassocs = shapeOp.getReassociationIndices();
|
||||
mlir::SmallVector<mlir::ReassociationIndices> newReassocs;
|
||||
newReassocs.append(oldReassocs.begin(), oldReassocs.end());
|
||||
// add [rank-1] to reassociations if crt decomp
|
||||
if (!lweResultTy.getCrtDecomposition().empty()) {
|
||||
mlir::ReassociationIndices lweAssoc;
|
||||
lweAssoc.push_back(reassocTy.getRank() - 2);
|
||||
newReassocs.push_back(lweAssoc);
|
||||
}
|
||||
|
||||
// add [rank] to reassociations
|
||||
{
|
||||
@@ -1020,14 +893,21 @@ void ConcreteToBConcretePass::runOnOperation() {
|
||||
LowerBootstrap, LowerBatchedBootstrap, LowerKeySwitch,
|
||||
LowerBatchedKeySwitch,
|
||||
LowToBConcrete<mlir::concretelang::Concrete::AddLweCiphertextsOp,
|
||||
mlir::concretelang::BConcrete::AddLweTensorOp,
|
||||
BConcrete::AddCRTLweTensorOp>,
|
||||
mlir::concretelang::BConcrete::AddLweTensorOp>,
|
||||
AddPlaintextLweCiphertextOpPattern, MulCleartextLweCiphertextOpPattern,
|
||||
LowToBConcrete<
|
||||
mlir::concretelang::Concrete::EncodeExpandLutForBootstrapOp,
|
||||
mlir::concretelang::BConcrete::EncodeExpandLutForBootstrapTensorOp>,
|
||||
LowToBConcrete<
|
||||
mlir::concretelang::Concrete::EncodeExpandLutForWopPBSOp,
|
||||
mlir::concretelang::BConcrete::EncodeExpandLutForWopPBSTensorOp>,
|
||||
LowToBConcrete<
|
||||
mlir::concretelang::Concrete::EncodePlaintextWithCrtOp,
|
||||
mlir::concretelang::BConcrete::EncodePlaintextWithCrtTensorOp>,
|
||||
LowToBConcrete<mlir::concretelang::Concrete::NegateLweCiphertextOp,
|
||||
mlir::concretelang::BConcrete::NegateLweTensorOp,
|
||||
BConcrete::NegateCRTLweTensorOp>,
|
||||
LowToBConcrete<Concrete::WopPBSLweOp, BConcrete::WopPBSCRTLweTensorOp,
|
||||
BConcrete::WopPBSCRTLweTensorOp>>(&getContext());
|
||||
mlir::concretelang::BConcrete::NegateLweTensorOp>,
|
||||
LowToBConcrete<Concrete::WopPBSLweOp, BConcrete::WopPBSCRTLweTensorOp>>(
|
||||
&getContext());
|
||||
|
||||
// Add patterns to rewrite tensor operators that works on encrypted
|
||||
// tensors
|
||||
|
||||
@@ -1,391 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <iostream>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/Linalg/IR/Linalg.h>
|
||||
#include <mlir/IR/Operation.h>
|
||||
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/Conversion/FHEToTFHE/Patterns.h"
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
|
||||
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
|
||||
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTDialect.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTOps.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTTypes.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
|
||||
namespace FHE = mlir::concretelang::FHE;
|
||||
namespace TFHE = mlir::concretelang::TFHE;
|
||||
|
||||
namespace {
|
||||
|
||||
using mlir::concretelang::FHE::EncryptedIntegerType;
|
||||
using mlir::concretelang::TFHE::GLWECipherTextType;
|
||||
|
||||
/// FHEToTFHETypeConverter is a TypeConverter that transform
|
||||
/// `FHE.eint<p>` to `TFHE.glwe<{_,_,_}{p}>`
|
||||
class FHEToTFHETypeConverter : public mlir::TypeConverter {
|
||||
|
||||
public:
|
||||
FHEToTFHETypeConverter() {
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([](EncryptedIntegerType type) {
|
||||
return mlir::concretelang::convertTypeEncryptedIntegerToGLWE(
|
||||
type.getContext(), type);
|
||||
});
|
||||
addConversion([](mlir::RankedTensorType type) {
|
||||
auto eint =
|
||||
type.getElementType().dyn_cast_or_null<EncryptedIntegerType>();
|
||||
if (eint == nullptr) {
|
||||
return (mlir::Type)(type);
|
||||
}
|
||||
mlir::Type r = mlir::RankedTensorType::get(
|
||||
type.getShape(),
|
||||
mlir::concretelang::convertTypeEncryptedIntegerToGLWE(
|
||||
eint.getContext(), eint));
|
||||
return r;
|
||||
});
|
||||
addConversion([&](mlir::concretelang::RT::FutureType type) {
|
||||
return mlir::concretelang::RT::FutureType::get(
|
||||
this->convertType(type.dyn_cast<mlir::concretelang::RT::FutureType>()
|
||||
.getElementType()));
|
||||
});
|
||||
addConversion([&](mlir::concretelang::RT::PointerType type) {
|
||||
return mlir::concretelang::RT::PointerType::get(
|
||||
this->convertType(type.dyn_cast<mlir::concretelang::RT::PointerType>()
|
||||
.getElementType()));
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/// This rewrite pattern transforms any instance of `FHE.apply_lookup_table`
|
||||
/// operators.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = "FHE.apply_lookup_table"(%ct, %lut): (!FHE.eint<2>, tensor<4xi64>)
|
||||
/// ->(!FHE.eint<2>)
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %glwe_ks = "TFHE.keyswitch_glwe"(%ct)
|
||||
/// {baseLog = -1 : i32, level = -1 : i32}
|
||||
/// : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
/// %0 = "TFHE.bootstrap_glwe"(%glwe_ks, %lut)
|
||||
/// {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32,
|
||||
/// polynomialSize = -1 : i32}
|
||||
/// : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) ->
|
||||
/// !TFHE.glwe<{_,_,_}{2}>
|
||||
/// ```
|
||||
struct ApplyLookupTableEintOpToKeyswitchBootstrapPattern
|
||||
: public mlir::OpRewritePattern<FHE::ApplyLookupTableEintOp> {
|
||||
ApplyLookupTableEintOpToKeyswitchBootstrapPattern(
|
||||
mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<FHE::ApplyLookupTableEintOp>(context,
|
||||
benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::ApplyLookupTableEintOp lutOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
FHEToTFHETypeConverter converter;
|
||||
auto inputTy = converter.convertType(lutOp.a().getType())
|
||||
.cast<TFHE::GLWECipherTextType>();
|
||||
auto resultTy = converter.convertType(lutOp.getType());
|
||||
auto glweKs = rewriter.create<TFHE::KeySwitchGLWEOp>(
|
||||
lutOp.getLoc(), inputTy, lutOp.a(), -1, -1);
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, glweKs, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
return converter.convertType(t);
|
||||
});
|
||||
// %0 = "TFHE.bootstrap_glwe"(%glwe_ks, %glwe_lut)
|
||||
rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
|
||||
lutOp, resultTy, glweKs, lutOp.lut(), -1, -1, -1, -1);
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
/// This rewrite pattern transforms any instance of `FHE.apply_lookup_table`
|
||||
/// operators.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = "FHE.apply_lookup_table"(%ct, %lut): (!FHE.eint<2>, tensor<4xi64>)
|
||||
/// ->(!FHE.eint<2>)
|
||||
/// ```
|
||||
///
|
||||
/// becomes:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = "TFHE.wop_pbs_glwe"(%ct, %lut)
|
||||
/// : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) ->
|
||||
/// (!TFHE.glwe<{_,_,_}{2}>)
|
||||
/// ```
|
||||
struct ApplyLookupTableEintOpToWopPBSPattern
|
||||
: public mlir::OpRewritePattern<FHE::ApplyLookupTableEintOp> {
|
||||
ApplyLookupTableEintOpToWopPBSPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<FHE::ApplyLookupTableEintOp>(context,
|
||||
benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::ApplyLookupTableEintOp lutOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
FHEToTFHETypeConverter converter;
|
||||
auto resultTy = converter.convertType(lutOp.getType());
|
||||
// %0 = "TFHE.wop_pbs_glwe"(%ct, %lut)
|
||||
// : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) ->
|
||||
// (!TFHE.glwe<{_,_,_}{2}>)
|
||||
auto wopPBS = rewriter.replaceOpWithNewOp<TFHE::WopPBSGLWEOp>(
|
||||
lutOp, resultTy, lutOp.a(), lutOp.lut(), -1, -1, -1, -1, -1, -1, -1, -1,
|
||||
-1, -1);
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, wopPBS, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
return converter.convertType(t);
|
||||
});
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
/// This rewrite pattern transforms any instance of `FHE.sub_eint_int`
|
||||
/// operators to a negation and an addition.
|
||||
struct SubEintIntOpPattern : public mlir::OpRewritePattern<FHE::SubEintIntOp> {
|
||||
SubEintIntOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<FHE::SubEintIntOp>(context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::SubEintIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::Location location = op.getLoc();
|
||||
|
||||
mlir::Value lhs = op.getOperand(0);
|
||||
mlir::Value rhs = op.getOperand(1);
|
||||
|
||||
mlir::Type rhsType = rhs.getType();
|
||||
mlir::Attribute minusOneAttr = mlir::IntegerAttr::get(rhsType, -1);
|
||||
mlir::Value minusOne =
|
||||
rewriter.create<mlir::arith::ConstantOp>(location, minusOneAttr)
|
||||
.getResult();
|
||||
|
||||
mlir::Value negative =
|
||||
rewriter.create<mlir::arith::MulIOp>(location, rhs, minusOne)
|
||||
.getResult();
|
||||
|
||||
FHEToTFHETypeConverter converter;
|
||||
auto resultTy = converter.convertType(op.getType());
|
||||
|
||||
auto addition =
|
||||
rewriter.create<TFHE::AddGLWEIntOp>(location, resultTy, lhs, negative);
|
||||
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, addition, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
return converter.convertType(t);
|
||||
});
|
||||
|
||||
rewriter.replaceOp(op, {addition.getResult()});
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
/// This rewrite pattern transforms any instance of `FHE.sub_eint`
|
||||
/// operators to a negation and an addition.
|
||||
struct SubEintOpPattern : public mlir::OpRewritePattern<FHE::SubEintOp> {
|
||||
SubEintOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<FHE::SubEintOp>(context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::SubEintOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::Location location = op.getLoc();
|
||||
|
||||
mlir::Value lhs = op.getOperand(0);
|
||||
mlir::Value rhs = op.getOperand(1);
|
||||
|
||||
FHEToTFHETypeConverter converter;
|
||||
|
||||
auto rhsTy = converter.convertType(rhs.getType());
|
||||
auto negative = rewriter.create<TFHE::NegGLWEOp>(location, rhsTy, rhs);
|
||||
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, negative, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
return converter.convertType(t);
|
||||
});
|
||||
|
||||
auto resultTy = converter.convertType(op.getType());
|
||||
auto addition = rewriter.create<TFHE::AddGLWEOp>(location, resultTy, lhs,
|
||||
negative.getResult());
|
||||
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, addition, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
return converter.convertType(t);
|
||||
});
|
||||
|
||||
rewriter.replaceOp(op, {addition.getResult()});
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct FHEToTFHEPass : public FHEToTFHEBase<FHEToTFHEPass> {
|
||||
|
||||
FHEToTFHEPass(mlir::concretelang::ApplyLookupTableLowering lutLowerStrategy)
|
||||
: lutLowerStrategy(lutLowerStrategy) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto op = this->getOperation();
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
FHEToTFHETypeConverter converter;
|
||||
|
||||
// Mark ops from the target dialect as legal operations
|
||||
target.addLegalDialect<mlir::concretelang::TFHE::TFHEDialect>();
|
||||
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
|
||||
|
||||
// Make sure that no ops from `FHE` remain after the lowering
|
||||
target.addIllegalDialect<mlir::concretelang::FHE::FHEDialect>();
|
||||
|
||||
// Make sure that no ops `linalg.generic` that have illegal types
|
||||
target.addDynamicallyLegalOp<mlir::linalg::GenericOp,
|
||||
mlir::tensor::GenerateOp>(
|
||||
[&](mlir::Operation *op) {
|
||||
return (
|
||||
converter.isLegal(op->getOperandTypes()) &&
|
||||
converter.isLegal(op->getResultTypes()) &&
|
||||
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
|
||||
});
|
||||
|
||||
// Make sure that func has legal signature
|
||||
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
|
||||
[&](mlir::func::FuncOp funcOp) {
|
||||
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
|
||||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
target.addDynamicallyLegalOp<mlir::func::ConstantOp>(
|
||||
[&](mlir::func::ConstantOp op) {
|
||||
return FunctionConstantOpConversion<FHEToTFHETypeConverter>::isLegal(
|
||||
op, converter);
|
||||
});
|
||||
|
||||
// Add all patterns required to lower all ops from `FHE` to
|
||||
// `TFHE`
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
populateWithGeneratedFHEToTFHE(patterns);
|
||||
|
||||
patterns.add<
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>>(
|
||||
patterns.getContext(), converter);
|
||||
|
||||
switch (lutLowerStrategy) {
|
||||
case mlir::concretelang::KeySwitchBoostrapLowering:
|
||||
patterns.add<ApplyLookupTableEintOpToKeyswitchBootstrapPattern>(
|
||||
&getContext());
|
||||
break;
|
||||
case mlir::concretelang::WopPBSLowering:
|
||||
patterns.add<ApplyLookupTableEintOpToWopPBSPattern>(&getContext());
|
||||
break;
|
||||
}
|
||||
|
||||
patterns.add<SubEintOpPattern>(&getContext());
|
||||
patterns.add<SubEintIntOpPattern>(&getContext());
|
||||
patterns.add<FunctionConstantOpConversion<FHEToTFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
|
||||
FHEToTFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
patterns.add<
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::linalg::YieldOp>>(
|
||||
patterns.getContext(), converter);
|
||||
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
|
||||
FHEToTFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
patterns.add<
|
||||
RegionOpTypeConverterPattern<mlir::scf::ForOp, FHEToTFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<mlir::concretelang::GenericTypeAndOpConverterPattern<
|
||||
mlir::concretelang::FHE::ZeroTensorOp,
|
||||
mlir::concretelang::TFHE::ZeroTensorGLWEOp>>(&getContext(), converter);
|
||||
|
||||
mlir::concretelang::populateWithTensorTypeConverterPatterns(
|
||||
patterns, target, converter);
|
||||
|
||||
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
|
||||
patterns, converter);
|
||||
|
||||
// Conversion of RT Dialect Ops
|
||||
patterns.add<
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::MakeReadyFutureOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::AwaitFutureOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::CreateAsyncTaskOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::WorkFunctionReturnOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
|
||||
converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::MakeReadyFutureOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::AwaitFutureOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::CreateAsyncTaskOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(
|
||||
target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::WorkFunctionReturnOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
mlir::concretelang::ApplyLookupTableLowering lutLowerStrategy;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertFHEToTFHEPass(ApplyLookupTableLowering lower) {
|
||||
return std::make_unique<FHEToTFHEPass>(lower);
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -1,6 +1,6 @@
|
||||
add_mlir_dialect_library(
|
||||
FHEToTFHE
|
||||
FHEToTFHE.cpp
|
||||
FHEToTFHECrt
|
||||
FHEToTFHECrt.cpp
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
|
||||
DEPENDS
|
||||
@@ -12,4 +12,4 @@ add_mlir_dialect_library(
|
||||
MLIRTransforms
|
||||
MLIRMathDialect)
|
||||
|
||||
target_link_libraries(FHEToTFHE PUBLIC MLIRIR)
|
||||
target_link_libraries(FHEToTFHECrt PUBLIC MLIRIR)
|
||||
960
compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp
Normal file
960
compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp
Normal file
@@ -0,0 +1,960 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <iostream>
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/Linalg/IR/Linalg.h>
|
||||
#include <mlir/IR/Operation.h>
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/Conversion/FHEToTFHECrt/Pass.h"
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
|
||||
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
|
||||
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTDialect.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTOps.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTTypes.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
|
||||
namespace FHE = mlir::concretelang::FHE;
|
||||
namespace TFHE = mlir::concretelang::TFHE;
|
||||
namespace concretelang = mlir::concretelang;
|
||||
|
||||
namespace fhe_to_tfhe_crt_conversion {
|
||||
|
||||
namespace typing {
|
||||
|
||||
/// Converts `FHE::EncryptedInteger` into `Tensor<TFHE::GlweCiphetext>`.
|
||||
mlir::RankedTensorType convertEint(mlir::MLIRContext *context,
|
||||
FHE::EncryptedIntegerType eint,
|
||||
uint64_t crtLength) {
|
||||
return mlir::RankedTensorType::get(
|
||||
mlir::ArrayRef<int64_t>((int64_t)crtLength),
|
||||
TFHE::GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth()));
|
||||
}
|
||||
|
||||
/// Converts `Tensor<FHE::EncryptedInteger>` into a
|
||||
/// `Tensor<TFHE::GlweCiphertext>` if the element type is appropriate. Otherwise
|
||||
/// return the input type.
|
||||
mlir::Type maybeConvertEintTensor(mlir::MLIRContext *context,
|
||||
mlir::RankedTensorType maybeEintTensor,
|
||||
uint64_t crtLength) {
|
||||
if (!maybeEintTensor.getElementType().isa<FHE::EncryptedIntegerType>()) {
|
||||
return (mlir::Type)(maybeEintTensor);
|
||||
}
|
||||
auto eint =
|
||||
maybeEintTensor.getElementType().cast<FHE::EncryptedIntegerType>();
|
||||
auto currentShape = maybeEintTensor.getShape();
|
||||
mlir::SmallVector<int64_t> newShape =
|
||||
mlir::SmallVector<int64_t>(currentShape.begin(), currentShape.end());
|
||||
newShape.push_back((int64_t)crtLength);
|
||||
return mlir::RankedTensorType::get(
|
||||
llvm::ArrayRef<int64_t>(newShape),
|
||||
TFHE::GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth()));
|
||||
}
|
||||
|
||||
/// Converts the type `FHE::EncryptedInteger` to `Tensor<TFHE::GlweCiphetext>`
|
||||
/// if the input type is appropriate. Otherwise return the input type.
|
||||
mlir::Type maybeConvertEint(mlir::MLIRContext *context, mlir::Type t,
|
||||
uint64_t crtLength) {
|
||||
if (auto eint = t.dyn_cast<FHE::EncryptedIntegerType>())
|
||||
return convertEint(context, eint, crtLength);
|
||||
|
||||
return t;
|
||||
}
|
||||
|
||||
/// The type converter used to convert `FHE` to `TFHE` types using the crt
|
||||
/// strategy.
|
||||
class TypeConverter : public mlir::TypeConverter {
|
||||
|
||||
public:
|
||||
TypeConverter(concretelang::CrtLoweringParameters loweringParameters) {
|
||||
size_t nMods = loweringParameters.nMods;
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([=](FHE::EncryptedIntegerType type) {
|
||||
return convertEint(type.getContext(), type, nMods);
|
||||
});
|
||||
addConversion([=](mlir::RankedTensorType type) {
|
||||
return maybeConvertEintTensor(type.getContext(), type, nMods);
|
||||
});
|
||||
addConversion([&](concretelang::RT::FutureType type) {
|
||||
return concretelang::RT::FutureType::get(this->convertType(
|
||||
type.dyn_cast<concretelang::RT::FutureType>().getElementType()));
|
||||
});
|
||||
addConversion([&](concretelang::RT::PointerType type) {
|
||||
return concretelang::RT::PointerType::get(this->convertType(
|
||||
type.dyn_cast<concretelang::RT::PointerType>().getElementType()));
|
||||
});
|
||||
}
|
||||
|
||||
/// Returns a lambda that uses this converter to turn one type into another.
|
||||
std::function<mlir::Type(mlir::MLIRContext *, mlir::Type)>
|
||||
getConversionLambda() {
|
||||
return [&](mlir::MLIRContext *, mlir::Type t) { return convertType(t); };
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace typing
|
||||
|
||||
namespace lowering {
|
||||
|
||||
/// A pattern rewriter superclass used by most op rewriters during the
|
||||
/// conversion.
|
||||
template <typename T> struct CrtOpPattern : public mlir::OpRewritePattern<T> {
|
||||
|
||||
/// The lowering parameters are bound to the op rewriter.
|
||||
concretelang::CrtLoweringParameters loweringParameters;
|
||||
|
||||
CrtOpPattern(mlir::MLIRContext *context,
|
||||
concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<T>(context, benefit),
|
||||
loweringParameters(params) {}
|
||||
|
||||
/// Writes an `scf::for` that loops over the crt dimension of two tensors and
|
||||
/// execute the input lambda to write the loop body. Returns the first result
|
||||
/// of the op.
|
||||
///
|
||||
/// Note:
|
||||
/// -----
|
||||
///
|
||||
/// + The type of `firstArgTensor` type is used as output type.
|
||||
mlir::Value writeBinaryTensorLoop(
|
||||
mlir::Location location, mlir::Value firstTensor,
|
||||
mlir::Value secondTensor, mlir::PatternRewriter &rewriter,
|
||||
mlir::function_ref<void(mlir::OpBuilder &, mlir::Location, mlir::Value,
|
||||
mlir::ValueRange)>
|
||||
body) const {
|
||||
|
||||
// Create the loop
|
||||
mlir::arith::ConstantOp zeroConstantOp =
|
||||
rewriter.create<mlir::arith::ConstantIndexOp>(location, 0);
|
||||
mlir::arith::ConstantOp oneConstantOp =
|
||||
rewriter.create<mlir::arith::ConstantIndexOp>(location, 1);
|
||||
mlir::arith::ConstantOp crtSizeConstantOp =
|
||||
rewriter.create<mlir::arith::ConstantIndexOp>(location,
|
||||
loweringParameters.nMods);
|
||||
mlir::scf::ForOp newOp = rewriter.create<mlir::scf::ForOp>(
|
||||
location, zeroConstantOp, crtSizeConstantOp, oneConstantOp,
|
||||
mlir::ValueRange{firstTensor, secondTensor}, body);
|
||||
|
||||
// Convert the types of the new operation
|
||||
typing::TypeConverter converter(loweringParameters);
|
||||
concretelang::convertOperandAndResultTypes(rewriter, newOp,
|
||||
converter.getConversionLambda());
|
||||
|
||||
return newOp.getResult(0);
|
||||
}
|
||||
|
||||
/// Writes an `scf::for` that loops over the crt dimension of one tensor and
|
||||
/// execute the input lambda to write the loop body. Returns the first result
|
||||
/// of the op.
|
||||
///
|
||||
/// Note:
|
||||
/// -----
|
||||
///
|
||||
/// + The type of `firstArgTensor` type is used as output type.
|
||||
mlir::Value writeUnaryTensorLoop(
|
||||
mlir::Location location, mlir::Value tensor,
|
||||
mlir::PatternRewriter &rewriter,
|
||||
mlir::function_ref<void(mlir::OpBuilder &, mlir::Location, mlir::Value,
|
||||
mlir::ValueRange)>
|
||||
body) const {
|
||||
|
||||
// Create the loop
|
||||
mlir::arith::ConstantOp zeroConstantOp =
|
||||
rewriter.create<mlir::arith::ConstantIndexOp>(location, 0);
|
||||
mlir::arith::ConstantOp oneConstantOp =
|
||||
rewriter.create<mlir::arith::ConstantIndexOp>(location, 1);
|
||||
mlir::arith::ConstantOp crtSizeConstantOp =
|
||||
rewriter.create<mlir::arith::ConstantIndexOp>(location,
|
||||
loweringParameters.nMods);
|
||||
mlir::scf::ForOp newOp = rewriter.create<mlir::scf::ForOp>(
|
||||
location, zeroConstantOp, crtSizeConstantOp, oneConstantOp, tensor,
|
||||
body);
|
||||
|
||||
// Convert the types of the new operation
|
||||
typing::TypeConverter converter(loweringParameters);
|
||||
concretelang::convertOperandAndResultTypes(rewriter, newOp,
|
||||
converter.getConversionLambda());
|
||||
|
||||
return newOp.getResult(0);
|
||||
}
|
||||
|
||||
/// Writes the crt encoding of a plaintext of arbitrary precision.
|
||||
mlir::Value writePlaintextCrtEncoding(mlir::Location location,
|
||||
mlir::Value rawPlaintext,
|
||||
mlir::PatternRewriter &rewriter) const {
|
||||
mlir::Value castedPlaintext = rewriter.create<mlir::arith::ExtUIOp>(
|
||||
location, rewriter.getI64Type(), rawPlaintext);
|
||||
return rewriter.create<TFHE::EncodePlaintextWithCrtOp>(
|
||||
location,
|
||||
mlir::RankedTensorType::get(
|
||||
mlir::ArrayRef<int64_t>(loweringParameters.nMods),
|
||||
rewriter.getI64Type()),
|
||||
castedPlaintext, rewriter.getI64ArrayAttr(loweringParameters.mods),
|
||||
rewriter.getI64IntegerAttr(loweringParameters.modsProd));
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::add_eint_int` operation.
|
||||
struct AddEintIntOpPattern : public CrtOpPattern<FHE::AddEintIntOp> {
|
||||
|
||||
AddEintIntOpPattern(mlir::MLIRContext *context,
|
||||
concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: CrtOpPattern<FHE::AddEintIntOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::AddEintIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value eintOperand = op.a();
|
||||
mlir::Value intOperand = op.b();
|
||||
|
||||
// Convert operand type to glwe tensor.
|
||||
typing::TypeConverter converter(loweringParameters);
|
||||
intOperand.setType(converter.convertType(intOperand.getType()));
|
||||
eintOperand.setType(converter.convertType(eintOperand.getType()));
|
||||
|
||||
// Write plaintext encoding
|
||||
mlir::Value encodedPlaintextTensor =
|
||||
writePlaintextCrtEncoding(op.getLoc(), intOperand, rewriter);
|
||||
|
||||
// Write add loop.
|
||||
mlir::Type ciphertextScalarType =
|
||||
converter.convertType(eintOperand.getType())
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType();
|
||||
mlir::Value output = writeBinaryTensorLoop(
|
||||
location, eintOperand, encodedPlaintextTensor, rewriter,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter,
|
||||
mlir::ValueRange args) {
|
||||
mlir::Value extractedEint =
|
||||
builder.create<mlir::tensor::ExtractOp>(loc, args[0], iter);
|
||||
mlir::Value extractedInt =
|
||||
builder.create<mlir::tensor::ExtractOp>(loc, args[1], iter);
|
||||
mlir::Value output = builder.create<TFHE::AddGLWEIntOp>(
|
||||
loc, ciphertextScalarType, extractedEint, extractedInt);
|
||||
mlir::Value newTensor = builder.create<mlir::tensor::InsertOp>(
|
||||
loc, output, args[0], iter);
|
||||
builder.create<mlir::scf::YieldOp>(
|
||||
loc, mlir::ValueRange{newTensor, args[1]});
|
||||
});
|
||||
|
||||
// Rewrite original op.
|
||||
rewriter.replaceOp(op, output);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::sub_int_eint` operation.
|
||||
struct SubIntEintOpPattern : public CrtOpPattern<FHE::SubIntEintOp> {
|
||||
|
||||
SubIntEintOpPattern(mlir::MLIRContext *context,
|
||||
concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: CrtOpPattern<FHE::SubIntEintOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::SubIntEintOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value intOperand = op.a();
|
||||
mlir::Value eintOperand = op.b();
|
||||
|
||||
// Convert operand type to glwe tensor.
|
||||
typing::TypeConverter converter(loweringParameters);
|
||||
intOperand.setType(converter.convertType(intOperand.getType()));
|
||||
eintOperand.setType(converter.convertType(eintOperand.getType()));
|
||||
|
||||
// Write plaintext encoding
|
||||
mlir::Value encodedPlaintextTensor =
|
||||
writePlaintextCrtEncoding(op.getLoc(), intOperand, rewriter);
|
||||
|
||||
// Write add loop.
|
||||
mlir::Type ciphertextScalarType =
|
||||
converter.convertType(eintOperand.getType())
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType();
|
||||
mlir::Value output = writeBinaryTensorLoop(
|
||||
location, eintOperand, encodedPlaintextTensor, rewriter,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter,
|
||||
mlir::ValueRange args) {
|
||||
mlir::Value extractedEint =
|
||||
builder.create<mlir::tensor::ExtractOp>(loc, args[0], iter);
|
||||
mlir::Value extractedInt =
|
||||
builder.create<mlir::tensor::ExtractOp>(loc, args[1], iter);
|
||||
mlir::Value output = builder.create<TFHE::SubGLWEIntOp>(
|
||||
loc, ciphertextScalarType, extractedInt, extractedEint);
|
||||
mlir::Value newTensor = builder.create<mlir::tensor::InsertOp>(
|
||||
loc, output, args[0], iter);
|
||||
builder.create<mlir::scf::YieldOp>(
|
||||
loc, mlir::ValueRange{newTensor, args[1]});
|
||||
});
|
||||
|
||||
// Rewrite original op.
|
||||
rewriter.replaceOp(op, output);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::sub_eint_int` operation.
|
||||
struct SubEintIntOpPattern : public CrtOpPattern<FHE::SubEintIntOp> {
|
||||
|
||||
SubEintIntOpPattern(mlir::MLIRContext *context,
|
||||
concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: CrtOpPattern<FHE::SubEintIntOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::SubEintIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value eintOperand = op.a();
|
||||
mlir::Value intOperand = op.b();
|
||||
|
||||
// Convert operand type to glwe tensor.
|
||||
typing::TypeConverter converter(loweringParameters);
|
||||
intOperand.setType(converter.convertType(intOperand.getType()));
|
||||
eintOperand.setType(converter.convertType(eintOperand.getType()));
|
||||
|
||||
// Write plaintext negation
|
||||
mlir::Type intType = intOperand.getType();
|
||||
mlir::Attribute minusOneAttr = mlir::IntegerAttr::get(intType, -1);
|
||||
mlir::Value minusOne =
|
||||
rewriter.create<mlir::arith::ConstantOp>(location, minusOneAttr)
|
||||
.getResult();
|
||||
mlir::Value negative =
|
||||
rewriter.create<mlir::arith::MulIOp>(location, intOperand, minusOne)
|
||||
.getResult();
|
||||
|
||||
// Write plaintext encoding
|
||||
mlir::Value encodedPlaintextTensor =
|
||||
writePlaintextCrtEncoding(op.getLoc(), negative, rewriter);
|
||||
|
||||
// Write add loop.
|
||||
mlir::Type ciphertextScalarType =
|
||||
converter.convertType(eintOperand.getType())
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType();
|
||||
mlir::Value output = writeBinaryTensorLoop(
|
||||
location, eintOperand, encodedPlaintextTensor, rewriter,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter,
|
||||
mlir::ValueRange args) {
|
||||
mlir::Value extractedEint =
|
||||
builder.create<mlir::tensor::ExtractOp>(loc, args[0], iter);
|
||||
mlir::Value extractedInt =
|
||||
builder.create<mlir::tensor::ExtractOp>(loc, args[1], iter);
|
||||
mlir::Value output = builder.create<TFHE::AddGLWEIntOp>(
|
||||
loc, ciphertextScalarType, extractedEint, extractedInt);
|
||||
mlir::Value newTensor = builder.create<mlir::tensor::InsertOp>(
|
||||
loc, output, args[0], iter);
|
||||
builder.create<mlir::scf::YieldOp>(
|
||||
loc, mlir::ValueRange{newTensor, args[1]});
|
||||
});
|
||||
|
||||
// Rewrite original op.
|
||||
rewriter.replaceOp(op, output);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::add_eint` operation.
|
||||
struct AddEintOpPattern : CrtOpPattern<FHE::AddEintOp> {
|
||||
|
||||
AddEintOpPattern(mlir::MLIRContext *context,
|
||||
concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: CrtOpPattern<FHE::AddEintOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::AddEintOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value lhsOperand = op.a();
|
||||
mlir::Value rhsOperand = op.b();
|
||||
|
||||
// Convert operand type to glwe tensor.
|
||||
typing::TypeConverter converter(loweringParameters);
|
||||
lhsOperand.setType(converter.convertType(lhsOperand.getType()));
|
||||
rhsOperand.setType(converter.convertType(rhsOperand.getType()));
|
||||
|
||||
// Write add loop.
|
||||
mlir::Type ciphertextScalarType =
|
||||
converter.convertType(lhsOperand.getType())
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType();
|
||||
mlir::Value output = writeBinaryTensorLoop(
|
||||
location, lhsOperand, rhsOperand, rewriter,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter,
|
||||
mlir::ValueRange args) {
|
||||
mlir::Value extractedLhs =
|
||||
builder.create<mlir::tensor::ExtractOp>(loc, args[0], iter);
|
||||
mlir::Value extractedRhs =
|
||||
builder.create<mlir::tensor::ExtractOp>(loc, args[1], iter);
|
||||
mlir::Value output = builder.create<TFHE::AddGLWEOp>(
|
||||
loc, ciphertextScalarType, extractedLhs, extractedRhs);
|
||||
mlir::Value newTensor = builder.create<mlir::tensor::InsertOp>(
|
||||
loc, output, args[0], iter);
|
||||
builder.create<mlir::scf::YieldOp>(
|
||||
loc, mlir::ValueRange{newTensor, args[1]});
|
||||
});
|
||||
|
||||
// Rewrite original op.
|
||||
rewriter.replaceOp(op, output);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::sub_eint` operation.
|
||||
struct SubEintOpPattern : CrtOpPattern<FHE::SubEintOp> {
|
||||
|
||||
SubEintOpPattern(mlir::MLIRContext *context,
|
||||
concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: CrtOpPattern<FHE::SubEintOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::SubEintOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value lhsOperand = op.a();
|
||||
mlir::Value rhsOperand = op.b();
|
||||
|
||||
// Convert operand type to glwe tensor.
|
||||
typing::TypeConverter converter(loweringParameters);
|
||||
lhsOperand.setType(converter.convertType(lhsOperand.getType()));
|
||||
rhsOperand.setType(converter.convertType(rhsOperand.getType()));
|
||||
|
||||
// Write sub loop.
|
||||
mlir::Type ciphertextScalarType =
|
||||
converter.convertType(lhsOperand.getType())
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType();
|
||||
mlir::Value output = writeBinaryTensorLoop(
|
||||
location, lhsOperand, rhsOperand, rewriter,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter,
|
||||
mlir::ValueRange args) {
|
||||
mlir::Value extractedLhs =
|
||||
builder.create<mlir::tensor::ExtractOp>(loc, args[0], iter);
|
||||
mlir::Value extractedRhs =
|
||||
builder.create<mlir::tensor::ExtractOp>(loc, args[1], iter);
|
||||
mlir::Value negatedRhs = builder.create<TFHE::NegGLWEOp>(
|
||||
loc, ciphertextScalarType, extractedRhs);
|
||||
mlir::Value output = builder.create<TFHE::AddGLWEOp>(
|
||||
loc, ciphertextScalarType, extractedLhs, negatedRhs);
|
||||
mlir::Value newTensor = builder.create<mlir::tensor::InsertOp>(
|
||||
loc, output, args[0], iter);
|
||||
builder.create<mlir::scf::YieldOp>(
|
||||
loc, mlir::ValueRange{newTensor, args[1]});
|
||||
});
|
||||
|
||||
// Rewrite original op.
|
||||
rewriter.replaceOp(op, output);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::neg_eint` operation.
|
||||
struct NegEintOpPattern : CrtOpPattern<FHE::NegEintOp> {
|
||||
|
||||
NegEintOpPattern(mlir::MLIRContext *context,
|
||||
concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: CrtOpPattern<FHE::NegEintOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::NegEintOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value operand = op.a();
|
||||
|
||||
// Convert operand type to glwe tensor.
|
||||
typing::TypeConverter converter{loweringParameters};
|
||||
operand.setType(converter.convertType(operand.getType()));
|
||||
|
||||
// Write the loop nest.
|
||||
mlir::Type ciphertextScalarType = converter.convertType(operand.getType())
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType();
|
||||
mlir::Value loopRes = writeUnaryTensorLoop(
|
||||
location, operand, rewriter,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter,
|
||||
mlir::ValueRange args) {
|
||||
mlir::Value extractedCiphertext =
|
||||
builder.create<mlir::tensor::ExtractOp>(loc, args[0], iter);
|
||||
mlir::Value negatedCiphertext = builder.create<TFHE::NegGLWEOp>(
|
||||
loc, ciphertextScalarType, extractedCiphertext);
|
||||
mlir::Value newTensor = builder.create<mlir::tensor::InsertOp>(
|
||||
loc, negatedCiphertext, args[0], iter);
|
||||
builder.create<mlir::scf::YieldOp>(loc, mlir::ValueRange{newTensor});
|
||||
});
|
||||
|
||||
// Rewrite original op.
|
||||
rewriter.replaceOp(op, loopRes);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::mul_eint_int` operation.
|
||||
struct MulEintIntOpPattern : CrtOpPattern<FHE::MulEintIntOp> {
|
||||
|
||||
MulEintIntOpPattern(mlir::MLIRContext *context,
|
||||
concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: CrtOpPattern<FHE::MulEintIntOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::MulEintIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value eintOperand = op.a();
|
||||
mlir::Value intOperand = op.b();
|
||||
|
||||
// Convert operand type to glwe tensor.
|
||||
typing::TypeConverter converter{loweringParameters};
|
||||
eintOperand.setType(converter.convertType(eintOperand.getType()));
|
||||
|
||||
// Write cleartext "encoding"
|
||||
mlir::Value encodedCleartext = rewriter.create<mlir::arith::ExtSIOp>(
|
||||
location, rewriter.getI64Type(), intOperand);
|
||||
|
||||
// Write the loop nest.
|
||||
mlir::Type ciphertextScalarType =
|
||||
converter.convertType(eintOperand.getType())
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType();
|
||||
mlir::Value loopRes = writeUnaryTensorLoop(
|
||||
location, eintOperand, rewriter,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter,
|
||||
mlir::ValueRange args) {
|
||||
mlir::Value extractedCiphertext =
|
||||
builder.create<mlir::tensor::ExtractOp>(loc, args[0], iter);
|
||||
mlir::Value negatedCiphertext = builder.create<TFHE::MulGLWEIntOp>(
|
||||
loc, ciphertextScalarType, extractedCiphertext, encodedCleartext);
|
||||
mlir::Value newTensor = builder.create<mlir::tensor::InsertOp>(
|
||||
loc, negatedCiphertext, args[0], iter);
|
||||
builder.create<mlir::scf::YieldOp>(loc, mlir::ValueRange{newTensor});
|
||||
});
|
||||
|
||||
// Rewrite original op.
|
||||
rewriter.replaceOp(op, loopRes);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::apply_lookup_table` operation.
|
||||
struct ApplyLookupTableEintOpPattern
|
||||
: public CrtOpPattern<FHE::ApplyLookupTableEintOp> {
|
||||
|
||||
ApplyLookupTableEintOpPattern(mlir::MLIRContext *context,
|
||||
concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: CrtOpPattern<FHE::ApplyLookupTableEintOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::ApplyLookupTableEintOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
typing::TypeConverter converter(loweringParameters);
|
||||
|
||||
mlir::Value newLut =
|
||||
rewriter
|
||||
.create<TFHE::EncodeExpandLutForWopPBSOp>(
|
||||
op.getLoc(),
|
||||
mlir::RankedTensorType::get(
|
||||
mlir::ArrayRef<int64_t>(loweringParameters.lutSize),
|
||||
rewriter.getI64Type()),
|
||||
op.lut(),
|
||||
rewriter.getI64ArrayAttr(
|
||||
mlir::ArrayRef<int64_t>(loweringParameters.mods)),
|
||||
rewriter.getI64ArrayAttr(
|
||||
mlir::ArrayRef<int64_t>(loweringParameters.bits)),
|
||||
rewriter.getI32IntegerAttr(loweringParameters.polynomialSize),
|
||||
rewriter.getI32IntegerAttr(loweringParameters.modsProd))
|
||||
.getResult();
|
||||
|
||||
// Replace the lut with an encoded / expanded one.
|
||||
auto wopPBS = rewriter.create<TFHE::WopPBSGLWEOp>(
|
||||
op.getLoc(), op.getType(), op.a(), newLut, -1, -1, -1, -1, -1, -1, -1,
|
||||
-1, -1, -1, rewriter.getI64ArrayAttr({}));
|
||||
|
||||
concretelang::convertOperandAndResultTypes(rewriter, wopPBS,
|
||||
converter.getConversionLambda());
|
||||
rewriter.replaceOp(op, {wopPBS.getResult()});
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
/// Rewriter for the `tensor::extract` operation.
|
||||
struct TensorExtractOpPattern : public CrtOpPattern<mlir::tensor::ExtractOp> {
|
||||
|
||||
TensorExtractOpPattern(mlir::MLIRContext *context,
|
||||
concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: CrtOpPattern<mlir::tensor::ExtractOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::tensor::ExtractOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
if (!op.getTensor()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<FHE::EncryptedIntegerType>() &&
|
||||
!op.getTensor()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<TFHE::GLWECipherTextType>()) {
|
||||
return mlir::success();
|
||||
}
|
||||
typing::TypeConverter converter{loweringParameters};
|
||||
mlir::SmallVector<mlir::OpFoldResult> offsets;
|
||||
mlir::SmallVector<mlir::OpFoldResult> sizes;
|
||||
mlir::SmallVector<mlir::OpFoldResult> strides;
|
||||
for (auto index : op.getIndices()) {
|
||||
offsets.push_back(index);
|
||||
sizes.push_back(rewriter.getI64IntegerAttr(1));
|
||||
strides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
}
|
||||
offsets.push_back(
|
||||
rewriter.create<mlir::arith::ConstantIndexOp>(op.getLoc(), 0)
|
||||
.getResult());
|
||||
sizes.push_back(rewriter.getI64IntegerAttr(loweringParameters.nMods));
|
||||
strides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
auto newOp = rewriter.create<mlir::tensor::ExtractSliceOp>(
|
||||
op.getLoc(),
|
||||
converter.convertType(op.getResult().getType())
|
||||
.cast<mlir::RankedTensorType>(),
|
||||
op.getTensor(), offsets, sizes, strides);
|
||||
concretelang::convertOperandAndResultTypes(rewriter, newOp,
|
||||
converter.getConversionLambda());
|
||||
rewriter.replaceOp(op, {newOp.getResult()});
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `tensor::extract` operation.
|
||||
struct TensorInsertOpPattern : public CrtOpPattern<mlir::tensor::InsertOp> {
|
||||
|
||||
TensorInsertOpPattern(mlir::MLIRContext *context,
|
||||
concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: CrtOpPattern<mlir::tensor::InsertOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::tensor::InsertOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
if (!op.getDest()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<FHE::EncryptedIntegerType>() &&
|
||||
!op.getDest()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<TFHE::GLWECipherTextType>()) {
|
||||
return mlir::success();
|
||||
}
|
||||
typing::TypeConverter converter{loweringParameters};
|
||||
mlir::SmallVector<mlir::OpFoldResult> offsets;
|
||||
mlir::SmallVector<mlir::OpFoldResult> sizes;
|
||||
mlir::SmallVector<mlir::OpFoldResult> strides;
|
||||
for (auto index : op.getIndices()) {
|
||||
offsets.push_back(index);
|
||||
sizes.push_back(rewriter.getI64IntegerAttr(1));
|
||||
strides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
}
|
||||
offsets.push_back(
|
||||
rewriter.create<mlir::arith::ConstantIndexOp>(op.getLoc(), 0)
|
||||
.getResult());
|
||||
sizes.push_back(rewriter.getI64IntegerAttr(loweringParameters.nMods));
|
||||
strides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
auto newOp = rewriter.create<mlir::tensor::InsertSliceOp>(
|
||||
op.getLoc(), op.getScalar(), op.getDest(), offsets, sizes, strides);
|
||||
concretelang::convertOperandAndResultTypes(rewriter, newOp,
|
||||
converter.getConversionLambda());
|
||||
rewriter.replaceOp(op, {newOp});
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `tensor::from_elements` operation.
|
||||
struct TensorFromElementsOpPattern
|
||||
: public CrtOpPattern<mlir::tensor::FromElementsOp> {
|
||||
|
||||
TensorFromElementsOpPattern(mlir::MLIRContext *context,
|
||||
concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: CrtOpPattern<mlir::tensor::FromElementsOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::tensor::FromElementsOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (!op.getResult()
|
||||
.getType()
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType()
|
||||
.isa<FHE::EncryptedIntegerType>() &&
|
||||
!op.getResult()
|
||||
.getType()
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType()
|
||||
.isa<TFHE::GLWECipherTextType>()) {
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
typing::TypeConverter converter{loweringParameters};
|
||||
|
||||
// Create dest tensor allocation op
|
||||
mlir::Value outputTensor =
|
||||
rewriter.create<mlir::bufferization::AllocTensorOp>(
|
||||
op.getLoc(),
|
||||
converter.convertType(op.getResult().getType())
|
||||
.cast<mlir::RankedTensorType>(),
|
||||
mlir::ValueRange{});
|
||||
|
||||
// Create insert_slice ops to insert the different pieces.
|
||||
auto outputShape =
|
||||
outputTensor.getType().cast<mlir::RankedTensorType>().getShape();
|
||||
mlir::SmallVector<mlir::OpFoldResult> offsets{
|
||||
rewriter.getI64IntegerAttr(0)};
|
||||
mlir::SmallVector<mlir::OpFoldResult> sizes{rewriter.getI64IntegerAttr(1)};
|
||||
|
||||
mlir::SmallVector<mlir::OpFoldResult> strides{
|
||||
rewriter.getI64IntegerAttr(1)};
|
||||
for (size_t dimIndex = 1; dimIndex < outputShape.size(); ++dimIndex) {
|
||||
sizes.push_back(rewriter.getI64IntegerAttr(outputShape[dimIndex]));
|
||||
strides.push_back(rewriter.getI64IntegerAttr(1));
|
||||
offsets.push_back(rewriter.getI64IntegerAttr(0));
|
||||
}
|
||||
for (size_t insertionIndex = 0; insertionIndex < op.getElements().size();
|
||||
++insertionIndex) {
|
||||
offsets[0] = rewriter.getI64IntegerAttr(insertionIndex);
|
||||
mlir::tensor::InsertSliceOp insertOp =
|
||||
rewriter.create<mlir::tensor::InsertSliceOp>(
|
||||
op.getLoc(), op.getElements()[insertionIndex], outputTensor,
|
||||
offsets, sizes, strides);
|
||||
concretelang::convertOperandAndResultTypes(
|
||||
rewriter, insertOp, converter.getConversionLambda());
|
||||
outputTensor = insertOp.getResult();
|
||||
}
|
||||
rewriter.replaceOp(op, {outputTensor});
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace lowering
|
||||
|
||||
struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
|
||||
|
||||
FHEToTFHECrtPass(concretelang::CrtLoweringParameters params)
|
||||
: loweringParameters(params) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto op = this->getOperation();
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
typing::TypeConverter converter(loweringParameters);
|
||||
|
||||
//------------------------------------------- Marking legal/illegal dialects
|
||||
target.addIllegalDialect<FHE::FHEDialect>();
|
||||
target.addLegalDialect<TFHE::TFHEDialect>();
|
||||
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
|
||||
target.addDynamicallyLegalOp<mlir::tensor::GenerateOp, mlir::scf::ForOp>(
|
||||
[&](mlir::Operation *op) {
|
||||
return (
|
||||
converter.isLegal(op->getOperandTypes()) &&
|
||||
converter.isLegal(op->getResultTypes()) &&
|
||||
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
|
||||
});
|
||||
target.addDynamicallyLegalOp<mlir::tensor::InsertOp,
|
||||
mlir::tensor::ExtractOp, mlir::scf::YieldOp>(
|
||||
[&](mlir::Operation *op) {
|
||||
return (converter.isLegal(op->getOperandTypes()) &&
|
||||
converter.isLegal(op->getResultTypes()));
|
||||
});
|
||||
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
|
||||
[&](mlir::func::FuncOp funcOp) {
|
||||
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
|
||||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
target.addDynamicallyLegalOp<mlir::func::ConstantOp>(
|
||||
[&](mlir::func::ConstantOp op) {
|
||||
return FunctionConstantOpConversion<typing::TypeConverter>::isLegal(
|
||||
op, converter);
|
||||
});
|
||||
target.addLegalOp<mlir::func::CallOp>();
|
||||
target.addLegalOp<mlir::bufferization::AllocTensorOp>();
|
||||
concretelang::addDynamicallyLegalTypeOp<mlir::tensor::ExtractSliceOp>(
|
||||
target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<mlir::tensor::InsertSliceOp>(
|
||||
target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<mlir::tensor::FromElementsOp>(
|
||||
target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<mlir::tensor::ExpandShapeOp>(
|
||||
target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<mlir::tensor::CollapseShapeOp>(
|
||||
target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<
|
||||
concretelang::RT::MakeReadyFutureOp>(target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<concretelang::RT::AwaitFutureOp>(
|
||||
target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<
|
||||
concretelang::RT::CreateAsyncTaskOp>(target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<
|
||||
concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<
|
||||
concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(target,
|
||||
converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<
|
||||
concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<
|
||||
concretelang::RT::WorkFunctionReturnOp>(target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<
|
||||
concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter);
|
||||
|
||||
//---------------------------------------------------------- Adding patterns
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
// Patterns for the `FHE` dialect operations
|
||||
patterns.add<
|
||||
// |_ `FHE::zero_eint`
|
||||
concretelang::GenericTypeAndOpConverterPattern<FHE::ZeroEintOp,
|
||||
TFHE::ZeroGLWEOp>,
|
||||
// |_ `FHE::zero_tensor`
|
||||
concretelang::GenericTypeAndOpConverterPattern<FHE::ZeroTensorOp,
|
||||
TFHE::ZeroTensorGLWEOp>>(
|
||||
&getContext(), converter);
|
||||
// |_ `FHE::add_eint_int`
|
||||
patterns.add<lowering::AddEintIntOpPattern,
|
||||
// |_ `FHE::add_eint`
|
||||
lowering::AddEintOpPattern,
|
||||
// |_ `FHE::sub_int_eint`
|
||||
lowering::SubIntEintOpPattern,
|
||||
// |_ `FHE::sub_eint_int`
|
||||
lowering::SubEintIntOpPattern,
|
||||
// |_ `FHE::sub_eint`
|
||||
lowering::SubEintOpPattern,
|
||||
// |_ `FHE::neg_eint`
|
||||
lowering::NegEintOpPattern,
|
||||
// |_ `FHE::mul_eint_int`
|
||||
lowering::MulEintIntOpPattern,
|
||||
// |_ `FHE::apply_lookup_table`
|
||||
lowering::ApplyLookupTableEintOpPattern>(&getContext(),
|
||||
loweringParameters);
|
||||
|
||||
// Patterns for the relics of the `FHELinalg` dialect operations.
|
||||
// |_ `linalg::generic` turned to nested `scf::for`
|
||||
patterns.add<concretelang::GenericTypeConverterPattern<mlir::scf::ForOp>>(
|
||||
patterns.getContext(), converter);
|
||||
patterns.add<concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>>(
|
||||
patterns.getContext(), converter);
|
||||
patterns.add<
|
||||
RegionOpTypeConverterPattern<mlir::scf::ForOp, typing::TypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<lowering::TensorExtractOpPattern>(&getContext(),
|
||||
loweringParameters);
|
||||
patterns.add<lowering::TensorInsertOpPattern>(&getContext(),
|
||||
loweringParameters);
|
||||
patterns.add<concretelang::GenericTypeConverterPattern<
|
||||
mlir::tensor::ExtractSliceOp>>(patterns.getContext(), converter);
|
||||
patterns.add<
|
||||
concretelang::GenericTypeConverterPattern<mlir::tensor::InsertSliceOp>>(
|
||||
patterns.getContext(), converter);
|
||||
patterns.add<concretelang::GenericTypeConverterPattern<
|
||||
mlir::tensor::CollapseShapeOp>>(patterns.getContext(), converter);
|
||||
patterns.add<
|
||||
concretelang::GenericTypeConverterPattern<mlir::tensor::ExpandShapeOp>>(
|
||||
patterns.getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
|
||||
typing::TypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
// Patterns for `func` dialect operations.
|
||||
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
|
||||
patterns, converter);
|
||||
patterns
|
||||
.add<concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>>(
|
||||
patterns.getContext(), converter);
|
||||
patterns.add<FunctionConstantOpConversion<typing::TypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
// Pattern for the `tensor::from_element` op.
|
||||
patterns.add<lowering::TensorFromElementsOpPattern>(patterns.getContext(),
|
||||
loweringParameters);
|
||||
|
||||
// Patterns for the `RT` dialect operations.
|
||||
patterns
|
||||
.add<concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
|
||||
concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::MakeReadyFutureOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::AwaitFutureOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::CreateAsyncTaskOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::BuildReturnPtrPlaceholderOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::DerefReturnPtrPlaceholderOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::WorkFunctionReturnOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
|
||||
converter);
|
||||
|
||||
//--------------------------------------------------------- Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
concretelang::CrtLoweringParameters loweringParameters;
|
||||
};
|
||||
} // namespace fhe_to_tfhe_crt_conversion
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<mlir::ModuleOp>>
|
||||
createConvertFHEToTFHECrtPass(CrtLoweringParameters lowering) {
|
||||
return std::make_unique<fhe_to_tfhe_crt_conversion::FHEToTFHECrtPass>(
|
||||
lowering);
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
15
compiler/lib/Conversion/FHEToTFHEScalar/CMakeLists.txt
Normal file
15
compiler/lib/Conversion/FHEToTFHEScalar/CMakeLists.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
add_mlir_dialect_library(
|
||||
FHEToTFHEScalar
|
||||
FHEToTFHEScalar.cpp
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
|
||||
DEPENDS
|
||||
FHEDialect
|
||||
mlir-headers
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
MLIRIR
|
||||
MLIRTransforms
|
||||
MLIRMathDialect)
|
||||
|
||||
target_link_libraries(FHEToTFHEScalar PUBLIC MLIRIR)
|
||||
518
compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp
Normal file
518
compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp
Normal file
@@ -0,0 +1,518 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <iostream>
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/Linalg/IR/Linalg.h>
|
||||
#include <mlir/IR/Operation.h>
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/Conversion/FHEToTFHEScalar/Pass.h"
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Tools.h"
|
||||
#include "concretelang/Conversion/Utils/FuncConstOpConversion.h"
|
||||
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
|
||||
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTDialect.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTOps.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTTypes.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
|
||||
namespace FHE = mlir::concretelang::FHE;
|
||||
namespace TFHE = mlir::concretelang::TFHE;
|
||||
namespace concretelang = mlir::concretelang;
|
||||
|
||||
namespace fhe_to_tfhe_scalar_conversion {
|
||||
|
||||
namespace typing {
|
||||
|
||||
/// Converts `FHE::EncryptedInteger` into `TFHE::GlweCiphetext`.
|
||||
TFHE::GLWECipherTextType convertEint(mlir::MLIRContext *context,
|
||||
FHE::EncryptedIntegerType eint) {
|
||||
return TFHE::GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth());
|
||||
}
|
||||
|
||||
/// Converts `Tensor<FHE::EncryptedInteger>` into a
|
||||
/// `Tensor<TFHE::GlweCiphertext>` if the element type is appropriate.
|
||||
/// Otherwise return the input type.
|
||||
mlir::Type maybeConvertEintTensor(mlir::MLIRContext *context,
|
||||
mlir::RankedTensorType maybeEintTensor) {
|
||||
if (!maybeEintTensor.getElementType().isa<FHE::EncryptedIntegerType>()) {
|
||||
return (mlir::Type)(maybeEintTensor);
|
||||
}
|
||||
auto eint =
|
||||
maybeEintTensor.getElementType().cast<FHE::EncryptedIntegerType>();
|
||||
auto currentShape = maybeEintTensor.getShape();
|
||||
return mlir::RankedTensorType::get(
|
||||
currentShape,
|
||||
TFHE::GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth()));
|
||||
}
|
||||
|
||||
/// Converts the type `FHE::EncryptedInteger` to `TFHE::GlweCiphetext` if the
|
||||
/// input type is appropriate. Otherwise return the input type.
|
||||
mlir::Type maybeConvertEint(mlir::MLIRContext *context, mlir::Type t) {
|
||||
if (auto eint = t.dyn_cast<FHE::EncryptedIntegerType>())
|
||||
return convertEint(context, eint);
|
||||
|
||||
return t;
|
||||
}
|
||||
|
||||
/// The type converter used to convert `FHE` to `TFHE` types using the scalar
|
||||
/// strategy.
|
||||
class TypeConverter : public mlir::TypeConverter {
|
||||
|
||||
public:
|
||||
TypeConverter() {
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([](FHE::EncryptedIntegerType type) {
|
||||
return convertEint(type.getContext(), type);
|
||||
});
|
||||
addConversion([](mlir::RankedTensorType type) {
|
||||
return maybeConvertEintTensor(type.getContext(), type);
|
||||
});
|
||||
addConversion([&](concretelang::RT::FutureType type) {
|
||||
return concretelang::RT::FutureType::get(this->convertType(
|
||||
type.dyn_cast<concretelang::RT::FutureType>().getElementType()));
|
||||
});
|
||||
addConversion([&](concretelang::RT::PointerType type) {
|
||||
return concretelang::RT::PointerType::get(this->convertType(
|
||||
type.dyn_cast<concretelang::RT::PointerType>().getElementType()));
|
||||
});
|
||||
}
|
||||
|
||||
/// Returns a lambda that uses this converter to turn one type into another.
|
||||
std::function<mlir::Type(mlir::MLIRContext *, mlir::Type)>
|
||||
getConversionLambda() {
|
||||
return [&](mlir::MLIRContext *, mlir::Type t) { return convertType(t); };
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace typing
|
||||
|
||||
namespace lowering {
|
||||
|
||||
/// A pattern rewriter superclass used by most op rewriters during the
|
||||
/// conversion.
|
||||
template <typename T>
|
||||
struct ScalarOpPattern : public mlir::OpRewritePattern<T> {
|
||||
|
||||
ScalarOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<T>(context, benefit) {}
|
||||
|
||||
/// Writes the encoding of a plaintext of arbitrary precision using shift.
|
||||
mlir::Value
|
||||
writePlaintextShiftEncoding(mlir::Location location, mlir::Value rawPlaintext,
|
||||
int64_t encryptedWidth,
|
||||
mlir::PatternRewriter &rewriter) const {
|
||||
int64_t intShift = 64 - 1 - encryptedWidth;
|
||||
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
|
||||
location, rewriter.getIntegerType(64), rawPlaintext);
|
||||
mlir::Value constantShiftOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
location, rewriter.getI64IntegerAttr(intShift));
|
||||
mlir::Value encodedInt = rewriter.create<mlir::arith::ShLIOp>(
|
||||
location, rewriter.getI64Type(), castedInt, constantShiftOp);
|
||||
return encodedInt;
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::zero` operation.
|
||||
struct ZeroEintOpPattern : public mlir::OpRewritePattern<FHE::ZeroEintOp> {
|
||||
ZeroEintOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<FHE::ZeroEintOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::ZeroEintOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::Location location = op.getLoc();
|
||||
typing::TypeConverter converter;
|
||||
TFHE::ZeroGLWEOp newOp =
|
||||
rewriter.create<TFHE::ZeroGLWEOp>(location, op.getType());
|
||||
concretelang::convertOperandAndResultTypes(rewriter, newOp,
|
||||
converter.getConversionLambda());
|
||||
rewriter.replaceOp(op, {newOp.getResult()});
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::add_eint_int` operation.
|
||||
struct AddEintIntOpPattern : public ScalarOpPattern<FHE::AddEintIntOp> {
|
||||
AddEintIntOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ScalarOpPattern<FHE::AddEintIntOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::AddEintIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value eintOperand = op.a();
|
||||
mlir::Value intOperand = op.b();
|
||||
|
||||
// Write the plaintext encoding
|
||||
mlir::Value encodedInt = writePlaintextShiftEncoding(
|
||||
op.getLoc(), intOperand,
|
||||
eintOperand.getType().cast<FHE::EncryptedIntegerType>().getWidth(),
|
||||
rewriter);
|
||||
|
||||
// Write the new op
|
||||
auto newOp = rewriter.create<TFHE::AddGLWEIntOp>(location, op.getType(),
|
||||
eintOperand, encodedInt);
|
||||
typing::TypeConverter converter;
|
||||
concretelang::convertOperandAndResultTypes(rewriter, newOp,
|
||||
converter.getConversionLambda());
|
||||
|
||||
rewriter.replaceOp(op, {newOp.getResult()});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::sub_eint_int` operation.
|
||||
struct SubEintIntOpPattern : public ScalarOpPattern<FHE::SubEintIntOp> {
|
||||
SubEintIntOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ScalarOpPattern<FHE::SubEintIntOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::SubEintIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value eintOperand = op.a();
|
||||
mlir::Value intOperand = op.b();
|
||||
|
||||
// Write the integer negation
|
||||
mlir::Type intType = intOperand.getType();
|
||||
mlir::Attribute minusOneAttr = mlir::IntegerAttr::get(intType, -1);
|
||||
mlir::Value minusOne =
|
||||
rewriter.create<mlir::arith::ConstantOp>(location, minusOneAttr)
|
||||
.getResult();
|
||||
mlir::Value negative =
|
||||
rewriter.create<mlir::arith::MulIOp>(location, intOperand, minusOne)
|
||||
.getResult();
|
||||
|
||||
// Write the plaintext encoding
|
||||
mlir::Value encodedInt = writePlaintextShiftEncoding(
|
||||
op.getLoc(), negative,
|
||||
eintOperand.getType().cast<FHE::EncryptedIntegerType>().getWidth(),
|
||||
rewriter);
|
||||
|
||||
// Write the new op
|
||||
auto newOp = rewriter.create<TFHE::AddGLWEIntOp>(location, op.getType(),
|
||||
eintOperand, encodedInt);
|
||||
typing::TypeConverter converter;
|
||||
|
||||
// Convert the types
|
||||
concretelang::convertOperandAndResultTypes(rewriter, newOp,
|
||||
converter.getConversionLambda());
|
||||
|
||||
rewriter.replaceOp(op, {newOp.getResult()});
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::sub_int_eint` operation.
|
||||
struct SubIntEintOpPattern : public ScalarOpPattern<FHE::SubIntEintOp> {
|
||||
SubIntEintOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ScalarOpPattern<FHE::SubIntEintOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::SubIntEintOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value intOperand = op.a();
|
||||
mlir::Value eintOperand = op.b();
|
||||
|
||||
// Write the plaintext encoding
|
||||
mlir::Value encodedInt = writePlaintextShiftEncoding(
|
||||
op.getLoc(), intOperand,
|
||||
eintOperand.getType().cast<FHE::EncryptedIntegerType>().getWidth(),
|
||||
rewriter);
|
||||
|
||||
// Write the new op
|
||||
auto newOp = rewriter.create<TFHE::SubGLWEIntOp>(location, op.getType(),
|
||||
encodedInt, eintOperand);
|
||||
typing::TypeConverter converter;
|
||||
concretelang::convertOperandAndResultTypes(rewriter, newOp,
|
||||
converter.getConversionLambda());
|
||||
|
||||
rewriter.replaceOp(op, {newOp.getResult()});
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::sub_eint` operation.
|
||||
struct SubEintOpPattern : public ScalarOpPattern<FHE::SubEintOp> {
|
||||
SubEintOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ScalarOpPattern<FHE::SubEintOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::SubEintOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value lhsOperand = op.a();
|
||||
mlir::Value rhsOperand = op.b();
|
||||
|
||||
// Write rhs negation
|
||||
auto negative = rewriter.create<TFHE::NegGLWEOp>(
|
||||
location, rhsOperand.getType(), rhsOperand);
|
||||
typing::TypeConverter converter;
|
||||
concretelang::convertOperandAndResultTypes(rewriter, negative,
|
||||
converter.getConversionLambda());
|
||||
|
||||
// Write new op.
|
||||
auto newOp = rewriter.create<TFHE::AddGLWEOp>(
|
||||
location, op.getType(), lhsOperand, negative.getResult());
|
||||
concretelang::convertOperandAndResultTypes(rewriter, newOp,
|
||||
converter.getConversionLambda());
|
||||
|
||||
rewriter.replaceOp(op, {newOp.getResult()});
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::mul_eint_int` operation.
|
||||
struct MulEintIntOpPattern : public ScalarOpPattern<FHE::MulEintIntOp> {
|
||||
MulEintIntOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ScalarOpPattern<FHE::MulEintIntOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::MulEintIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::Location location = op.getLoc();
|
||||
mlir::Value eintOperand = op.a();
|
||||
mlir::Value intOperand = op.b();
|
||||
|
||||
// Write the cleartext "encoding"
|
||||
mlir::Value castedCleartext = rewriter.create<mlir::arith::ExtSIOp>(
|
||||
location, rewriter.getIntegerType(64), intOperand);
|
||||
|
||||
// Write the new op.
|
||||
auto newOp = rewriter.create<TFHE::MulGLWEIntOp>(
|
||||
location, op.getType(), eintOperand, castedCleartext);
|
||||
typing::TypeConverter converter;
|
||||
concretelang::convertOperandAndResultTypes(rewriter, newOp,
|
||||
converter.getConversionLambda());
|
||||
|
||||
rewriter.replaceOp(op, {newOp.getResult()});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `FHE::apply_lookup_table` operation.
|
||||
struct ApplyLookupTableEintOpPattern
|
||||
: public ScalarOpPattern<FHE::ApplyLookupTableEintOp> {
|
||||
ApplyLookupTableEintOpPattern(
|
||||
mlir::MLIRContext *context,
|
||||
concretelang::ScalarLoweringParameters loweringParams,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ScalarOpPattern<FHE::ApplyLookupTableEintOp>(context, benefit),
|
||||
loweringParameters(loweringParams) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(FHE::ApplyLookupTableEintOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
size_t outputBits =
|
||||
op.getResult().getType().cast<FHE::EncryptedIntegerType>().getWidth();
|
||||
mlir::Value newLut =
|
||||
rewriter
|
||||
.create<TFHE::EncodeExpandLutForBootstrapOp>(
|
||||
op.getLoc(),
|
||||
mlir::RankedTensorType::get(
|
||||
mlir::ArrayRef<int64_t>(loweringParameters.polynomialSize),
|
||||
rewriter.getI64Type()),
|
||||
op.lut(),
|
||||
rewriter.getI32IntegerAttr(loweringParameters.polynomialSize),
|
||||
rewriter.getI32IntegerAttr(outputBits))
|
||||
.getResult();
|
||||
|
||||
// Insert keyswitch
|
||||
auto ksOp = rewriter.create<TFHE::KeySwitchGLWEOp>(
|
||||
op.getLoc(), op.a().getType(), op.a(), -1, -1);
|
||||
typing::TypeConverter converter;
|
||||
concretelang::convertOperandAndResultTypes(rewriter, ksOp,
|
||||
converter.getConversionLambda());
|
||||
|
||||
// Insert bootstrap
|
||||
auto bsOp = rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
|
||||
op, op.getType(), ksOp, newLut, -1, -1, -1, -1);
|
||||
concretelang::convertOperandAndResultTypes(rewriter, bsOp,
|
||||
converter.getConversionLambda());
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
|
||||
private:
|
||||
concretelang::ScalarLoweringParameters loweringParameters;
|
||||
};
|
||||
|
||||
} // namespace lowering
|
||||
|
||||
struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
|
||||
|
||||
FHEToTFHEScalarPass(concretelang::ScalarLoweringParameters loweringParams)
|
||||
: loweringParameters(loweringParams){};
|
||||
|
||||
void runOnOperation() override {
|
||||
auto op = this->getOperation();
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
typing::TypeConverter converter;
|
||||
|
||||
//------------------------------------------- Marking legal/illegal dialects
|
||||
target.addIllegalDialect<FHE::FHEDialect>();
|
||||
target.addLegalDialect<TFHE::TFHEDialect>();
|
||||
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
|
||||
target.addDynamicallyLegalOp<mlir::linalg::GenericOp,
|
||||
mlir::tensor::GenerateOp>(
|
||||
[&](mlir::Operation *op) {
|
||||
return (
|
||||
converter.isLegal(op->getOperandTypes()) &&
|
||||
converter.isLegal(op->getResultTypes()) &&
|
||||
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
|
||||
});
|
||||
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
|
||||
[&](mlir::func::FuncOp funcOp) {
|
||||
return converter.isSignatureLegal(funcOp.getFunctionType()) &&
|
||||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
target.addDynamicallyLegalOp<mlir::func::ConstantOp>(
|
||||
[&](mlir::func::ConstantOp op) {
|
||||
return FunctionConstantOpConversion<typing::TypeConverter>::isLegal(
|
||||
op, converter);
|
||||
});
|
||||
target.addLegalOp<mlir::func::CallOp>();
|
||||
concretelang::addDynamicallyLegalTypeOp<
|
||||
concretelang::RT::MakeReadyFutureOp>(target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<concretelang::RT::AwaitFutureOp>(
|
||||
target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<
|
||||
concretelang::RT::CreateAsyncTaskOp>(target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<
|
||||
concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<
|
||||
concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>(target,
|
||||
converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<
|
||||
concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<
|
||||
concretelang::RT::WorkFunctionReturnOp>(target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<
|
||||
concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter);
|
||||
|
||||
//---------------------------------------------------------- Adding patterns
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
// Patterns for the `FHE` dialect operations
|
||||
patterns.add<
|
||||
// |_ `FHE::zero_eint`
|
||||
concretelang::GenericTypeAndOpConverterPattern<FHE::ZeroEintOp,
|
||||
TFHE::ZeroGLWEOp>,
|
||||
// |_ `FHE::zero_tensor`
|
||||
concretelang::GenericTypeAndOpConverterPattern<FHE::ZeroTensorOp,
|
||||
TFHE::ZeroTensorGLWEOp>,
|
||||
// |_ `FHE::neg_eint`
|
||||
concretelang::GenericTypeAndOpConverterPattern<FHE::NegEintOp,
|
||||
TFHE::NegGLWEOp>,
|
||||
// |_ `FHE::add_eint`
|
||||
concretelang::GenericTypeAndOpConverterPattern<FHE::AddEintOp,
|
||||
TFHE::AddGLWEOp>>(
|
||||
&getContext(), converter);
|
||||
// |_ `FHE::add_eint_int`
|
||||
patterns.add<lowering::AddEintIntOpPattern,
|
||||
// |_ `FHE::sub_int_eint`
|
||||
lowering::SubIntEintOpPattern,
|
||||
// |_ `FHE::sub_eint_int`
|
||||
lowering::SubEintIntOpPattern,
|
||||
// |_ `FHE::sub_eint`
|
||||
lowering::SubEintOpPattern,
|
||||
// |_ `FHE::mul_eint_int`
|
||||
lowering::MulEintIntOpPattern>(&getContext());
|
||||
// |_ `FHE::apply_lookup_table`
|
||||
patterns.add<lowering::ApplyLookupTableEintOpPattern>(&getContext(),
|
||||
loweringParameters);
|
||||
|
||||
// Patterns for the relics of the `FHELinalg` dialect operations.
|
||||
// |_ `linalg::generic` turned to nested `scf::for`
|
||||
patterns
|
||||
.add<concretelang::GenericTypeConverterPattern<mlir::linalg::YieldOp>>(
|
||||
patterns.getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
|
||||
typing::TypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<
|
||||
RegionOpTypeConverterPattern<mlir::scf::ForOp, typing::TypeConverter>>(
|
||||
&getContext(), converter);
|
||||
concretelang::populateWithTensorTypeConverterPatterns(patterns, target,
|
||||
converter);
|
||||
|
||||
// Patterns for `func` dialect operations.
|
||||
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
|
||||
patterns, converter);
|
||||
patterns
|
||||
.add<concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>>(
|
||||
patterns.getContext(), converter);
|
||||
patterns.add<FunctionConstantOpConversion<typing::TypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
// Patterns for the `RT` dialect operations.
|
||||
patterns
|
||||
.add<concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
|
||||
concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::MakeReadyFutureOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::AwaitFutureOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::CreateAsyncTaskOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::BuildReturnPtrPlaceholderOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::DerefReturnPtrPlaceholderOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::WorkFunctionReturnOp>,
|
||||
concretelang::GenericTypeConverterPattern<
|
||||
concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
|
||||
converter);
|
||||
|
||||
//--------------------------------------------------------- Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
concretelang::ScalarLoweringParameters loweringParameters;
|
||||
};
|
||||
|
||||
} // namespace fhe_to_tfhe_scalar_conversion
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertFHEToTFHEScalarPass(ScalarLoweringParameters loweringParameters) {
|
||||
return std::make_unique<fhe_to_tfhe_scalar_conversion::FHEToTFHEScalarPass>(
|
||||
loweringParameters);
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -70,17 +70,12 @@ public:
|
||||
auto dimension = cryptoParameters.getNBigLweDimension();
|
||||
auto polynomialSize = 1;
|
||||
auto precision = (signed)type.getP();
|
||||
auto crtDecomposition =
|
||||
cryptoParameters.largeInteger.hasValue()
|
||||
? cryptoParameters.largeInteger->crtDecomposition
|
||||
: mlir::concretelang::CRTDecomposition{};
|
||||
if ((int)dimension == type.getDimension() &&
|
||||
(int)polynomialSize == type.getPolynomialSize()) {
|
||||
return type;
|
||||
}
|
||||
return TFHE::GLWECipherTextType::get(type.getContext(), dimension,
|
||||
polynomialSize, bits, precision,
|
||||
crtDecomposition);
|
||||
polynomialSize, bits, precision);
|
||||
}
|
||||
|
||||
TFHE::GLWECipherTextType glweLookupTableType(GLWECipherTextType &type) {
|
||||
@@ -89,7 +84,7 @@ public:
|
||||
auto polynomialSize = cryptoParameters.getPolynomialSize();
|
||||
auto precision = (signed)type.getP();
|
||||
return TFHE::GLWECipherTextType::get(type.getContext(), dimension,
|
||||
polynomialSize, bits, precision, {});
|
||||
polynomialSize, bits, precision);
|
||||
}
|
||||
|
||||
TFHE::GLWECipherTextType glweIntraPBSType(GLWECipherTextType &type) {
|
||||
@@ -98,7 +93,7 @@ public:
|
||||
auto polynomialSize = 1;
|
||||
auto precision = (signed)type.getP();
|
||||
return TFHE::GLWECipherTextType::get(type.getContext(), dimension,
|
||||
polynomialSize, bits, precision, {});
|
||||
polynomialSize, bits, precision);
|
||||
}
|
||||
|
||||
mlir::concretelang::V0Parameter cryptoParameters;
|
||||
@@ -181,7 +176,7 @@ struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto newOp = rewriter.replaceOpWithNewOp<TFHE::WopPBSGLWEOp>(
|
||||
wopPBSOp, converter.convertType(wopPBSOp.result().getType()),
|
||||
wopPBSOp.ciphertext(), wopPBSOp.lookupTable(),
|
||||
wopPBSOp.ciphertexts(), wopPBSOp.lookupTable(),
|
||||
// Bootstrap parameters
|
||||
cryptoParameters.brLevel, cryptoParameters.brLogBase,
|
||||
// Keyswitch parameters
|
||||
@@ -195,11 +190,18 @@ struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
|
||||
cryptoParameters.largeInteger->wopPBS.packingKeySwitch.baseLog,
|
||||
// Circuit bootstrap parameters
|
||||
cryptoParameters.largeInteger->wopPBS.circuitBootstrap.level,
|
||||
cryptoParameters.largeInteger->wopPBS.circuitBootstrap.baseLog);
|
||||
cryptoParameters.largeInteger->wopPBS.circuitBootstrap.baseLog,
|
||||
// Crt decomposition
|
||||
rewriter.getI64ArrayAttr(
|
||||
cryptoParameters.largeInteger->crtDecomposition));
|
||||
rewriter.startRootUpdate(newOp);
|
||||
auto ctType =
|
||||
wopPBSOp.ciphertexts().getType().cast<mlir::RankedTensorType>();
|
||||
auto ciphertextType =
|
||||
wopPBSOp.ciphertext().getType().cast<TFHE::GLWECipherTextType>();
|
||||
newOp.ciphertext().setType(converter.glweInterPBSType(ciphertextType));
|
||||
ctType.getElementType().cast<TFHE::GLWECipherTextType>();
|
||||
auto newType = mlir::RankedTensorType::get(
|
||||
ctType.getShape(), converter.glweInterPBSType(ciphertextType));
|
||||
newOp.ciphertexts().setType(newType);
|
||||
rewriter.finalizeRootUpdate(newOp);
|
||||
return mlir::success();
|
||||
};
|
||||
@@ -290,6 +292,8 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
|
||||
target.addDynamicallyLegalOp<TFHE::WopPBSGLWEOp>(
|
||||
[&](TFHE::WopPBSGLWEOp op) {
|
||||
return !op.getType()
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType()
|
||||
.cast<TFHE::GLWECipherTextType>()
|
||||
.hasUnparametrizedParameters();
|
||||
});
|
||||
|
||||
@@ -108,7 +108,7 @@ struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
|
||||
mlir::Type resultType = converter.convertType(wopOp.getType());
|
||||
|
||||
auto newOp = rewriter.replaceOpWithNewOp<Concrete::WopPBSLweOp>(
|
||||
wopOp, resultType, wopOp.ciphertext(), wopOp.lookupTable(),
|
||||
wopOp, resultType, wopOp.ciphertexts(), wopOp.lookupTable(),
|
||||
// Bootstrap parameters
|
||||
wopOp.bootstrapLevel(), wopOp.bootstrapBaseLog(),
|
||||
// Keyswitch parameters
|
||||
@@ -118,12 +118,14 @@ struct WopPBSGLWEOpPattern : public mlir::OpRewritePattern<TFHE::WopPBSGLWEOp> {
|
||||
wopOp.packingKeySwitchoutputPolynomialSize(),
|
||||
wopOp.packingKeySwitchLevel(), wopOp.packingKeySwitchBaseLog(),
|
||||
// Circuit bootstrap parameters
|
||||
wopOp.circuitBootstrapLevel(), wopOp.circuitBootstrapBaseLog());
|
||||
wopOp.circuitBootstrapLevel(), wopOp.circuitBootstrapBaseLog(),
|
||||
// Crt Decomposition
|
||||
wopOp.crtDecomposition());
|
||||
|
||||
rewriter.startRootUpdate(newOp);
|
||||
|
||||
newOp.ciphertext().setType(
|
||||
converter.convertType(wopOp.ciphertext().getType()));
|
||||
newOp.ciphertexts().setType(
|
||||
converter.convertType(wopOp.ciphertexts().getType()));
|
||||
|
||||
rewriter.finalizeRootUpdate(newOp);
|
||||
return ::mlir::success();
|
||||
@@ -179,6 +181,18 @@ void TFHEToConcretePass::runOnOperation() {
|
||||
patterns.add<mlir::concretelang::GenericTypeAndOpConverterPattern<
|
||||
mlir::concretelang::TFHE::ZeroTensorGLWEOp,
|
||||
mlir::concretelang::Concrete::ZeroTensorLWEOp>>(&getContext(), converter);
|
||||
patterns.add<mlir::concretelang::GenericTypeAndOpConverterPattern<
|
||||
mlir::concretelang::TFHE::EncodeExpandLutForBootstrapOp,
|
||||
mlir::concretelang::Concrete::EncodeExpandLutForBootstrapOp>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<mlir::concretelang::GenericTypeAndOpConverterPattern<
|
||||
mlir::concretelang::TFHE::EncodeExpandLutForWopPBSOp,
|
||||
mlir::concretelang::Concrete::EncodeExpandLutForWopPBSOp>>(&getContext(),
|
||||
converter);
|
||||
patterns.add<mlir::concretelang::GenericTypeAndOpConverterPattern<
|
||||
mlir::concretelang::TFHE::EncodePlaintextWithCrtOp,
|
||||
mlir::concretelang::Concrete::EncodePlaintextWithCrtOp>>(&getContext(),
|
||||
converter);
|
||||
patterns.add<BootstrapGLWEOpPattern>(&getContext(), converter);
|
||||
patterns.add<WopPBSGLWEOpPattern>(&getContext(), converter);
|
||||
target.addDynamicallyLegalOp<Concrete::BootstrapLweOp>(
|
||||
|
||||
@@ -135,5 +135,19 @@ void mlir::concretelang::BConcrete::
|
||||
BConcrete::WopPBSCRTLweTensorOp::attachInterface<TensorToMemrefOp<
|
||||
BConcrete::WopPBSCRTLweTensorOp, BConcrete::WopPBSCRTLweBufferOp>>(
|
||||
*ctx);
|
||||
// encode_plaintext_with_crt_tensor => encode_plaintext_with_crt_buffer
|
||||
BConcrete::EncodePlaintextWithCrtTensorOp::attachInterface<
|
||||
TensorToMemrefOp<BConcrete::EncodePlaintextWithCrtTensorOp,
|
||||
BConcrete::EncodePlaintextWithCrtBufferOp>>(*ctx);
|
||||
// encode_expand_lut_for_bootstrap_tensor =>
|
||||
// encode_expand_lut_for_bootstrap_buffer
|
||||
BConcrete::EncodeExpandLutForBootstrapTensorOp::attachInterface<
|
||||
TensorToMemrefOp<BConcrete::EncodeExpandLutForBootstrapTensorOp,
|
||||
BConcrete::EncodeExpandLutForBootstrapBufferOp>>(*ctx);
|
||||
// encode_expand_lut_for_woppbs_tensor =>
|
||||
// encode_expand_lut_for_woppbs_buffer
|
||||
BConcrete::EncodeExpandLutForWopPBSTensorOp::attachInterface<
|
||||
TensorToMemrefOp<BConcrete::EncodeExpandLutForWopPBSTensorOp,
|
||||
BConcrete::EncodeExpandLutForWopPBSBufferOp>>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ add_mlir_dialect_library(
|
||||
ConcretelangBConcreteTransforms
|
||||
BufferizableOpInterfaceImpl.cpp
|
||||
AddRuntimeContext.cpp
|
||||
EliminateCRTOps.cpp
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/BConcrete
|
||||
DEPENDS
|
||||
|
||||
@@ -1,561 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/ClientLib/CRT.h"
|
||||
#include "concretelang/Conversion/Tools.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
|
||||
#include "concretelang/Dialect/BConcrete/Transforms/Passes.h"
|
||||
|
||||
namespace arith = mlir::arith;
|
||||
namespace tensor = mlir::tensor;
|
||||
namespace bufferization = mlir::bufferization;
|
||||
namespace scf = mlir::scf;
|
||||
namespace BConcrete = mlir::concretelang::BConcrete;
|
||||
namespace crt = concretelang::clientlib::crt;
|
||||
|
||||
namespace {
|
||||
|
||||
char encode_crt[] = "encode_crt";
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// `BConcreteCRTOp` operators to `BConcreteOp` on
|
||||
// each block.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "BConcreteCRTOp"(%arg0, %arg1) {crtDecomposition = [...]}
|
||||
// : (tensor<nbBlocksxlweSizexi64>, tensor<nbBlocksxlweSizexi64>) ->
|
||||
// (tensor<nbBlocksxlweSizexi64>)
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %c0 = arith.constant 0 : index
|
||||
// %c1 = arith.constant 1 : index
|
||||
// %cB = arith.constant nbBlocks : index
|
||||
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
|
||||
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
|
||||
// (tensor<nbBlocksxlweSizexi64>) {
|
||||
// %blockArg = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64>
|
||||
// %tmp = "BConcreteOp"(%blockArg)
|
||||
// : (tensor<lweSizexi64>) -> (tensor<lweSizexi64>)
|
||||
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
|
||||
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
|
||||
// }
|
||||
// ```
|
||||
template <typename BConcreteCRTOp, typename BConcreteOp>
|
||||
struct BConcreteCRTUnaryOpPattern
|
||||
: public mlir::OpRewritePattern<BConcreteCRTOp> {
|
||||
BConcreteCRTUnaryOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<BConcreteCRTOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(BConcreteCRTOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto resultTy =
|
||||
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
|
||||
auto loc = op.getLoc();
|
||||
assert(resultTy.getShape().size() == 2);
|
||||
auto shape = resultTy.getShape();
|
||||
|
||||
// %c0 = arith.constant 0 : index
|
||||
// %c1 = arith.constant 1 : index
|
||||
// %cB = arith.constant nbBlocks : index
|
||||
auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
auto cB = rewriter.create<arith::ConstantIndexOp>(loc, shape[0]);
|
||||
|
||||
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
|
||||
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
|
||||
op.getLoc(), resultTy, mlir::ValueRange{});
|
||||
|
||||
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
|
||||
// (tensor<nbBlocksxlweSizexi64>) {
|
||||
rewriter.replaceOpWithNewOp<scf::ForOp>(
|
||||
op, c0, cB, c1, init,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i,
|
||||
mlir::ValueRange iterArgs) {
|
||||
// [%i, 0]
|
||||
mlir::SmallVector<mlir::OpFoldResult> offsets{
|
||||
i, rewriter.getI64IntegerAttr(0)};
|
||||
// [1, lweSize]
|
||||
mlir::SmallVector<mlir::OpFoldResult> sizes{
|
||||
rewriter.getI64IntegerAttr(1),
|
||||
rewriter.getI64IntegerAttr(shape[1])};
|
||||
// [1, 1]
|
||||
mlir::SmallVector<mlir::OpFoldResult> strides{
|
||||
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
|
||||
|
||||
auto blockTy = mlir::RankedTensorType::get({shape[1]},
|
||||
resultTy.getElementType());
|
||||
|
||||
// %blockArg = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64>
|
||||
auto blockArg = builder.create<tensor::ExtractSliceOp>(
|
||||
loc, blockTy, op.ciphertext(), offsets, sizes, strides);
|
||||
// %tmp = "BConcrete.add_lwe_buffer"(%blockArg0, %blockArg1)
|
||||
// : (tensor<lweSizexi64>, tensor<lweSizexi64>) ->
|
||||
// (tensor<lweSizexi64>)
|
||||
auto tmp = builder.create<BConcreteOp>(loc, blockTy, blockArg);
|
||||
|
||||
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
|
||||
// 1] : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
|
||||
auto res = builder.create<tensor::InsertSliceOp>(
|
||||
loc, tmp, iterArgs[0], offsets, sizes, strides);
|
||||
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
|
||||
builder.create<scf::YieldOp>(loc, (mlir::Value)res);
|
||||
});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// `BConcreteCRTOp` operators to `BConcreteOp` on
|
||||
// each block.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "BConcreteCRTOp"(%arg0, %arg1) {crtDecomposition = [...]}
|
||||
// : (tensor<nbBlocksxlweSizexi64>, tensor<nbBlocksxlweSizexi64>) ->
|
||||
// (tensor<nbBlocksxlweSizexi64>)
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// %c0 = arith.constant 0 : index
|
||||
// %c1 = arith.constant 1 : index
|
||||
// %cB = arith.constant nbBlocks : index
|
||||
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
|
||||
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
|
||||
// (tensor<nbBlocksxlweSizexi64>) {
|
||||
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64>
|
||||
// %blockArg1 = tensor.extract_slice %arg1[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64>
|
||||
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
|
||||
// : (tensor<lweSizexi64>, tensor<lweSizexi64>) ->
|
||||
// (tensor<lweSizexi64>)
|
||||
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
|
||||
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
|
||||
// }
|
||||
// ```
|
||||
template <typename BConcreteCRTOp, typename BConcreteOp>
|
||||
struct BConcreteCRTBinaryOpPattern
|
||||
: public mlir::OpRewritePattern<BConcreteCRTOp> {
|
||||
BConcreteCRTBinaryOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<BConcreteCRTOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(BConcreteCRTOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto resultTy =
|
||||
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
|
||||
auto loc = op.getLoc();
|
||||
assert(resultTy.getShape().size() == 2);
|
||||
auto shape = resultTy.getShape();
|
||||
|
||||
// %c0 = arith.constant 0 : index
|
||||
// %c1 = arith.constant 1 : index
|
||||
// %cB = arith.constant nbBlocks : index
|
||||
auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
auto cB = rewriter.create<arith::ConstantIndexOp>(loc, shape[0]);
|
||||
|
||||
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
|
||||
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
|
||||
op.getLoc(), resultTy, mlir::ValueRange{});
|
||||
|
||||
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
|
||||
// (tensor<nbBlocksxlweSizexi64>) {
|
||||
rewriter.replaceOpWithNewOp<scf::ForOp>(
|
||||
op, c0, cB, c1, init,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i,
|
||||
mlir::ValueRange iterArgs) {
|
||||
// [%i, 0]
|
||||
mlir::SmallVector<mlir::OpFoldResult> offsets{
|
||||
i, rewriter.getI64IntegerAttr(0)};
|
||||
// [1, lweSize]
|
||||
mlir::SmallVector<mlir::OpFoldResult> sizes{
|
||||
rewriter.getI64IntegerAttr(1),
|
||||
rewriter.getI64IntegerAttr(shape[1])};
|
||||
// [1, 1]
|
||||
mlir::SmallVector<mlir::OpFoldResult> strides{
|
||||
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
|
||||
|
||||
auto blockTy = mlir::RankedTensorType::get({shape[1]},
|
||||
resultTy.getElementType());
|
||||
|
||||
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64>
|
||||
auto blockArg0 = builder.create<tensor::ExtractSliceOp>(
|
||||
loc, blockTy, op.lhs(), offsets, sizes, strides);
|
||||
// %blockArg1 = tensor.extract_slice %arg1[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64>
|
||||
auto blockArg1 = builder.create<tensor::ExtractSliceOp>(
|
||||
loc, blockTy, op.rhs(), offsets, sizes, strides);
|
||||
// %tmp = "BConcrete.add_lwe_buffer"(%blockArg0, %blockArg1)
|
||||
// : (tensor<lweSizexi64>, tensor<lweSizexi64>) ->
|
||||
// (tensor<lweSizexi64>)
|
||||
auto tmp =
|
||||
builder.create<BConcreteOp>(loc, blockTy, blockArg0, blockArg1);
|
||||
|
||||
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
|
||||
// 1] : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
|
||||
auto res = builder.create<tensor::InsertSliceOp>(
|
||||
loc, tmp, iterArgs[0], offsets, sizes, strides);
|
||||
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
|
||||
builder.create<scf::YieldOp>(loc, (mlir::Value)res);
|
||||
});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// `BConcreteCRTOp` operators to `BConcreteOp` on
|
||||
// each block with the crt decomposition of the cleartext.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "BConcreteCRTOp"(%arg0, %x) {crtDecomposition = [d0...dn]}
|
||||
// : (tensor<nbBlocksxlweSizexi64>, i64) -> (tensor<nbBlocksxlweSizexi64>)
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// // Build the decomposition of the plaintext
|
||||
// %x0_a = arith.constant 64/d0 : f64
|
||||
// %x0_b = arith.mulf %x, %x0_a : i64
|
||||
// %x0 = arith.fptoui %x0_b : f64 to i64
|
||||
// ...
|
||||
// %xn_a = arith.constant 64/dn : f64
|
||||
// %xn_b = arith.mulf %x, %xn_a : i64
|
||||
// %xn = arith.fptoui %xn_b : f64 to i64
|
||||
// %x_decomp = tensor.from_elements %x0, ..., %xn : tensor<nbBlocksxi64>
|
||||
// // Loop on blocks
|
||||
// %c0 = arith.constant 0 : index
|
||||
// %c1 = arith.constant 1 : index
|
||||
// %cB = arith.constant nbBlocks : index
|
||||
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
|
||||
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
|
||||
// (tensor<nbBlocksxlweSizexi64>) {
|
||||
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64>
|
||||
// %blockArg1 = tensor.extract %x_decomp[%i] : tensor<nbBlocksxi64>
|
||||
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
|
||||
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
|
||||
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
|
||||
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
|
||||
// }
|
||||
// ```
|
||||
struct AddPlaintextCRTLweTensorOpPattern
|
||||
: public mlir::OpRewritePattern<BConcrete::AddPlaintextCRTLweTensorOp> {
|
||||
AddPlaintextCRTLweTensorOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<BConcrete::AddPlaintextCRTLweTensorOp>(context,
|
||||
benefit) {
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(BConcrete::AddPlaintextCRTLweTensorOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto resultTy =
|
||||
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
|
||||
auto loc = op.getLoc();
|
||||
assert(resultTy.getShape().size() == 2);
|
||||
auto shape = resultTy.getShape();
|
||||
|
||||
auto rhs = op.rhs();
|
||||
mlir::SmallVector<mlir::Value, 5> plaintextElements;
|
||||
uint64_t moduliProduct = 1;
|
||||
for (mlir::Attribute di : op.crtDecomposition()) {
|
||||
moduliProduct *= di.cast<mlir::IntegerAttr>().getValue().getZExtValue();
|
||||
}
|
||||
if (auto cst =
|
||||
mlir::dyn_cast_or_null<arith::ConstantIntOp>(rhs.getDefiningOp())) {
|
||||
auto apCst = cst.getValue().cast<mlir::IntegerAttr>().getValue();
|
||||
auto value = apCst.getSExtValue();
|
||||
|
||||
// constant value, encode at compile time
|
||||
for (mlir::Attribute di : op.crtDecomposition()) {
|
||||
auto modulus = di.cast<mlir::IntegerAttr>().getValue().getZExtValue();
|
||||
|
||||
auto encoded = crt::encode(value, modulus, moduliProduct);
|
||||
plaintextElements.push_back(
|
||||
rewriter.create<arith::ConstantIntOp>(loc, encoded, 64));
|
||||
}
|
||||
} else {
|
||||
// dynamic value, encode at runtime
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, encode_crt,
|
||||
mlir::FunctionType::get(rewriter.getContext(),
|
||||
{rewriter.getI64Type(),
|
||||
rewriter.getI64Type(),
|
||||
rewriter.getI64Type()},
|
||||
{rewriter.getI64Type()}))
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto extOp =
|
||||
rewriter.create<arith::ExtSIOp>(loc, rewriter.getI64Type(), rhs);
|
||||
auto moduliProductOp =
|
||||
rewriter.create<arith::ConstantIntOp>(loc, moduliProduct, 64);
|
||||
for (mlir::Attribute di : op.crtDecomposition()) {
|
||||
auto modulus = di.cast<mlir::IntegerAttr>().getValue().getZExtValue();
|
||||
auto modulusOp =
|
||||
rewriter.create<arith::ConstantIntOp>(loc, modulus, 64);
|
||||
plaintextElements.push_back(
|
||||
rewriter
|
||||
.create<mlir::func::CallOp>(
|
||||
loc, encode_crt, mlir::TypeRange{rewriter.getI64Type()},
|
||||
mlir::ValueRange{extOp, modulusOp, moduliProductOp})
|
||||
.getResult(0));
|
||||
}
|
||||
}
|
||||
|
||||
// %x_decomp = tensor.from_elements %x0, ..., %xn : tensor<nbBlocksxi64>
|
||||
auto x_decomp =
|
||||
rewriter.create<tensor::FromElementsOp>(loc, plaintextElements);
|
||||
|
||||
// %c0 = arith.constant 0 : index
|
||||
// %c1 = arith.constant 1 : index
|
||||
// %cB = arith.constant nbBlocks : index
|
||||
auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
auto cB = rewriter.create<arith::ConstantIndexOp>(loc, shape[0]);
|
||||
|
||||
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
|
||||
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
|
||||
op.getLoc(), resultTy, mlir::ValueRange{});
|
||||
|
||||
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
|
||||
// (tensor<nbBlocksxlweSizexi64>) {
|
||||
rewriter.replaceOpWithNewOp<scf::ForOp>(
|
||||
op, c0, cB, c1, init,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i,
|
||||
mlir::ValueRange iterArgs) {
|
||||
// [%i, 0]
|
||||
mlir::SmallVector<mlir::OpFoldResult> offsets{
|
||||
i, rewriter.getI64IntegerAttr(0)};
|
||||
// [1, lweSize]
|
||||
mlir::SmallVector<mlir::OpFoldResult> sizes{
|
||||
rewriter.getI64IntegerAttr(1),
|
||||
rewriter.getI64IntegerAttr(shape[1])};
|
||||
// [1, 1]
|
||||
mlir::SmallVector<mlir::OpFoldResult> strides{
|
||||
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
|
||||
|
||||
auto blockTy = mlir::RankedTensorType::get({shape[1]},
|
||||
resultTy.getElementType());
|
||||
|
||||
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64>
|
||||
auto blockArg0 = builder.create<tensor::ExtractSliceOp>(
|
||||
loc, blockTy, op.lhs(), offsets, sizes, strides);
|
||||
// %blockArg1 = tensor.extract %x_decomp[%i] : tensor<nbBlocksxi64>
|
||||
auto blockArg1 = builder.create<tensor::ExtractOp>(loc, x_decomp, i);
|
||||
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
|
||||
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
|
||||
auto tmp = builder.create<BConcrete::AddPlaintextLweTensorOp>(
|
||||
loc, blockTy, blockArg0, blockArg1);
|
||||
|
||||
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
|
||||
// 1] : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
|
||||
auto res = builder.create<tensor::InsertSliceOp>(
|
||||
loc, tmp, iterArgs[0], offsets, sizes, strides);
|
||||
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
|
||||
builder.create<scf::YieldOp>(loc, (mlir::Value)res);
|
||||
});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// `BConcreteCRTOp` operators to `BConcreteOp` on
|
||||
// each block with the crt decomposition of the cleartext.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```mlir
|
||||
// %0 = "BConcreteCRTOp"(%arg0, %x) {crtDecomposition = [d0...dn]}
|
||||
// : (tensor<nbBlocksxlweSizexi64>, i64) -> (tensor<nbBlocksxlweSizexi64>)
|
||||
// ```
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// ```mlir
|
||||
// // Build the decomposition of the plaintext
|
||||
// %x0_a = arith.constant 64/d0 : f64
|
||||
// %x0_b = arith.mulf %x, %x0_a : i64
|
||||
// %x0 = arith.fptoui %x0_b : f64 to i64
|
||||
// ...
|
||||
// %xn_a = arith.constant 64/dn : f64
|
||||
// %xn_b = arith.mulf %x, %xn_a : i64
|
||||
// %xn = arith.fptoui %xn_b : f64 to i64
|
||||
// %x_decomp = tensor.from_elements %x0, ..., %xn : tensor<nbBlocksxi64>
|
||||
// // Loop on blocks
|
||||
// %c0 = arith.constant 0 : index
|
||||
// %c1 = arith.constant 1 : index
|
||||
// %cB = arith.constant nbBlocks : index
|
||||
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
|
||||
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
|
||||
// (tensor<nbBlocksxlweSizexi64>) {
|
||||
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64>
|
||||
// %blockArg1 = tensor.extract %x_decomp[%i] : tensor<nbBlocksxi64>
|
||||
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
|
||||
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
|
||||
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
|
||||
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
|
||||
// }
|
||||
// ```
|
||||
struct MulCleartextCRTLweTensorOpPattern
|
||||
: public mlir::OpRewritePattern<BConcrete::MulCleartextCRTLweTensorOp> {
|
||||
MulCleartextCRTLweTensorOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<BConcrete::MulCleartextCRTLweTensorOp>(context,
|
||||
benefit) {
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(BConcrete::MulCleartextCRTLweTensorOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto resultTy =
|
||||
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
|
||||
auto loc = op.getLoc();
|
||||
assert(resultTy.getShape().size() == 2);
|
||||
auto shape = resultTy.getShape();
|
||||
|
||||
// %c0 = arith.constant 0 : index
|
||||
// %c1 = arith.constant 1 : index
|
||||
// %cB = arith.constant nbBlocks : index
|
||||
auto c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
auto cB = rewriter.create<arith::ConstantIndexOp>(loc, shape[0]);
|
||||
|
||||
// %init = linalg.tensor_init [B, lweSize] : tensor<nbBlocksxlweSizexi64>
|
||||
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
|
||||
op.getLoc(), resultTy, mlir::ValueRange{});
|
||||
|
||||
auto rhs = rewriter.create<arith::ExtSIOp>(op.getLoc(),
|
||||
rewriter.getI64Type(), op.rhs());
|
||||
|
||||
// %0 = scf.for %i = %c0 to %cB step %c1 iter_args(%acc = %init) ->
|
||||
// (tensor<nbBlocksxlweSizexi64>) {
|
||||
rewriter.replaceOpWithNewOp<scf::ForOp>(
|
||||
op, c0, cB, c1, init,
|
||||
[&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value i,
|
||||
mlir::ValueRange iterArgs) {
|
||||
// [%i, 0]
|
||||
mlir::SmallVector<mlir::OpFoldResult> offsets{
|
||||
i, rewriter.getI64IntegerAttr(0)};
|
||||
// [1, lweSize]
|
||||
mlir::SmallVector<mlir::OpFoldResult> sizes{
|
||||
rewriter.getI64IntegerAttr(1),
|
||||
rewriter.getI64IntegerAttr(shape[1])};
|
||||
// [1, 1]
|
||||
mlir::SmallVector<mlir::OpFoldResult> strides{
|
||||
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1)};
|
||||
|
||||
auto blockTy = mlir::RankedTensorType::get({shape[1]},
|
||||
resultTy.getElementType());
|
||||
|
||||
// %blockArg0 = tensor.extract_slice %arg0[%i, 0] [1, lweSize] [1, 1]
|
||||
// : tensor<lweSizexi64>
|
||||
auto blockArg0 = builder.create<tensor::ExtractSliceOp>(
|
||||
loc, blockTy, op.lhs(), offsets, sizes, strides);
|
||||
|
||||
// %tmp = BConcrete.mul_cleartext_lwe_buffer(%blockArg0, %x)
|
||||
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
|
||||
auto tmp = builder.create<BConcrete::MulCleartextLweTensorOp>(
|
||||
loc, blockTy, blockArg0, rhs);
|
||||
|
||||
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
|
||||
// 1] : tensor<lweSizexi64> into tensor<nbBlocksxlweSizexi64>
|
||||
auto res = builder.create<tensor::InsertSliceOp>(
|
||||
loc, tmp, iterArgs[0], offsets, sizes, strides);
|
||||
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
|
||||
builder.create<scf::YieldOp>(loc, (mlir::Value)res);
|
||||
});
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
struct EliminateCRTOpsPass : public EliminateCRTOpsBase<EliminateCRTOpsPass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
|
||||
void EliminateCRTOpsPass::runOnOperation() {
|
||||
auto op = getOperation();
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
// add_crt_lwe_buffers
|
||||
target.addIllegalOp<BConcrete::AddCRTLweTensorOp>();
|
||||
patterns.add<BConcreteCRTBinaryOpPattern<BConcrete::AddCRTLweTensorOp,
|
||||
BConcrete::AddLweTensorOp>>(
|
||||
&getContext());
|
||||
|
||||
// add_plaintext_crt_lwe_buffers
|
||||
target.addIllegalOp<BConcrete::AddPlaintextCRTLweTensorOp>();
|
||||
patterns.add<AddPlaintextCRTLweTensorOpPattern>(&getContext());
|
||||
|
||||
// mul_cleartext_crt_lwe_buffer
|
||||
target.addIllegalOp<BConcrete::MulCleartextCRTLweTensorOp>();
|
||||
patterns.add<MulCleartextCRTLweTensorOpPattern>(&getContext());
|
||||
|
||||
target.addIllegalOp<BConcrete::NegateCRTLweTensorOp>();
|
||||
patterns.add<BConcreteCRTUnaryOpPattern<BConcrete::NegateCRTLweTensorOp,
|
||||
BConcrete::NegateLweTensorOp>>(
|
||||
&getContext());
|
||||
|
||||
// This dialect are used to transforms crt ops to bconcrete ops
|
||||
target
|
||||
.addLegalDialect<arith::ArithmeticDialect, tensor::TensorDialect,
|
||||
scf::SCFDialect, bufferization::BufferizationDialect,
|
||||
mlir::func::FuncDialect, BConcrete::BConcreteDialect>();
|
||||
|
||||
// Apply the conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createEliminateCRTOps() {
|
||||
return std::make_unique<EliminateCRTOpsPass>();
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -67,18 +67,6 @@ void GlweCiphertextType::print(mlir::AsmPrinter &p) const {
|
||||
|
||||
void LweCiphertextType::print(mlir::AsmPrinter &p) const {
|
||||
p << "<";
|
||||
// decomposition parameters if any
|
||||
auto crt = getCrtDecomposition();
|
||||
if (!crt.empty()) {
|
||||
p << "crt=[";
|
||||
for (auto c : crt.drop_back(1)) {
|
||||
printSigned(p, c);
|
||||
p << ",";
|
||||
}
|
||||
printSigned(p, crt.back());
|
||||
p << "]";
|
||||
p << ",";
|
||||
}
|
||||
printSigned(p, getDimension());
|
||||
p << ",";
|
||||
printSigned(p, getP());
|
||||
@@ -89,29 +77,6 @@ mlir::Type LweCiphertextType::parse(mlir::AsmParser &parser) {
|
||||
if (parser.parseLess())
|
||||
return mlir::Type();
|
||||
|
||||
// Parse for the crt decomposition if any
|
||||
std::vector<int64_t> crtDecomposition;
|
||||
if (!parser.parseOptionalKeyword("crt")) {
|
||||
if (parser.parseEqual() || parser.parseLSquare())
|
||||
return mlir::Type();
|
||||
while (true) {
|
||||
int64_t c = -1;
|
||||
if (parser.parseOptionalKeyword("_") && parser.parseInteger(c)) {
|
||||
return mlir::Type();
|
||||
}
|
||||
crtDecomposition.push_back(c);
|
||||
if (parser.parseOptionalComma()) {
|
||||
if (parser.parseRSquare()) {
|
||||
return mlir::Type();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (parser.parseComma())
|
||||
return mlir::Type();
|
||||
}
|
||||
|
||||
int dimension = -1;
|
||||
if (parser.parseOptionalKeyword("_") && parser.parseInteger(dimension))
|
||||
return mlir::Type();
|
||||
@@ -125,7 +90,7 @@ mlir::Type LweCiphertextType::parse(mlir::AsmParser &parser) {
|
||||
|
||||
mlir::Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
|
||||
|
||||
return getChecked(loc, loc.getContext(), dimension, p, crtDecomposition);
|
||||
return getChecked(loc, loc.getContext(), dimension, p);
|
||||
}
|
||||
|
||||
void CleartextType::print(mlir::AsmPrinter &p) const {
|
||||
|
||||
@@ -32,8 +32,7 @@ void TFHEDialect::initialize() {
|
||||
/// - The bits parameter is 64 (we support only this for v0)
|
||||
::mlir::LogicalResult GLWECipherTextType::verify(
|
||||
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
|
||||
signed dimension, signed polynomialSize, signed bits, signed p,
|
||||
llvm::ArrayRef<int64_t>) {
|
||||
signed dimension, signed polynomialSize, signed bits, signed p) {
|
||||
if (bits != -1 && bits != 64) {
|
||||
emitError() << "GLWE bits parameter can only be 64";
|
||||
return ::mlir::failure();
|
||||
|
||||
@@ -40,16 +40,14 @@ mlir::LogicalResult _verifyGLWEIntegerOperator(mlir::OpState &op,
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "p");
|
||||
return mlir::failure();
|
||||
}
|
||||
if (a.getCrtDecomposition() != result.getCrtDecomposition()) {
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "crt");
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if ((int)b.getWidth() != 64) {
|
||||
op.emitOpError() << "should have the width of `b` equals to 64.";
|
||||
}
|
||||
// verify consistency of width of inputs
|
||||
if ((int)b.getWidth() > a.getP() + 1) {
|
||||
op.emitOpError()
|
||||
<< "should have the width of `b` equals or less than 'p'+1: "
|
||||
<< b.getWidth() << " <= " << a.getP() << "+ 1";
|
||||
if ((int)b.getWidth() != 64) {
|
||||
op.emitOpError() << "should have the width of `b` equals 64 : "
|
||||
<< b.getWidth() << " != 64";
|
||||
return mlir::failure();
|
||||
}
|
||||
return mlir::success();
|
||||
@@ -111,11 +109,6 @@ mlir::LogicalResult verifyBinaryGLWEOperator(Operator &op) {
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "p");
|
||||
return mlir::failure();
|
||||
}
|
||||
if (a.getCrtDecomposition() != b.getCrtDecomposition() ||
|
||||
a.getCrtDecomposition() != result.getCrtDecomposition()) {
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "crt");
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
@@ -146,10 +139,6 @@ mlir::LogicalResult verifyUnaryGLWEOperator(Operator &op) {
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "p");
|
||||
return mlir::failure();
|
||||
}
|
||||
if (a.getCrtDecomposition() != result.getCrtDecomposition()) {
|
||||
emitOpErrorForIncompatibleGLWEParameter(op, "crt");
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
@@ -18,16 +18,6 @@ void printSigned(mlir::AsmPrinter &p, signed i) {
|
||||
|
||||
void GLWECipherTextType::print(mlir::AsmPrinter &p) const {
|
||||
p << "<";
|
||||
auto crt = getCrtDecomposition();
|
||||
if (!crt.empty()) {
|
||||
p << "crt=[";
|
||||
for (auto c : crt.drop_back(1)) {
|
||||
printSigned(p, c);
|
||||
p << ",";
|
||||
}
|
||||
printSigned(p, crt.back());
|
||||
p << "]";
|
||||
}
|
||||
p << "{";
|
||||
printSigned(p, getDimension());
|
||||
p << ",";
|
||||
@@ -45,27 +35,6 @@ mlir::Type GLWECipherTextType::parse(AsmParser &parser) {
|
||||
if (parser.parseLess())
|
||||
return mlir::Type();
|
||||
|
||||
// Parse for the crt decomposition if any
|
||||
std::vector<int64_t> crtDecomposition;
|
||||
if (!parser.parseOptionalKeyword("crt")) {
|
||||
if (parser.parseEqual() || parser.parseLSquare())
|
||||
return mlir::Type();
|
||||
while (true) {
|
||||
signed c = -1;
|
||||
if (parser.parseOptionalKeyword("_") && parser.parseInteger(c)) {
|
||||
return mlir::Type();
|
||||
}
|
||||
crtDecomposition.push_back(c);
|
||||
if (parser.parseOptionalComma()) {
|
||||
if (parser.parseRSquare()) {
|
||||
return mlir::Type();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (parser.parseLBrace())
|
||||
return mlir::Type();
|
||||
|
||||
@@ -98,8 +67,7 @@ mlir::Type GLWECipherTextType::parse(AsmParser &parser) {
|
||||
if (parser.parseGreater())
|
||||
return mlir::Type();
|
||||
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
|
||||
return getChecked(loc, loc.getContext(), dimension, polynomialSize, bits, p,
|
||||
llvm::ArrayRef<int64_t>(crtDecomposition));
|
||||
return getChecked(loc, loc.getContext(), dimension, polynomialSize, bits, p);
|
||||
}
|
||||
} // namespace TFHE
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -23,34 +23,6 @@ DefaultEngine *get_levelled_engine() {
|
||||
return levelled_engine;
|
||||
}
|
||||
|
||||
void encode_and_expand_lut(uint64_t *output, size_t output_size,
|
||||
size_t out_MESSAGE_BITS, const uint64_t *lut,
|
||||
size_t lut_size) {
|
||||
assert((output_size % lut_size) == 0);
|
||||
|
||||
size_t mega_case_size = output_size / lut_size;
|
||||
|
||||
assert((mega_case_size % 2) == 0);
|
||||
|
||||
for (size_t idx = 0; idx < mega_case_size / 2; ++idx) {
|
||||
output[idx] = lut[0] << (64 - out_MESSAGE_BITS - 1);
|
||||
}
|
||||
|
||||
for (size_t idx = (lut_size - 1) * mega_case_size + mega_case_size / 2;
|
||||
idx < output_size; ++idx) {
|
||||
output[idx] = -(lut[0] << (64 - out_MESSAGE_BITS - 1));
|
||||
}
|
||||
|
||||
for (size_t lut_idx = 1; lut_idx < lut_size; ++lut_idx) {
|
||||
uint64_t lut_value = lut[lut_idx] << (64 - out_MESSAGE_BITS - 1);
|
||||
size_t start = mega_case_size * (lut_idx - 1) + mega_case_size / 2;
|
||||
for (size_t output_idx = start; output_idx < start + mega_case_size;
|
||||
++output_idx) {
|
||||
output[output_idx] = lut_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#include "concretelang/ClientLib/CRT.h"
|
||||
#include "concretelang/Runtime/wrappers.h"
|
||||
|
||||
@@ -217,13 +189,11 @@ void memref_batched_bootstrap_lwe_cuda_u64(
|
||||
uint64_t glwe_ct_len = poly_size * (glwe_dim + 1);
|
||||
uint64_t glwe_ct_size = glwe_ct_len * sizeof(uint64_t);
|
||||
uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size);
|
||||
std::vector<uint64_t> expanded_tabulated_function_array(poly_size);
|
||||
encode_and_expand_lut(expanded_tabulated_function_array.data(), poly_size,
|
||||
precision, tlu_aligned + tlu_offset, tlu_size);
|
||||
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_trivially_encrypt_glwe_ciphertext_u64_raw_ptr_buffers(
|
||||
get_levelled_engine(), glwe_ct, glwe_ct_len,
|
||||
expanded_tabulated_function_array.data(), poly_size));
|
||||
get_levelled_engine(), glwe_ct, glwe_ct_len, tlu_aligned + tlu_offset,
|
||||
poly_size));
|
||||
|
||||
// Move the glwe accumulator to the GPU
|
||||
void *glwe_ct_gpu =
|
||||
@@ -261,34 +231,117 @@ void memref_batched_bootstrap_lwe_cuda_u64(
|
||||
|
||||
#endif
|
||||
|
||||
void memref_expand_lut_in_trivial_glwe_ct_u64(
|
||||
uint64_t *glwe_ct_allocated, uint64_t *glwe_ct_aligned,
|
||||
uint64_t glwe_ct_offset, uint64_t glwe_ct_size, uint64_t glwe_ct_stride,
|
||||
uint32_t poly_size, uint32_t glwe_dimension, uint32_t out_precision,
|
||||
uint64_t *lut_allocated, uint64_t *lut_aligned, uint64_t lut_offset,
|
||||
uint64_t lut_size, uint64_t lut_stride) {
|
||||
void memref_encode_plaintext_with_crt(
|
||||
uint64_t *output_allocated, uint64_t *output_aligned,
|
||||
uint64_t output_offset, uint64_t output_size, uint64_t output_stride,
|
||||
uint64_t input, uint64_t *mods_allocated, uint64_t *mods_aligned,
|
||||
uint64_t mods_offset, uint64_t mods_size, uint64_t mods_stride,
|
||||
uint64_t mods_product) {
|
||||
|
||||
assert(lut_stride == 1 && "Runtime: stride not equal to 1, check "
|
||||
"memref_expand_lut_in_trivial_glwe_ct_u64");
|
||||
assert(output_stride == 1 && "Runtime: stride not equal to 1, check "
|
||||
"memref_encode_plaintext_with_crt");
|
||||
|
||||
assert(glwe_ct_stride == 1 && "Runtime: stride not equal to 1, check "
|
||||
"memref_expand_lut_in_trivial_glwe_ct_u64");
|
||||
assert(mods_stride == 1 && "Runtime: stride not equal to 1, check "
|
||||
"memref_encode_plaintext_with_crt");
|
||||
|
||||
assert(glwe_ct_size == poly_size * (glwe_dimension + 1));
|
||||
|
||||
std::vector<uint64_t> expanded_tabulated_function_array(poly_size);
|
||||
|
||||
encode_and_expand_lut(expanded_tabulated_function_array.data(), poly_size,
|
||||
out_precision, lut_aligned + lut_offset, lut_size);
|
||||
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_trivially_encrypt_glwe_ciphertext_u64_raw_ptr_buffers(
|
||||
get_levelled_engine(), glwe_ct_aligned + glwe_ct_offset, glwe_ct_size,
|
||||
expanded_tabulated_function_array.data(), poly_size));
|
||||
for (size_t i = 0; i < (size_t)mods_size; ++i) {
|
||||
output_aligned[output_offset + i] =
|
||||
encode_crt(input, mods_aligned[mods_offset + i], mods_product);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
void memref_encode_expand_lut_for_bootstrap(
|
||||
uint64_t *output_lut_allocated, uint64_t *output_lut_aligned,
|
||||
uint64_t output_lut_offset, uint64_t output_lut_size,
|
||||
uint64_t output_lut_stride, uint64_t *input_lut_allocated,
|
||||
uint64_t *input_lut_aligned, uint64_t input_lut_offset,
|
||||
uint64_t input_lut_size, uint64_t input_lut_stride, uint32_t poly_size,
|
||||
uint32_t out_MESSAGE_BITS) {
|
||||
|
||||
assert(input_lut_stride == 1 && "Runtime: stride not equal to 1, check "
|
||||
"memref_encode_expand_lut_bootstrap");
|
||||
|
||||
assert(output_lut_stride == 1 && "Runtime: stride not equal to 1, check "
|
||||
"memref_encode_expand_lut_bootstrap");
|
||||
|
||||
size_t mega_case_size = output_lut_size / input_lut_size;
|
||||
|
||||
assert((mega_case_size % 2) == 0);
|
||||
|
||||
for (size_t idx = 0; idx < mega_case_size / 2; ++idx) {
|
||||
output_lut_aligned[output_lut_offset + idx] =
|
||||
input_lut_aligned[input_lut_offset] << (64 - out_MESSAGE_BITS - 1);
|
||||
}
|
||||
|
||||
for (size_t idx = (input_lut_size - 1) * mega_case_size + mega_case_size / 2;
|
||||
idx < output_lut_size; ++idx) {
|
||||
output_lut_aligned[output_lut_offset + idx] =
|
||||
-(input_lut_aligned[input_lut_offset] << (64 - out_MESSAGE_BITS - 1));
|
||||
}
|
||||
|
||||
for (size_t lut_idx = 1; lut_idx < input_lut_size; ++lut_idx) {
|
||||
uint64_t lut_value = input_lut_aligned[input_lut_offset + lut_idx]
|
||||
<< (64 - out_MESSAGE_BITS - 1);
|
||||
size_t start = mega_case_size * (lut_idx - 1) + mega_case_size / 2;
|
||||
for (size_t output_idx = start; output_idx < start + mega_case_size;
|
||||
++output_idx) {
|
||||
output_lut_aligned[output_lut_offset + output_idx] = lut_value;
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
void memref_encode_expand_lut_for_woppbs(
|
||||
// Output encoded/expanded lut
|
||||
uint64_t *output_lut_allocated, uint64_t *output_lut_aligned,
|
||||
uint64_t output_lut_offset, uint64_t output_lut_size,
|
||||
uint64_t output_lut_stride,
|
||||
// Input lut
|
||||
uint64_t *input_lut_allocated, uint64_t *input_lut_aligned,
|
||||
uint64_t input_lut_offset, uint64_t input_lut_size,
|
||||
uint64_t input_lut_stride,
|
||||
// Crt coprimes
|
||||
uint64_t *crt_decomposition_allocated, uint64_t *crt_decomposition_aligned,
|
||||
uint64_t crt_decomposition_offset, uint64_t crt_decomposition_size,
|
||||
uint64_t crt_decomposition_stride,
|
||||
// Crt number of bits
|
||||
uint64_t *crt_bits_allocated, uint64_t *crt_bits_aligned,
|
||||
uint64_t crt_bits_offset, uint64_t crt_bits_size, uint64_t crt_bits_stride,
|
||||
// Crypto parameters
|
||||
uint32_t poly_size, uint32_t modulus_product) {
|
||||
|
||||
assert(input_lut_stride == 1 && "Runtime: stride not equal to 1, check "
|
||||
"memref_encode_expand_lut_woppbs");
|
||||
assert(output_lut_stride == 1 && "Runtime: stride not equal to 1, check "
|
||||
"memref_encode_expand_lut_woppbs");
|
||||
assert(modulus_product > input_lut_size);
|
||||
|
||||
uint64_t lut_crt_size = output_lut_size / crt_decomposition_size;
|
||||
|
||||
for (uint64_t value = 0; value < input_lut_size; value++) {
|
||||
uint64_t index_lut = 0;
|
||||
uint64_t tmp = 1;
|
||||
|
||||
for (size_t block = 0; block < crt_decomposition_size; block++) {
|
||||
auto base = crt_decomposition_aligned[crt_decomposition_offset + block];
|
||||
auto bits = crt_bits_aligned[crt_bits_offset + block];
|
||||
index_lut += (((value % base) << bits) / base) * tmp;
|
||||
tmp <<= bits;
|
||||
}
|
||||
|
||||
for (size_t block = 0; block < crt_decomposition_size; block++) {
|
||||
auto base = crt_decomposition_aligned[crt_decomposition_offset + block];
|
||||
auto v = encode_crt(input_lut_aligned[input_lut_offset + value], base,
|
||||
modulus_product);
|
||||
output_lut_aligned[output_lut_offset + block * lut_crt_size + index_lut] =
|
||||
v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void memref_add_lwe_ciphertexts_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
@@ -387,15 +440,10 @@ void memref_bootstrap_lwe_u64(
|
||||
uint64_t glwe_ct_size = poly_size * (glwe_dim + 1);
|
||||
uint64_t *glwe_ct = (uint64_t *)malloc(glwe_ct_size * sizeof(uint64_t));
|
||||
|
||||
std::vector<uint64_t> expanded_tabulated_function_array(poly_size);
|
||||
|
||||
encode_and_expand_lut(expanded_tabulated_function_array.data(), poly_size,
|
||||
precision, tlu_aligned + tlu_offset, tlu_size);
|
||||
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_trivially_encrypt_glwe_ciphertext_u64_raw_ptr_buffers(
|
||||
get_levelled_engine(), glwe_ct, glwe_ct_size,
|
||||
expanded_tabulated_function_array.data(), poly_size));
|
||||
tlu_aligned + tlu_offset, poly_size));
|
||||
|
||||
CAPI_ASSERT_ERROR(
|
||||
fft_engine_lwe_ciphertext_discarding_bootstrap_u64_raw_ptr_buffers(
|
||||
@@ -430,40 +478,6 @@ uint64_t encode_crt(int64_t plaintext, uint64_t modulus, uint64_t product) {
|
||||
return concretelang::clientlib::crt::encode(plaintext, modulus, product);
|
||||
}
|
||||
|
||||
void generate_luts_crt_without_padding(
|
||||
uint64_t *&luts_crt, uint64_t &total_luts_crt_size, uint64_t *crt_decomp,
|
||||
uint64_t *number_of_bits_per_block, size_t crt_size, uint64_t *lut,
|
||||
uint64_t lut_size, uint64_t total_number_of_bits, uint64_t modulus,
|
||||
uint64_t polynomialSize) {
|
||||
|
||||
uint64_t lut_crt_size = uint64_t(1) << total_number_of_bits;
|
||||
lut_crt_size = std::max(lut_crt_size, polynomialSize);
|
||||
|
||||
total_luts_crt_size = crt_size * lut_crt_size;
|
||||
luts_crt = (uint64_t *)aligned_alloc(U64_ALIGNMENT,
|
||||
sizeof(uint64_t) * total_luts_crt_size);
|
||||
|
||||
assert(modulus > lut_size);
|
||||
for (uint64_t value = 0; value < lut_size; value++) {
|
||||
uint64_t index_lut = 0;
|
||||
|
||||
uint64_t tmp = 1;
|
||||
|
||||
for (size_t block = 0; block < crt_size; block++) {
|
||||
auto base = crt_decomp[block];
|
||||
auto bits = number_of_bits_per_block[block];
|
||||
index_lut += (((value % base) << bits) / base) * tmp;
|
||||
tmp <<= bits;
|
||||
}
|
||||
|
||||
for (size_t block = 0; block < crt_size; block++) {
|
||||
auto base = crt_decomp[block];
|
||||
auto v = encode_crt(lut[value], base, modulus);
|
||||
luts_crt[block * lut_crt_size + index_lut] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void memref_wop_pbs_crt_buffer(
|
||||
// Output 2D memref
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
@@ -543,23 +557,14 @@ void memref_wop_pbs_crt_buffer(
|
||||
in_block, nb_bits_to_extract, delta_log));
|
||||
}
|
||||
|
||||
uint64_t *luts_crt;
|
||||
uint64_t luts_crt_size;
|
||||
generate_luts_crt_without_padding(
|
||||
luts_crt, luts_crt_size, crt_decomp_aligned, number_of_bits_per_block,
|
||||
crt_decomp_size, lut_ct_aligned, lut_ct_size,
|
||||
total_number_of_bits_per_block, message_modulus, polynomial_size);
|
||||
|
||||
// Vertical packing
|
||||
CAPI_ASSERT_ERROR(
|
||||
fft_engine_lwe_ciphertext_vector_discarding_circuit_bootstrap_boolean_vertical_packing_u64_raw_ptr_buffers(
|
||||
context->get_fft_engine(), context->get_default_engine(),
|
||||
context->get_fft_fourier_bsk(), out_aligned, lwe_big_size,
|
||||
crt_decomp_size, extract_bits_output_buffer, lwe_small_size,
|
||||
total_number_of_bits_per_block, luts_crt, luts_crt_size,
|
||||
cbs_level_count, cbs_base_log, context->get_fpksk()));
|
||||
|
||||
free(luts_crt);
|
||||
total_number_of_bits_per_block, lut_ct_aligned + lut_ct_offset,
|
||||
lut_ct_size, cbs_level_count, cbs_base_log, context->get_fpksk()));
|
||||
}
|
||||
|
||||
void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned,
|
||||
|
||||
@@ -21,7 +21,8 @@ add_mlir_library(
|
||||
FHELinalgDialect
|
||||
FHELinalgDialectTransforms
|
||||
FHETensorOpsToLinalg
|
||||
FHEToTFHE
|
||||
FHEToTFHECrt
|
||||
FHEToTFHEScalar
|
||||
ExtractSDFGOps
|
||||
MLIRLowerableDialectsToLLVM
|
||||
FHEDialectAnalysis
|
||||
|
||||
@@ -311,31 +311,14 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
if (target == Target::FHE)
|
||||
return std::move(res);
|
||||
|
||||
// FHE -> TFHE
|
||||
if (mlir::concretelang::pipeline::lowerFHEToTFHE(mlirContext, module,
|
||||
res.fheContext, enablePass)
|
||||
// FHELinalg -> FHE
|
||||
if (mlir::concretelang::pipeline::lowerFHELinalgToFHE(
|
||||
mlirContext, module, res.fheContext, enablePass, loopParallelize,
|
||||
options.batchConcreteOps)
|
||||
.failed()) {
|
||||
return errorDiag("Lowering from FHE to TFHE failed");
|
||||
return errorDiag("Lowering from FHELinalg to FHE failed");
|
||||
}
|
||||
if (target == Target::TFHE)
|
||||
return std::move(res);
|
||||
|
||||
// TFHE -> Concrete
|
||||
if (mlir::concretelang::pipeline::lowerTFHEToConcrete(
|
||||
mlirContext, module, res.fheContext, this->enablePass)
|
||||
.failed()) {
|
||||
return errorDiag("Lowering from TFHE to Concrete failed");
|
||||
}
|
||||
|
||||
// Optimizing Concrete
|
||||
if (this->compilerOptions.optimizeConcrete &&
|
||||
mlir::concretelang::pipeline::optimizeConcrete(mlirContext, module,
|
||||
this->enablePass)
|
||||
.failed()) {
|
||||
return errorDiag("Optimizing Concrete failed");
|
||||
}
|
||||
|
||||
if (target == Target::CONCRETE)
|
||||
if (target == Target::FHE_NO_LINALG)
|
||||
return std::move(res);
|
||||
|
||||
// Generate client parameters if requested
|
||||
@@ -371,19 +354,33 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
}
|
||||
}
|
||||
|
||||
// Concrete with linalg ops -> Concrete with loop ops
|
||||
if (mlir::concretelang::pipeline::lowerConcreteLinalgToLoops(
|
||||
mlirContext, module, this->enablePass, loopParallelize,
|
||||
options.batchConcreteOps)
|
||||
// FHE -> TFHE
|
||||
if (mlir::concretelang::pipeline::lowerFHEToTFHE(mlirContext, module,
|
||||
res.fheContext, enablePass)
|
||||
.failed()) {
|
||||
return StreamStringError(
|
||||
"Lowering from Concrete with linalg ops to Concrete with loops failed");
|
||||
return errorDiag("Lowering from FHE to TFHE failed");
|
||||
}
|
||||
if (target == Target::TFHE)
|
||||
return std::move(res);
|
||||
|
||||
// TFHE -> Concrete
|
||||
if (mlir::concretelang::pipeline::lowerTFHEToConcrete(
|
||||
mlirContext, module, res.fheContext, this->enablePass)
|
||||
.failed()) {
|
||||
return errorDiag("Lowering from TFHE to Concrete failed");
|
||||
}
|
||||
|
||||
if (target == Target::CONCRETEWITHLOOPS) {
|
||||
return std::move(res);
|
||||
// Optimizing Concrete
|
||||
if (this->compilerOptions.optimizeConcrete &&
|
||||
mlir::concretelang::pipeline::optimizeConcrete(mlirContext, module,
|
||||
this->enablePass)
|
||||
.failed()) {
|
||||
return errorDiag("Optimizing Concrete failed");
|
||||
}
|
||||
|
||||
if (target == Target::CONCRETE)
|
||||
return std::move(res);
|
||||
|
||||
// Concrete -> BConcrete
|
||||
if (mlir::concretelang::pipeline::lowerConcreteToBConcrete(
|
||||
mlirContext, module, this->enablePass, loopParallelize)
|
||||
|
||||
@@ -183,27 +183,56 @@ markFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelizeLoops, bool batchOperations) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("FHELinalgToFHE", pm, context);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertFHETensorOpsToLinalg(), enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::createLinalgGeneralizationPass(),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm,
|
||||
mlir::concretelang::createLinalgGenericOpWithTensorsToLoopsPass(
|
||||
parallelizeLoops),
|
||||
enablePass);
|
||||
|
||||
if (batchOperations) {
|
||||
addPotentiallyNestedPass(pm, mlir::concretelang::createBatchingPass(),
|
||||
enablePass);
|
||||
}
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("FHEToTFHE", pm, context);
|
||||
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertFHETensorOpsToLinalg(), enablePass);
|
||||
// FHETensorOpsToLinalg does generate linalg named ops that need to be lowered
|
||||
// to linalg.generic operations
|
||||
addPotentiallyNestedPass(pm, mlir::createLinalgGeneralizationPass(),
|
||||
enablePass);
|
||||
mlir::concretelang::ApplyLookupTableLowering lowerStrategy =
|
||||
mlir::concretelang::KeySwitchBoostrapLowering;
|
||||
if (fheContext.hasValue() && fheContext->parameter.largeInteger.hasValue()) {
|
||||
lowerStrategy = mlir::concretelang::WopPBSLowering;
|
||||
pipelinePrinting("FHEToTFHECrt", pm, context);
|
||||
auto dec =
|
||||
fheContext.value().parameter.largeInteger.value().crtDecomposition;
|
||||
auto mods = mlir::SmallVector<int64_t>(dec.begin(), dec.end());
|
||||
auto polySize = fheContext.value().parameter.getPolynomialSize();
|
||||
addPotentiallyNestedPass(
|
||||
pm,
|
||||
mlir::concretelang::createConvertFHEToTFHECrtPass(
|
||||
mlir::concretelang::CrtLoweringParameters(mods, polySize)),
|
||||
enablePass);
|
||||
} else if (fheContext.hasValue()) {
|
||||
pipelinePrinting("FHEToTFHEScalar", pm, context);
|
||||
size_t polySize = fheContext.value().parameter.getPolynomialSize();
|
||||
addPotentiallyNestedPass(
|
||||
pm,
|
||||
mlir::concretelang::createConvertFHEToTFHEScalarPass(
|
||||
mlir::concretelang::ScalarLoweringParameters(polySize)),
|
||||
enablePass);
|
||||
}
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertFHEToTFHEPass(lowerStrategy),
|
||||
enablePass);
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
@@ -240,27 +269,6 @@ optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerConcreteLinalgToLoops(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelizeLoops, bool batchOperations) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("ConcreteLinalgToLoops", pm, context);
|
||||
|
||||
addPotentiallyNestedPass(
|
||||
pm,
|
||||
mlir::concretelang::createLinalgGenericOpWithTensorsToLoopsPass(
|
||||
parallelizeLoops),
|
||||
enablePass);
|
||||
|
||||
if (batchOperations) {
|
||||
addPotentiallyNestedPass(pm, mlir::concretelang::createBatchingPass(),
|
||||
enablePass);
|
||||
}
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
@@ -293,8 +301,6 @@ lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("BConcreteToStd", pm, context);
|
||||
addPotentiallyNestedPass(pm, mlir::concretelang::createEliminateCRTOps(),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::concretelang::createAddRuntimeContext(),
|
||||
enablePass);
|
||||
return pm.run(module.getOperation());
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "concretelang/Support/V0Curves.h"
|
||||
|
||||
@@ -34,7 +35,8 @@ const auto keyFormat = KEY_FORMAT_BINARY;
|
||||
const auto v0Curve = getV0Curves(securityLevel, keyFormat);
|
||||
|
||||
/// For the v0 the secretKeyID and precision are the same for all gates.
|
||||
llvm::Expected<CircuitGate> gateFromMLIRType(LweSecretKeyID secretKeyID,
|
||||
llvm::Expected<CircuitGate> gateFromMLIRType(V0FHEContext fheContext,
|
||||
LweSecretKeyID secretKeyID,
|
||||
Variance variance,
|
||||
mlir::Type type) {
|
||||
if (type.isIntOrIndex()) {
|
||||
@@ -58,29 +60,35 @@ llvm::Expected<CircuitGate> gateFromMLIRType(LweSecretKeyID secretKeyID,
|
||||
};
|
||||
}
|
||||
if (auto lweTy = type.dyn_cast_or_null<
|
||||
mlir::concretelang::Concrete::LweCiphertextType>()) {
|
||||
mlir::concretelang::FHE::EncryptedIntegerType>()) {
|
||||
bool sign = lweTy.isSignedInteger();
|
||||
std::vector<int64_t> crt;
|
||||
if (fheContext.parameter.largeInteger.has_value()) {
|
||||
crt = fheContext.parameter.largeInteger.value().crtDecomposition;
|
||||
}
|
||||
return CircuitGate{
|
||||
/* .encryption = */ llvm::Optional<EncryptionGate>({
|
||||
/* .secretKeyID = */ secretKeyID,
|
||||
/* .variance = */ variance,
|
||||
/* .encoding = */
|
||||
{
|
||||
/* .precision = */ (size_t)lweTy.getP(),
|
||||
/* .crt = */ lweTy.getCrtDecomposition().vec(),
|
||||
/* .precision = */ lweTy.getWidth(),
|
||||
/* .crt = */ crt,
|
||||
},
|
||||
}),
|
||||
/*.shape = */
|
||||
{/*.width = */ (size_t)lweTy.getP(),
|
||||
/*.dimensions = */ std::vector<int64_t>(),
|
||||
/*.size = */ 0,
|
||||
/* .sign */ sign},
|
||||
{
|
||||
/*.width = */ (size_t)lweTy.getWidth(),
|
||||
/*.dimensions = */ std::vector<int64_t>(),
|
||||
/*.size = */ 0,
|
||||
/*.sign = */ sign,
|
||||
},
|
||||
};
|
||||
}
|
||||
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
if (tensor != nullptr) {
|
||||
auto gate =
|
||||
gateFromMLIRType(secretKeyID, variance, tensor.getElementType());
|
||||
auto gate = gateFromMLIRType(fheContext, secretKeyID, variance,
|
||||
tensor.getElementType());
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
@@ -179,14 +187,13 @@ createClientParametersForV0(V0FHEContext fheContext,
|
||||
auto funcType = (*funcOp).getFunctionType();
|
||||
|
||||
auto inputs = funcType.getInputs();
|
||||
|
||||
bool hasContext =
|
||||
inputs.empty()
|
||||
? false
|
||||
: inputs.back().isa<mlir::concretelang::Concrete::ContextType>();
|
||||
|
||||
auto gateFromType = [&](mlir::Type ty) {
|
||||
return gateFromMLIRType(clientlib::BIG_KEY, inputVariance, ty);
|
||||
return gateFromMLIRType(fheContext, clientlib::BIG_KEY, inputVariance, ty);
|
||||
};
|
||||
for (auto inType = funcType.getInputs().begin();
|
||||
inType < funcType.getInputs().end() - hasContext; inType++) {
|
||||
|
||||
@@ -46,9 +46,9 @@ namespace clientlib = concretelang::clientlib;
|
||||
enum Action {
|
||||
ROUND_TRIP,
|
||||
DUMP_FHE,
|
||||
DUMP_FHE_NO_LINALG,
|
||||
DUMP_TFHE,
|
||||
DUMP_CONCRETE,
|
||||
DUMP_CONCRETEWITHLOOPS,
|
||||
DUMP_BCONCRETE,
|
||||
DUMP_SDFG,
|
||||
DUMP_STD,
|
||||
@@ -119,13 +119,13 @@ static llvm::cl::opt<enum Action> action(
|
||||
"Parse input module and regenerate textual representation")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_FHE, "dump-fhe",
|
||||
"Dump FHE module")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_FHE_NO_LINALG,
|
||||
"dump-fhe-no-linalg",
|
||||
"Lower FHELinalg to FHE and dump result")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_TFHE, "dump-tfhe",
|
||||
"Lower to TFHE and dump result")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_CONCRETE, "dump-concrete",
|
||||
"Lower to Concrete and dump result")),
|
||||
llvm::cl::values(clEnumValN(
|
||||
Action::DUMP_CONCRETEWITHLOOPS, "dump-concrete-with-loops",
|
||||
"Lower to Concrete, replace linalg ops with loops and dump result")),
|
||||
llvm::cl::values(
|
||||
clEnumValN(Action::DUMP_BCONCRETE, "dump-bconcrete",
|
||||
"Lower to Bufferized Concrete and dump result")),
|
||||
@@ -369,9 +369,9 @@ cmdlineCompilationOptions() {
|
||||
"The large-integers options should all be set",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (cmdline::largeIntegerPackingKeyswitch.size() != 5) {
|
||||
if (cmdline::largeIntegerPackingKeyswitch.size() != 4) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"The large-integers-packing-keyswitch must be a list of 5 integer",
|
||||
"The large-integers-packing-keyswitch must be a list of 4 integer",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (cmdline::largeIntegerCircuitBootstrap.size() != 2) {
|
||||
@@ -488,15 +488,15 @@ mlir::LogicalResult processInputBuffer(
|
||||
case Action::DUMP_FHE:
|
||||
target = mlir::concretelang::CompilerEngine::Target::FHE;
|
||||
break;
|
||||
case Action::DUMP_FHE_NO_LINALG:
|
||||
target = mlir::concretelang::CompilerEngine::Target::FHE_NO_LINALG;
|
||||
break;
|
||||
case Action::DUMP_TFHE:
|
||||
target = mlir::concretelang::CompilerEngine::Target::TFHE;
|
||||
break;
|
||||
case Action::DUMP_CONCRETE:
|
||||
target = mlir::concretelang::CompilerEngine::Target::CONCRETE;
|
||||
break;
|
||||
case Action::DUMP_CONCRETEWITHLOOPS:
|
||||
target = mlir::concretelang::CompilerEngine::Target::CONCRETEWITHLOOPS;
|
||||
break;
|
||||
case Action::DUMP_BCONCRETE:
|
||||
target = mlir::concretelang::CompilerEngine::Target::BCONCRETE;
|
||||
break;
|
||||
|
||||
@@ -6,3 +6,10 @@ func.func @main(%arg0: !FHE.eint<15>, %cst: tensor<32768xi64>) -> tensor<1x!FHE.
|
||||
%6 = tensor.from_elements %1 : tensor<1x!FHE.eint<15>> // ERROR HERE line 4
|
||||
return %6 : tensor<1x!FHE.eint<15>>
|
||||
}
|
||||
|
||||
// Ensures that tensors of multiple elements can be constructed as well.
|
||||
func.func @main2(%arg0: !FHE.eint<15>, %cst: tensor<32768xi64>) -> tensor<2x!FHE.eint<15>> {
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<15>, tensor<32768xi64>) -> !FHE.eint<15>
|
||||
%6 = tensor.from_elements %1, %arg0 : tensor<2x!FHE.eint<15>> // ERROR HERE line 4
|
||||
return %6 : tensor<2x!FHE.eint<15>>
|
||||
}
|
||||
|
||||
@@ -8,12 +8,3 @@ func.func @add_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !
|
||||
%0 = "Concrete.add_lwe_ciphertexts"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<2048,7>, !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
return %0 : !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
//CHECK: func @add_crt_lwe_ciphertexts(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x2049xi64>
|
||||
//CHECK: }
|
||||
func.func @add_crt_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>, %arg1: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7> {
|
||||
%0 = "Concrete.add_lwe_ciphertexts"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>, !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>
|
||||
return %0 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>
|
||||
}
|
||||
|
||||
@@ -2,39 +2,21 @@
|
||||
|
||||
|
||||
//CHECK: func.func @add_glwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
|
||||
//CHECK: %c1_i8 = arith.constant 1 : i8
|
||||
//CHECK: %[[V0:.*]] = arith.extui %c1_i8 : i8 to i64
|
||||
//CHECK: %c56_i64 = arith.constant 56 : i64
|
||||
//CHECK: %[[V1:.*]] = arith.shli %[[V0]], %c56_i64 : i64
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: %c1_i64 = arith.constant 1 : i64
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %c1_i64) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: return %[[V2]] : tensor<1025xi64>
|
||||
//CHECK: }
|
||||
func.func @add_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
|
||||
%0 = arith.constant 1 : i8
|
||||
%2 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
%0 = arith.constant 1 : i64
|
||||
%2 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,7>, i64) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
return %2 : !Concrete.lwe_ciphertext<1024,7>
|
||||
}
|
||||
|
||||
//CHECK: func.func @add_glwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i5) -> tensor<1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = arith.extui %[[A1]] : i5 to i64
|
||||
//CHECK: %c59_i64 = arith.constant 59 : i64
|
||||
//CHECK: %[[V1:.*]] = arith.shli %[[V0]], %c59_i64 : i64
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: func.func @add_glwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i64) -> tensor<1025xi64> {
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: return %[[V2]] : tensor<1025xi64>
|
||||
//CHECK: }
|
||||
func.func @add_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
%1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
func.func @add_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i64) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
%1 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<1024,4>, i64) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
return %1 : !Concrete.lwe_ciphertext<1024,4>
|
||||
}
|
||||
|
||||
|
||||
//CHECK: func.func @add_plaintext_lwe_ciphertext(%[[A0:.*]]: tensor<5x1025xi64>) -> tensor<5x1025xi64> {
|
||||
//CHECK: %c1_i8 = arith.constant 1 : i8
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_tensor"(%[[A0]], %c1_i8) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i8) -> tensor<5x1025xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x1025xi64>
|
||||
//CHECK: }
|
||||
func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7> {
|
||||
%0 = arith.constant 1 : i8
|
||||
%2 = "Concrete.add_plaintext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>, i8) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>
|
||||
return %2 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>
|
||||
}
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @apply_lookup_table(%arg0: tensor<4xi64>) -> tensor<1024xi64> {
|
||||
// CHECK-NEXT: %0 = "BConcrete.encode_expand_lut_for_bootstrap_tensor"(%arg0) {outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64>
|
||||
// CHECK-NEXT: return %0 : tensor<1024xi64>
|
||||
// CHECK-NEXT: }
|
||||
func.func @apply_lookup_table(%arg0: tensor<4xi64>) -> tensor<1024xi64> {
|
||||
%0 = "Concrete.encode_expand_lut_for_bootstrap"(%arg0) {outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64>
|
||||
return %0 : tensor<1024xi64>
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @main(%arg0: tensor<4xi64>) -> tensor<40960xi64> {
|
||||
// CHECK-NEXT: %0 = "BConcrete.encode_expand_lut_for_woppbs_tensor"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
|
||||
// CHECK-NEXT: return %0 : tensor<40960xi64>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg0: tensor<4xi64>) -> tensor<40960xi64> {
|
||||
%0 = "Concrete.encode_expand_lut_for_woppbs"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
|
||||
return %0 : tensor<40960xi64>
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @main(%arg0: i64) -> tensor<5xi64> {
|
||||
// CHECK-NEXT: %0 = "BConcrete.encode_plaintext_with_crt_tensor"(%arg0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
|
||||
// CHECK-NEXT: return %0 : tensor<5xi64>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg0: i64) -> tensor<5xi64> {
|
||||
%0 = "Concrete.encode_plaintext_with_crt"(%arg0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
|
||||
return %0 : tensor<5xi64>
|
||||
}
|
||||
@@ -6,10 +6,3 @@
|
||||
func.func @identity(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
|
||||
return %arg0 : !Concrete.lwe_ciphertext<1024,7>
|
||||
}
|
||||
|
||||
// CHECK: func.func @identity_crt(%arg0: tensor<5x1025xi64>) -> tensor<5x1025xi64> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<5x1025xi64>
|
||||
// CHECK-NEXT: }
|
||||
func.func @identity_crt(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7> {
|
||||
return %arg0 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>
|
||||
}
|
||||
|
||||
@@ -1,33 +1,21 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @mul_lwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
|
||||
//CHECK: %c1_i8 = arith.constant 1 : i8
|
||||
//CHECK: %[[V0:.*]] = arith.extui %c1_i8 : i8 to i64
|
||||
//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: %c1_i64 = arith.constant 1 : i64
|
||||
//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %c1_i64) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: return %[[V1]] : tensor<1025xi64>
|
||||
//CHECK: }
|
||||
func.func @mul_lwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
|
||||
%0 = arith.constant 1 : i8
|
||||
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
%0 = arith.constant 1 : i64
|
||||
%2 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %0) : (!Concrete.lwe_ciphertext<1024,7>, i64) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
return %2 : !Concrete.lwe_ciphertext<1024,7>
|
||||
}
|
||||
|
||||
//CHECK: func.func @mul_lwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i5) -> tensor<1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = arith.extui %[[A1]] : i5 to i64
|
||||
//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: func.func @mul_lwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i64) -> tensor<1025xi64> {
|
||||
//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: return %[[V1]] : tensor<1025xi64>
|
||||
//CHECK: }
|
||||
func.func @mul_lwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
func.func @mul_lwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i64) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<1024,4>, i64) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
return %1 : !Concrete.lwe_ciphertext<1024,4>
|
||||
}
|
||||
|
||||
|
||||
//CHECK: func.func @mul_cleartext_lwe_ciphertext_crt(%[[A0:.*]]: tensor<5x1025xi64>, %[[A1:.*]]: i5) -> tensor<5x1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i5) -> tensor<5x1025xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x1025xi64>
|
||||
//CHECK: }
|
||||
func.func @mul_cleartext_lwe_ciphertext_crt(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4> {
|
||||
%1 = "Concrete.mul_cleartext_lwe_ciphertext"(%arg0, %arg1) : (!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>, i5) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>
|
||||
return %1 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>
|
||||
}
|
||||
|
||||
@@ -8,12 +8,3 @@ func.func @neg_lwe(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_cip
|
||||
%0 = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
return %0 : !Concrete.lwe_ciphertext<1024,4>
|
||||
}
|
||||
|
||||
//CHECK: func.func @negate_lwe_ciphertext(%[[A0:.*]]: tensor<5x1025xi64>) -> tensor<5x1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_tensor"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>) -> tensor<5x1025xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x1025xi64>
|
||||
//CHECK: }
|
||||
func.func @negate_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4> {
|
||||
%0 = "Concrete.negate_lwe_ciphertext"(%arg0) : (!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>
|
||||
return %0 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>
|
||||
}
|
||||
|
||||
@@ -31,16 +31,6 @@ func.func @tensor_collatenspse_shape(%arg0: tensor<2x3x2x3x4x!Concrete.lwe_ciphe
|
||||
return %0 : tensor<6x2x12x!Concrete.lwe_ciphertext<1024,4>>
|
||||
}
|
||||
|
||||
// -----
|
||||
//CHECK: func.func @tensor_collapse_shape_crt(%[[A0:.*]]: tensor<2x3x4x5x6x5x1025xi64>) -> tensor<720x5x1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = tensor.collapse_shape %[[A0]] [[_:\[\[0, 1, 2, 3, 4\], \[5\], \[6\]\]]] : tensor<2x3x4x5x6x5x1025xi64> into tensor<720x5x1025xi64>
|
||||
//CHECK: return %[[V0]] : tensor<720x5x1025xi64>
|
||||
//CHECK: }
|
||||
func.func @tensor_collapse_shape_crt(%arg0: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>>) -> tensor<720x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>> {
|
||||
%0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4]] {MANP = 1 : ui1}: tensor<2x3x4x5x6x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>> into tensor<720x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>>
|
||||
return %0 : tensor<720x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>>
|
||||
}
|
||||
|
||||
// -----
|
||||
//CHECK: func.func @tensor_expand_shape_crt(%[[A0:.*]]: tensor<30x1025xi64>) -> tensor<5x6x1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = tensor.expand_shape %[[A0]] [[_:\[\[0, 1\], \[2\]\]]] : tensor<30x1025xi64> into tensor<5x6x1025xi64>
|
||||
|
||||
@@ -6,10 +6,3 @@
|
||||
func.func @tensor_identity(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>> {
|
||||
return %arg0 : tensor<2x3x4x!Concrete.lwe_ciphertext<1024,7>>
|
||||
}
|
||||
|
||||
// CHECK: func.func @tensor_identity_crt(%arg0: tensor<2x3x4x5x1025xi64>) -> tensor<2x3x4x5x1025xi64> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x3x4x5x1025xi64>
|
||||
// CHECK-NEXT: }
|
||||
func.func @tensor_identity_crt(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>> {
|
||||
return %arg0 : tensor<2x3x4x!Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>>
|
||||
}
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
// RUN: concretecompiler %s --passes fhe-to-tfhe --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @add_eint_int(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
func.func @add_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %[[V2:.*]] = "TFHE.add_glwe_int"(%arg0, %[[V1]]) : (!TFHE.glwe<{_,_,_}{7}>, i8) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{7}>
|
||||
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
// RUN: concretecompiler %s --passes fhe-to-tfhe --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @apply_lookup_table(%[[A0:.*]]: !TFHE.glwe<{_,_,_}{2}>, %[[LUT:.*]]: tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "TFHE.keyswitch_glwe"(%[[A0]]) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %[[LUT]]) {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}>
|
||||
// CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{3}>
|
||||
func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<3> {
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>)
|
||||
return %1: !FHE.eint<3>
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: #map0 = affine_map<(d0, d1, d2, d3) -> (d1)>
|
||||
//CHECK-NEXT: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
//CHECK-NEXT: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
|
||||
//CHECK-NEXT: #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
|
||||
//CHECK-NEXT: #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
|
||||
//CHECK-NEXT: module {
|
||||
//CHECK-NEXT: func.func @conv2d(%arg0: tensor<100x3x28x28x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4x3x14x14xi3>, %arg2: tensor<4xi3>) -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>> {
|
||||
//CHECK-NEXT: %0 = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<4xi3>) outs(%0 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: ^bb0(%arg3: i3, %arg4: !TFHE.glwe<{_,_,_}{2}>):
|
||||
//CHECK-NEXT: %3 = "TFHE.add_glwe_int"(%arg4, %arg3) : (!TFHE.glwe<{_,_,_}{2}>, i3) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: linalg.yield %3 : !TFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: } -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<100x3x28x28x!TFHE.glwe<{_,_,_}{2}>>, tensor<4x3x14x14xi3>) outs(%1 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: ^bb0(%arg3: !TFHE.glwe<{_,_,_}{2}>, %arg4: i3, %arg5: !TFHE.glwe<{_,_,_}{2}>):
|
||||
//CHECK-NEXT: %3 = "TFHE.mul_glwe_int"(%arg3, %arg4) : (!TFHE.glwe<{_,_,_}{2}>, i3) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: %4 = "TFHE.add_glwe"(%arg5, %3) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: linalg.yield %4 : !TFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: } -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: return %2 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: }
|
||||
|
||||
func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0, 0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
// RUN: concretecompiler %s --passes fhe-to-tfhe --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: #map0 = affine_map<(d0) -> (d0)>
|
||||
// CHECK-NEXT: #map1 = affine_map<(d0) -> (0)>
|
||||
// CHECK-NEXT: module {
|
||||
// CHECK-NEXT: func.func @linalg_generic(%arg0: tensor<2x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<2xi3>, %arg2: tensor<1x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %0 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!TFHE.glwe<{_,_,_}{2}>>, tensor<2xi3>) outs(%arg2 : tensor<1x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: ^bb0(%arg3: !TFHE.glwe<{_,_,_}{2}>, %arg4: i3, %arg5: !TFHE.glwe<{_,_,_}{2}>):
|
||||
// CHECK-NEXT: %1 = "TFHE.mul_glwe_int"(%arg3, %arg4) : (!TFHE.glwe<{_,_,_}{2}>, i3) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
// CHECK-NEXT: %2 = "TFHE.add_glwe"(%1, %arg5) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
// CHECK-NEXT: linalg.yield %2 : !TFHE.glwe<{_,_,_}{2}>
|
||||
// CHECK-NEXT: } -> tensor<1x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
#map0 = affine_map<(d0) -> (d0)>
|
||||
#map1 = affine_map<(d0) -> (0)>
|
||||
module {
|
||||
func.func @linalg_generic(%arg0: tensor<2x!FHE.eint<2>>, %arg1: tensor<2xi3>, %acc: tensor<1x!FHE.eint<2>>) {
|
||||
%2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!FHE.eint<2>>, tensor<2xi3>) outs(%acc : tensor<1x!FHE.eint<2>>) {
|
||||
^bb0(%arg2: !FHE.eint<2>, %arg3: i3, %arg4: !FHE.eint<2>): // no predecessors
|
||||
%4 = "FHE.mul_eint_int"(%arg2, %arg3) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
%5 = "FHE.add_eint"(%4, %arg4) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
linalg.yield %5 : !FHE.eint<2>
|
||||
} -> tensor<1x!FHE.eint<2>>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
// RUN: concretecompiler %s --passes fhe-to-tfhe --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @mul_eint_int(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
func.func @mul_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 2 : i8
|
||||
// CHECK-NEXT: %[[V2:.*]] = "TFHE.mul_glwe_int"(%arg0, %[[V1]]) : (!TFHE.glwe<{_,_,_}{7}>, i8) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{7}>
|
||||
|
||||
%0 = arith.constant 2 : i8
|
||||
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
// RUN: concretecompiler %s --passes fhe-to-tfhe --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @neg_eint(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "TFHE.neg_glwe"(%arg0) : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: return %[[V1]] : !TFHE.glwe<{_,_,_}{7}>
|
||||
|
||||
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
// RUN: concretecompiler %s --passes fhe-to-tfhe --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @sub_int_eint(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
func.func @sub_int_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %[[V2:.*]] = "TFHE.sub_int_glwe"(%[[V1]], %arg0) : (i8, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{7}>
|
||||
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "FHE.sub_int_eint"(%0, %arg0): (i8, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @add_eint(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>, %arg1: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
func.func @add_eint(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %c0 = arith.constant 0 : index
|
||||
// CHECK-NEXT: %c1 = arith.constant 1 : index
|
||||
// CHECK-NEXT: %c5 = arith.constant 5 : index
|
||||
// CHECK-NEXT: %0:2 = scf.for %arg2 = %c0 to %c5 step %c1 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
|
||||
// CHECK-NEXT: %1 = tensor.extract %arg3[%arg2] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: %2 = tensor.extract %arg4[%arg2] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: %3 = "TFHE.add_glwe"(%1, %2) : (!TFHE.glwe<{_,_,_}{7}>, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: %4 = tensor.insert %3 into %arg3[%arg2] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: scf.yield %4, %arg4 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %0#0 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @add_eint_int(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
func.func @add_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64
|
||||
// CHECK-NEXT: %1 = "TFHE.encode_plaintext_with_crt"(%0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
|
||||
// CHECK-NEXT: %c0 = arith.constant 0 : index
|
||||
// CHECK-NEXT: %c1 = arith.constant 1 : index
|
||||
// CHECK-NEXT: %c5 = arith.constant 5 : index
|
||||
// CHECK-NEXT: %2:2 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %arg0, %arg3 = %1) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5xi64>) {
|
||||
// CHECK-NEXT: %3 = tensor.extract %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: %4 = tensor.extract %arg3[%arg1] : tensor<5xi64>
|
||||
// CHECK-NEXT: %5 = "TFHE.add_glwe_int"(%3, %4) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: %6 = tensor.insert %5 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: scf.yield %6, %arg3 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5xi64>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %2#0 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @apply_lookup_table(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{3}>>
|
||||
// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
|
||||
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bootstrapBaseLog = -1 : i32, bootstrapLevel = -1 : i32, circuitBootstrapBaseLog = -1 : i32, circuitBootstrapLevel = -1 : i32, crtDecomposition = [], keyswitchBaseLog = -1 : i32, keyswitchLevel = -1 : i32, packingKeySwitchBaseLog = -1 : i32, packingKeySwitchInputLweDimension = -1 : i32, packingKeySwitchLevel = -1 : i32, packingKeySwitchoutputPolynomialSize = -1 : i32} : (tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<40960xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{3}>>
|
||||
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{3}>>
|
||||
func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<3> {
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>)
|
||||
return %1: !FHE.eint<3>
|
||||
}
|
||||
@@ -1,11 +1,10 @@
|
||||
// RUN: concretecompiler %s --passes fhe-to-tfhe --action=dump-tfhe 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> {
|
||||
//CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
|
||||
//CHECK-NEXT: %[[V1:.*]] = "TFHE.keyswitch_glwe"(%[[A0]]) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
//CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %cst) {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>, tensor<128xi64>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
//CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{7}>
|
||||
//CHECK-NEXT: }
|
||||
// CHECK: func.func @apply_lookup_table_cst(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> {
|
||||
// CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
|
||||
// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_woppbs"(%cst) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<128xi64>) -> tensor<40960xi64>
|
||||
// CHECK-NEXT: %1 = "TFHE.wop_pbs_glwe"(%arg0, %0) {bootstrapBaseLog = -1 : i32, bootstrapLevel = -1 : i32, circuitBootstrapBaseLog = -1 : i32, circuitBootstrapLevel = -1 : i32, crtDecomposition = [], keyswitchBaseLog = -1 : i32, keyswitchLevel = -1 : i32, packingKeySwitchBaseLog = -1 : i32, packingKeySwitchInputLweDimension = -1 : i32, packingKeySwitchLevel = -1 : i32, packingKeySwitchoutputPolynomialSize = -1 : i32} : (tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<40960xi64>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
func.func @apply_lookup_table_cst(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64>
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<7>, tensor<128xi64>) -> (!FHE.eint<7>)
|
||||
@@ -0,0 +1,95 @@
|
||||
// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @conv2d(%arg0: tensor<100x3x28x28x5x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4x3x14x14xi3>, %arg2: tensor<4xi3>) -> tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: %c4 = arith.constant 4 : index
|
||||
// CHECK-NEXT: %c100 = arith.constant 100 : index
|
||||
// CHECK-NEXT: %c15 = arith.constant 15 : index
|
||||
// CHECK-NEXT: %c0 = arith.constant 0 : index
|
||||
// CHECK-NEXT: %c1 = arith.constant 1 : index
|
||||
// CHECK-NEXT: %c3 = arith.constant 3 : index
|
||||
// CHECK-NEXT: %c14 = arith.constant 14 : index
|
||||
// CHECK-NEXT: %0 = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: %1 = scf.for %arg3 = %c0 to %c100 step %c1 iter_args(%arg4 = %0) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %3 = scf.for %arg5 = %c0 to %c4 step %c1 iter_args(%arg6 = %arg4) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %4 = scf.for %arg7 = %c0 to %c15 step %c1 iter_args(%arg8 = %arg6) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %5 = scf.for %arg9 = %c0 to %c15 step %c1 iter_args(%arg10 = %arg8) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %6 = tensor.extract %arg2[%arg5] : tensor<4xi3>
|
||||
// CHECK-NEXT: %c0_0 = arith.constant 0 : index
|
||||
// CHECK-NEXT: %7 = tensor.extract_slice %0[%arg3, %arg5, %arg7, %arg9, %c0_0] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: %8 = arith.extui %6 : i3 to i64
|
||||
// CHECK-NEXT: %9 = "TFHE.encode_plaintext_with_crt"(%8) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
|
||||
// CHECK-NEXT: %c0_1 = arith.constant 0 : index
|
||||
// CHECK-NEXT: %c1_2 = arith.constant 1 : index
|
||||
// CHECK-NEXT: %c5 = arith.constant 5 : index
|
||||
// CHECK-NEXT: %10:2 = scf.for %arg11 = %c0_1 to %c5 step %c1_2 iter_args(%arg12 = %7, %arg13 = %9) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<5xi64>) {
|
||||
// CHECK-NEXT: %12 = tensor.extract %arg12[%arg11] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: %13 = tensor.extract %arg13[%arg11] : tensor<5xi64>
|
||||
// CHECK-NEXT: %14 = "TFHE.add_glwe_int"(%12, %13) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
// CHECK-NEXT: %15 = tensor.insert %14 into %arg12[%arg11] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: scf.yield %15, %arg13 : tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<5xi64>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %c0_3 = arith.constant 0 : index
|
||||
// CHECK-NEXT: %11 = tensor.insert_slice %10#0 into %arg10[%arg3, %arg5, %arg7, %arg9, %c0_3] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> into tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: scf.yield %11 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %5 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %4 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %3 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %2 = scf.for %arg3 = %c0 to %c100 step %c1 iter_args(%arg4 = %1) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %3 = scf.for %arg5 = %c0 to %c4 step %c1 iter_args(%arg6 = %arg4) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %4 = scf.for %arg7 = %c0 to %c15 step %c1 iter_args(%arg8 = %arg6) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %5 = scf.for %arg9 = %c0 to %c15 step %c1 iter_args(%arg10 = %arg8) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %6 = scf.for %arg11 = %c0 to %c3 step %c1 iter_args(%arg12 = %arg10) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %7 = scf.for %arg13 = %c0 to %c14 step %c1 iter_args(%arg14 = %arg12) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %8 = scf.for %arg15 = %c0 to %c14 step %c1 iter_args(%arg16 = %arg14) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %9 = affine.apply #map(%arg7, %arg13)
|
||||
// CHECK-NEXT: %10 = affine.apply #map(%arg9, %arg15)
|
||||
// CHECK-NEXT: %c0_0 = arith.constant 0 : index
|
||||
// CHECK-NEXT: %11 = tensor.extract_slice %arg0[%arg3, %arg11, %9, %10, %c0_0] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<100x3x28x28x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: %12 = tensor.extract %arg1[%arg5, %arg11, %arg13, %arg15] : tensor<4x3x14x14xi3>
|
||||
// CHECK-NEXT: %c0_1 = arith.constant 0 : index
|
||||
// CHECK-NEXT: %13 = tensor.extract_slice %1[%arg3, %arg5, %arg7, %arg9, %c0_1] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: %14 = arith.extsi %12 : i3 to i64
|
||||
// CHECK-NEXT: %c0_2 = arith.constant 0 : index
|
||||
// CHECK-NEXT: %c1_3 = arith.constant 1 : index
|
||||
// CHECK-NEXT: %c5 = arith.constant 5 : index
|
||||
// CHECK-NEXT: %15 = scf.for %arg17 = %c0_2 to %c5 step %c1_3 iter_args(%arg18 = %11) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %18 = tensor.extract %arg18[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: %19 = "TFHE.mul_glwe_int"(%18, %14) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
// CHECK-NEXT: %20 = tensor.insert %19 into %arg18[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: scf.yield %20 : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %c0_4 = arith.constant 0 : index
|
||||
// CHECK-NEXT: %c1_5 = arith.constant 1 : index
|
||||
// CHECK-NEXT: %c5_6 = arith.constant 5 : index
|
||||
// CHECK-NEXT: %16:2 = scf.for %arg17 = %c0_4 to %c5_6 step %c1_5 iter_args(%arg18 = %13, %arg19 = %15) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<5x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %18 = tensor.extract %arg18[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: %19 = tensor.extract %arg19[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: %20 = "TFHE.add_glwe"(%18, %19) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
// CHECK-NEXT: %21 = tensor.insert %20 into %arg18[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: scf.yield %21, %arg19 : tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %c0_7 = arith.constant 0 : index
|
||||
// CHECK-NEXT: %17 = tensor.insert_slice %16#0 into %arg16[%arg3, %arg5, %arg7, %arg9, %c0_7] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> into tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: scf.yield %17 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %8 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %7 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %6 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %5 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %4 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %3 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %2 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>
|
||||
func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0, 0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @mul_eint_int(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
func.func @mul_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %c2_i8 = arith.constant 2 : i8
|
||||
// CHECK-NEXT: %0 = arith.extsi %c2_i8 : i8 to i64
|
||||
// CHECK-NEXT: %c0 = arith.constant 0 : index
|
||||
// CHECK-NEXT: %c1 = arith.constant 1 : index
|
||||
// CHECK-NEXT: %c5 = arith.constant 5 : index
|
||||
// CHECK-NEXT: %1 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %arg0) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
|
||||
// CHECK-NEXT: %2 = tensor.extract %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: %3 = "TFHE.mul_glwe_int"(%2, %0) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: %4 = tensor.insert %3 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: scf.yield %4 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
|
||||
%0 = arith.constant 2 : i8
|
||||
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @neg_eint(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %c0 = arith.constant 0 : index
|
||||
// CHECK-NEXT: %c1 = arith.constant 1 : index
|
||||
// CHECK-NEXT: %c5 = arith.constant 5 : index
|
||||
// CHECK-NEXT: %0 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %arg0) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) {
|
||||
// CHECK-NEXT: %1 = tensor.extract %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: %2 = "TFHE.neg_glwe"(%1) : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: %3 = tensor.insert %2 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: scf.yield %3 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %0 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
|
||||
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @sub_int_eint(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
func.func @sub_int_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64
|
||||
// CHECK-NEXT: %1 = "TFHE.encode_plaintext_with_crt"(%0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
|
||||
// CHECK-NEXT: %c0 = arith.constant 0 : index
|
||||
// CHECK-NEXT: %c1 = arith.constant 1 : index
|
||||
// CHECK-NEXT: %c5 = arith.constant 5 : index
|
||||
// CHECK-NEXT: %2:2 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %arg0, %arg3 = %1) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5xi64>) {
|
||||
// CHECK-NEXT: %3 = tensor.extract %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: %4 = tensor.extract %arg3[%arg1] : tensor<5xi64>
|
||||
// CHECK-NEXT: %5 = "TFHE.sub_int_glwe"(%4, %3) : (i64, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: %6 = tensor.insert %5 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
// CHECK-NEXT: scf.yield %6, %arg3 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5xi64>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %2#0 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>
|
||||
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "FHE.sub_int_eint"(%0, %arg0): (i8, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
// RUN: concretecompiler %s --passes fhe-to-tfhe --action=dump-tfhe 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @add_eint(%arg0: !TFHE.glwe<{_,_,_}{7}>, %arg1: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
func.func @add_eint(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "TFHE.add_glwe"(%arg0, %arg1) : (!TFHE.glwe<{_,_,_}{7}>, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: %[[V1:.*]] = "TFHE.add_glwe"(%arg0, %arg1) {MANP = 2 : ui3} : (!TFHE.glwe<{_,_,_}{7}>, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: return %[[V1]] : !TFHE.glwe<{_,_,_}{7}>
|
||||
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
@@ -0,0 +1,16 @@
|
||||
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @add_eint_int(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
func.func @add_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64
|
||||
// CHECK-NEXT: %c56_i64 = arith.constant 56 : i64
|
||||
// CHECK-NEXT: %1 = arith.shli %0, %c56_i64 : i64
|
||||
// CHECK-NEXT: %2 = "TFHE.add_glwe_int"(%arg0, %1) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: return %2 : !TFHE.glwe<{_,_,_}{7}>
|
||||
|
||||
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @apply_lookup_table(%arg0: !TFHE.glwe<{_,_,_}{2}>, %arg1: tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}> {
|
||||
// CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_bootstrap"(%arg1) {outputBits = 3 : i32, polySize = 256 : i32} : (tensor<4xi64>) -> tensor<256xi64>
|
||||
// CHECK-NEXT: %1 = "TFHE.keyswitch_glwe"(%arg0) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
// CHECK-NEXT: %2 = "TFHE.bootstrap_glwe"(%1, %0) {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>, tensor<256xi64>) -> !TFHE.glwe<{_,_,_}{3}>
|
||||
// CHECK-NEXT: return %2 : !TFHE.glwe<{_,_,_}{3}>
|
||||
func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<3> {
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>)
|
||||
return %1: !FHE.eint<3>
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> {
|
||||
|
||||
//CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
|
||||
//CHECK-NEXT: %0 = "TFHE.encode_expand_lut_for_bootstrap"(%cst) {outputBits = 7 : i32, polySize = 8192 : i32} : (tensor<128xi64>) -> tensor<8192xi64>
|
||||
//CHECK-NEXT: %1 = "TFHE.keyswitch_glwe"(%arg0) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
//CHECK-NEXT: %2 = "TFHE.bootstrap_glwe"(%1, %0) {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>, tensor<8192xi64>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
//CHECK-NEXT: return %2 : !TFHE.glwe<{_,_,_}{7}>
|
||||
func.func @apply_lookup_table_cst(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64>
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<7>, tensor<128xi64>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @conv2d(%arg0: tensor<100x3x28x28x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4x3x14x14xi3>, %arg2: tensor<4xi3>) -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: %c4 = arith.constant 4 : index
|
||||
// CHECK-NEXT: %c100 = arith.constant 100 : index
|
||||
// CHECK-NEXT: %c15 = arith.constant 15 : index
|
||||
// CHECK-NEXT: %c0 = arith.constant 0 : index
|
||||
// CHECK-NEXT: %c1 = arith.constant 1 : index
|
||||
// CHECK-NEXT: %c3 = arith.constant 3 : index
|
||||
// CHECK-NEXT: %c14 = arith.constant 14 : index
|
||||
// CHECK-NEXT: %0 = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: %1 = scf.for %arg3 = %c0 to %c100 step %c1 iter_args(%arg4 = %0) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %3 = scf.for %arg5 = %c0 to %c4 step %c1 iter_args(%arg6 = %arg4) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %4 = scf.for %arg7 = %c0 to %c15 step %c1 iter_args(%arg8 = %arg6) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %5 = scf.for %arg9 = %c0 to %c15 step %c1 iter_args(%arg10 = %arg8) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %6 = tensor.extract %arg2[%arg5] : tensor<4xi3>
|
||||
// CHECK-NEXT: %7 = tensor.extract %0[%arg3, %arg5, %arg7, %arg9] : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: %8 = arith.extui %6 : i3 to i64
|
||||
// CHECK-NEXT: %c61_i64 = arith.constant 61 : i64
|
||||
// CHECK-NEXT: %9 = arith.shli %8, %c61_i64 : i64
|
||||
// CHECK-NEXT: %10 = "TFHE.add_glwe_int"(%7, %9) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
// CHECK-NEXT: %11 = tensor.insert %10 into %arg10[%arg3, %arg5, %arg7, %arg9] : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: scf.yield %11 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %5 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %4 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %3 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %2 = scf.for %arg3 = %c0 to %c100 step %c1 iter_args(%arg4 = %1) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %3 = scf.for %arg5 = %c0 to %c4 step %c1 iter_args(%arg6 = %arg4) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %4 = scf.for %arg7 = %c0 to %c15 step %c1 iter_args(%arg8 = %arg6) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %5 = scf.for %arg9 = %c0 to %c15 step %c1 iter_args(%arg10 = %arg8) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %6 = scf.for %arg11 = %c0 to %c3 step %c1 iter_args(%arg12 = %arg10) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %7 = scf.for %arg13 = %c0 to %c14 step %c1 iter_args(%arg14 = %arg12) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %8 = scf.for %arg15 = %c0 to %c14 step %c1 iter_args(%arg16 = %arg14) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// CHECK-NEXT: %9 = affine.apply #map(%arg7, %arg13)
|
||||
// CHECK-NEXT: %10 = affine.apply #map(%arg9, %arg15)
|
||||
// CHECK-NEXT: %11 = tensor.extract %arg0[%arg3, %arg11, %9, %10] : tensor<100x3x28x28x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: %12 = tensor.extract %arg1[%arg5, %arg11, %arg13, %arg15] : tensor<4x3x14x14xi3>
|
||||
// CHECK-NEXT: %13 = tensor.extract %1[%arg3, %arg5, %arg7, %arg9] : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: %14 = arith.extsi %12 : i3 to i64
|
||||
// CHECK-NEXT: %15 = "TFHE.mul_glwe_int"(%11, %14) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
// CHECK-NEXT: %16 = "TFHE.add_glwe"(%13, %15) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
// CHECK-NEXT: %17 = tensor.insert %16 into %arg16[%arg3, %arg5, %arg7, %arg9] : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: scf.yield %17 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %8 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %7 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %6 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %5 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %4 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: scf.yield %3 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %2 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0, 0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @mul_eint_int(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
func.func @mul_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %c2_i8 = arith.constant 2 : i8
|
||||
// CHECK-NEXT: %0 = arith.extsi %c2_i8 : i8 to i64
|
||||
// CHECK-NEXT: %1 = "TFHE.mul_glwe_int"(%arg0, %0) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: return %1 : !TFHE.glwe<{_,_,_}{7}>
|
||||
|
||||
%0 = arith.constant 2 : i8
|
||||
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @neg_eint(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) {MANP = 1 : ui1} : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: return %0 : !TFHE.glwe<{_,_,_}{7}>
|
||||
|
||||
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @sub_int_eint(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
func.func @sub_int_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64
|
||||
// CHECK-NEXT: %c56_i64 = arith.constant 56 : i64
|
||||
// CHECK-NEXT: %1 = arith.shli %0, %c56_i64 : i64
|
||||
// CHECK-NEXT: %2 = "TFHE.sub_int_glwe"(%1, %arg0) : (i64, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: return %2 : !TFHE.glwe<{_,_,_}{7}>
|
||||
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "FHE.sub_int_eint"(%0, %arg0): (i8, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
@@ -1,22 +1,22 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @add_glwe_const_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
|
||||
//CHECK: %c1_i8 = arith.constant 1 : i8
|
||||
//CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[A0]], %c1_i8) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
//CHECK: %c1_i64 = arith.constant 1 : i64
|
||||
//CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[A0]], %c1_i64) : (!Concrete.lwe_ciphertext<1024,7>, i64) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,7>
|
||||
//CHECK: }
|
||||
func.func @add_glwe_const_int(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{7}> {
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "TFHE.add_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,1,64}{7}>, i8) -> (!TFHE.glwe<{1024,1,64}{7}>)
|
||||
%0 = arith.constant 1 : i64
|
||||
%1 = "TFHE.add_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,1,64}{7}>, i64) -> (!TFHE.glwe<{1024,1,64}{7}>)
|
||||
return %1: !TFHE.glwe<{1024,1,64}{7}>
|
||||
}
|
||||
|
||||
|
||||
//CHECK: func.func @add_glwe_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i5) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
//CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[A0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: func.func @add_glwe_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i64) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
//CHECK: %[[V0:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[A0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i64) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: }
|
||||
func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !TFHE.glwe<{1024,1,64}{4}> {
|
||||
%1 = "TFHE.add_glwe_int"(%arg0, %arg1): (!TFHE.glwe<{1024,1,64}{4}>, i5) -> (!TFHE.glwe<{1024,1,64}{4}>)
|
||||
func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i64) -> !TFHE.glwe<{1024,1,64}{4}> {
|
||||
%1 = "TFHE.add_glwe_int"(%arg0, %arg1): (!TFHE.glwe<{1024,1,64}{4}>, i64) -> (!TFHE.glwe<{1024,1,64}{4}>)
|
||||
return %1: !TFHE.glwe<{1024,1,64}{4}>
|
||||
}
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @apply_lookup_table(%arg0: tensor<4xi64>) -> tensor<1024xi64> {
|
||||
// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_bootstrap"(%arg0) {outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64>
|
||||
// CHECK-NEXT: return %0 : tensor<1024xi64>
|
||||
// CHECK-NEXT: }
|
||||
func.func @apply_lookup_table(%arg1: tensor<4xi64>) -> tensor<1024xi64> {
|
||||
%0 = "TFHE.encode_expand_lut_for_bootstrap"(%arg1) {outputBits = 3 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<1024xi64>
|
||||
return %0: tensor<1024xi64>
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @main(%arg0: tensor<4xi64>) -> tensor<40960xi64> {
|
||||
// CHECK-NEXT: %0 = "Concrete.encode_expand_lut_for_woppbs"(%arg0) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
|
||||
// CHECK-NEXT: return %0 : tensor<40960xi64>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg1: tensor<4xi64>) -> tensor<40960xi64> {
|
||||
%0 = "TFHE.encode_expand_lut_for_woppbs"(%arg1) {crtBits = [1, 2, 3, 3, 4], crtDecomposition = [2, 3, 5, 7, 11], modulusProduct = 2310 : i32, polySize = 1024 : i32} : (tensor<4xi64>) -> tensor<40960xi64>
|
||||
return %0: tensor<40960xi64>
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @main(%arg0: i64) -> tensor<5xi64> {
|
||||
// CHECK-NEXT: %0 = "Concrete.encode_plaintext_with_crt"(%arg0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
|
||||
// CHECK-NEXT: return %0 : tensor<5xi64>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg1: i64) -> tensor<5xi64> {
|
||||
%0 = "TFHE.encode_plaintext_with_crt"(%arg1) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64>
|
||||
return %0: tensor<5xi64>
|
||||
}
|
||||
@@ -1,22 +1,22 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @mul_glwe_const_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
|
||||
//CHECK: %c1_i8 = arith.constant 1 : i8
|
||||
//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%[[A0]], %c1_i8) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
//CHECK: %c1_i64 = arith.constant 1 : i64
|
||||
//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%[[A0]], %c1_i64) : (!Concrete.lwe_ciphertext<1024,7>, i64) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,7>
|
||||
//CHECK: }
|
||||
func.func @mul_glwe_const_int(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{7}> {
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,1,64}{7}>, i8) -> (!TFHE.glwe<{1024,1,64}{7}>)
|
||||
%0 = arith.constant 1 : i64
|
||||
%1 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,1,64}{7}>, i64) -> (!TFHE.glwe<{1024,1,64}{7}>)
|
||||
return %1: !TFHE.glwe<{1024,1,64}{7}>
|
||||
}
|
||||
|
||||
|
||||
//CHECK: func.func @mul_glwe_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i5) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%[[A0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: func.func @mul_glwe_int(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i64) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
//CHECK: %[[V0:.*]] = "Concrete.mul_cleartext_lwe_ciphertext"(%[[A0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i64) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: return %[[V0]] : !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: }
|
||||
func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !TFHE.glwe<{1024,1,64}{4}> {
|
||||
%1 = "TFHE.mul_glwe_int"(%arg0, %arg1): (!TFHE.glwe<{1024,1,64}{4}>, i5) -> (!TFHE.glwe<{1024,1,64}{4}>)
|
||||
func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i64) -> !TFHE.glwe<{1024,1,64}{4}> {
|
||||
%1 = "TFHE.mul_glwe_int"(%arg0, %arg1): (!TFHE.glwe<{1024,1,64}{4}>, i64) -> (!TFHE.glwe<{1024,1,64}{4}>)
|
||||
return %1: !TFHE.glwe<{1024,1,64}{4}>
|
||||
}
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
// RUN: concretecompiler --passes tfhe-to-concrete --action=dump-concrete %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @sub_const_int_glwe(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
|
||||
//CHECK: %c1_i8 = arith.constant 1 : i8
|
||||
//CHECK: %c1_i64 = arith.constant 1 : i64
|
||||
//CHECK: %[[V0:.*]] = "Concrete.negate_lwe_ciphertext"(%[[A0]]) : (!Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
//CHECK: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[V0]], %c1_i8) : (!Concrete.lwe_ciphertext<1024,7>, i8) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
//CHECK: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[V0]], %c1_i64) : (!Concrete.lwe_ciphertext<1024,7>, i64) -> !Concrete.lwe_ciphertext<1024,7>
|
||||
//CHECK: return %[[V1]] : !Concrete.lwe_ciphertext<1024,7>
|
||||
//CHECK: }
|
||||
func.func @sub_const_int_glwe(%arg0: !TFHE.glwe<{1024,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{7}> {
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "TFHE.sub_int_glwe"(%0, %arg0): (i8, !TFHE.glwe<{1024,1,64}{7}>) -> (!TFHE.glwe<{1024,1,64}{7}>)
|
||||
%0 = arith.constant 1 : i64
|
||||
%1 = "TFHE.sub_int_glwe"(%0, %arg0): (i64, !TFHE.glwe<{1024,1,64}{7}>) -> (!TFHE.glwe<{1024,1,64}{7}>)
|
||||
return %1: !TFHE.glwe<{1024,1,64}{7}>
|
||||
}
|
||||
|
||||
//CHECK: func.func @sub_int_glwe(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i5) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
//CHECK: func.func @sub_int_glwe(%[[A0:.*]]: !Concrete.lwe_ciphertext<1024,4>, %[[A1:.*]]: i64) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
//CHECK: %[[V0:.*]] = "Concrete.negate_lwe_ciphertext"(%[[A0]]) : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[V0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i5) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: %[[V1:.*]] = "Concrete.add_plaintext_lwe_ciphertext"(%[[V0]], %[[A1]]) : (!Concrete.lwe_ciphertext<1024,4>, i64) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: return %[[V1]] : !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: }
|
||||
func.func @sub_int_glwe(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !TFHE.glwe<{1024,1,64}{4}> {
|
||||
%1 = "TFHE.sub_int_glwe"(%arg1, %arg0): (i5, !TFHE.glwe<{1024,1,64}{4}>) -> (!TFHE.glwe<{1024,1,64}{4}>)
|
||||
func.func @sub_int_glwe(%arg0: !TFHE.glwe<{1024,1,64}{4}>, %arg1: i64) -> !TFHE.glwe<{1024,1,64}{4}> {
|
||||
%1 = "TFHE.sub_int_glwe"(%arg1, %arg0): (i64, !TFHE.glwe<{1024,1,64}{4}>) -> (!TFHE.glwe<{1024,1,64}{4}>)
|
||||
return %1: !TFHE.glwe<{1024,1,64}{4}>
|
||||
}
|
||||
|
||||
@@ -9,15 +9,6 @@ func.func @add_lwe_ciphertexts(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>)
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @add_crt_lwe_ciphertexts(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x2049xi64>
|
||||
//CHECK: }
|
||||
func.func @add_crt_lwe_ciphertexts(%arg0: tensor<5x2049xi64>, %arg1: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
|
||||
%0 = "BConcrete.add_crt_lwe_tensor"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> ( tensor<5x2049xi64>)
|
||||
return %0 : tensor<5x2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @add_plaintext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
@@ -27,15 +18,6 @@ func.func @add_plaintext_lwe_ciphertext(%arg0: tensor<2049xi64>, %arg1: i64) ->
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @add_plaintext_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: i64) -> tensor<5x2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x2049xi64>
|
||||
//CHECK: }
|
||||
func.func @add_plaintext_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>, %arg1: i64) -> tensor<5x2049xi64> {
|
||||
%0 = "BConcrete.add_plaintext_crt_lwe_tensor"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> ( tensor<5x2049xi64>)
|
||||
return %0 : tensor<5x2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func @mul_cleartext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
@@ -45,15 +27,6 @@ func.func @mul_cleartext_lwe_ciphertext(%arg0: tensor<2049xi64>, %arg1: i64) ->
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @mul_cleartext_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: i64) -> tensor<5x2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x2049xi64>
|
||||
//CHECK: }
|
||||
func.func @mul_cleartext_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>, %arg1: i64) -> tensor<5x2049xi64> {
|
||||
%0 = "BConcrete.mul_cleartext_crt_lwe_tensor"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> (tensor<5x2049xi64>)
|
||||
return %0 : tensor<5x2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @negate_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.negate_lwe_tensor"(%[[A0]]) : (tensor<2049xi64>) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
@@ -63,15 +36,6 @@ func.func @negate_lwe_ciphertext(%arg0: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @negate_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_tensor"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> tensor<5x2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x2049xi64>
|
||||
//CHECK: }
|
||||
func.func @negate_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
|
||||
%0 = "BConcrete.negate_crt_lwe_tensor"(%arg0) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> (tensor<5x2049xi64>)
|
||||
return %0 : tensor<5x2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.bootstrap_lwe_tensor"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
|
||||
@@ -13,12 +13,6 @@ func.func @type_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Conc
|
||||
return %arg0: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @type_lwe_ciphertext_with_crt(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>
|
||||
func.func @type_lwe_ciphertext_with_crt(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7> {
|
||||
// CHECK-NEXT: return %arg0 : !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>
|
||||
return %arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @type_cleartext(%arg0: !Concrete.cleartext<5>) -> !Concrete.cleartext<5>
|
||||
func.func @type_cleartext(%arg0: !Concrete.cleartext<5>) -> !Concrete.cleartext<5> {
|
||||
// CHECK-NEXT: return %arg0 : !Concrete.cleartext<5>
|
||||
|
||||
@@ -1,21 +1,24 @@
|
||||
// RUN: concretecompiler %s --action=dump-tfhe 2>&1 | FileCheck %s
|
||||
// RUN: concretecompiler %s --action=dump-fhe-no-linalg 2>&1 | FileCheck %s
|
||||
|
||||
//CHECK: #map0 = affine_map<(d0) -> (d0)>
|
||||
//CHECK-NEXT: #map1 = affine_map<(d0) -> (0)>
|
||||
//CHECK-NEXT: module {
|
||||
//CHECK-NEXT: func.func @dot_eint_int(%arg0: tensor<2x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<2xi3>) -> !TFHE.glwe<{_,_,_}{2}> {
|
||||
//CHECK-NEXT: %c0 = arith.constant 0 : index
|
||||
//CHECK-NEXT: %0 = "TFHE.zero_tensor"() : () -> tensor<1x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!TFHE.glwe<{_,_,_}{2}>>, tensor<2xi3>) outs(%0 : tensor<1x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: ^bb0(%arg2: !TFHE.glwe<{_,_,_}{2}>, %arg3: i3, %arg4: !TFHE.glwe<{_,_,_}{2}>):
|
||||
//CHECK-NEXT: %3 = "TFHE.mul_glwe_int"(%arg2, %arg3) : (!TFHE.glwe<{_,_,_}{2}>, i3) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: %4 = "TFHE.add_glwe"(%3, %arg4) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: linalg.yield %4 : !TFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: } -> tensor<1x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %2 = tensor.extract %1[%c0] : tensor<1x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: return %2 : !TFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: }
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: func.func @dot_eint_int(%arg0: tensor<2x!FHE.eint<2>>, %arg1: tensor<2xi3>) -> !FHE.eint<2> {
|
||||
// CHECK-NEXT: %c2 = arith.constant 2 : index
|
||||
// CHECK-NEXT: %c0 = arith.constant 0 : index
|
||||
// CHECK-NEXT: %c1 = arith.constant 1 : index
|
||||
// CHECK-NEXT: %0 = "FHE.zero_tensor"() : () -> tensor<1x!FHE.eint<2>>
|
||||
// CHECK-NEXT: %1 = scf.for %arg2 = %c0 to %c2 step %c1 iter_args(%arg3 = %0) -> (tensor<1x!FHE.eint<2>>) {
|
||||
// CHECK-NEXT: %3 = tensor.extract %arg0[%arg2] : tensor<2x!FHE.eint<2>>
|
||||
// CHECK-NEXT: %4 = tensor.extract %arg1[%arg2] : tensor<2xi3>
|
||||
// CHECK-NEXT: %5 = tensor.extract %0[%c0] : tensor<1x!FHE.eint<2>>
|
||||
// CHECK-NEXT: %6 = "FHE.mul_eint_int"(%3, %4) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %7 = "FHE.add_eint"(%6, %5) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %8 = tensor.insert %7 into %arg3[%c0] : tensor<1x!FHE.eint<2>>
|
||||
// CHECK-NEXT: scf.yield %8 : tensor<1x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %2 = tensor.extract %1[%c0] : tensor<1x!FHE.eint<2>>
|
||||
// CHECK-NEXT: return %2 : !FHE.eint<2>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
func.func @dot_eint_int(%arg0: tensor<2x!FHE.eint<2>>,
|
||||
%arg1: tensor<2xi3>) -> !FHE.eint<2>
|
||||
{
|
||||
|
||||
@@ -51,21 +51,3 @@ func.func @add_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>, %arg1: !TFHE.glwe<{1024,
|
||||
%1 = "TFHE.add_glwe"(%arg0, %arg1): (!TFHE.glwe<{1024,12,64}{7}>, !TFHE.glwe<{1024,11,64}{7}>) -> (!TFHE.glwe<{1024,12,64}{7}>)
|
||||
return %1: !TFHE.glwe<{1024,12,64}{7}>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// GLWE polynomialSize parameter result
|
||||
func.func @add_glwe(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, %arg1: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}> {
|
||||
// expected-error @+1 {{'TFHE.add_glwe' op should have the same GLWE 'crt' parameter}}
|
||||
%1 = "TFHE.add_glwe"(%arg0, %arg1): (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> (!TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>)
|
||||
return %1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// GLWE polynomialSize parameter inputs
|
||||
func.func @add_glwe(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, %arg1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}> {
|
||||
// expected-error @+1 {{'TFHE.add_glwe' op should have the same GLWE 'crt' parameter}}
|
||||
%1 = "TFHE.add_glwe"(%arg0, %arg1): (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>) -> (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>)
|
||||
return %1: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>
|
||||
}
|
||||
|
||||
@@ -27,23 +27,3 @@ func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,
|
||||
%1 = "TFHE.add_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,12,64}{7}>, i8) -> (!TFHE.glwe<{1024,11,64}{7}>)
|
||||
return %1: !TFHE.glwe<{1024,11,64}{7}>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// GLWE crt parameter
|
||||
func.func @add_glwe_int(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}> {
|
||||
%0 = arith.constant 1 : i8
|
||||
// expected-error @+1 {{'TFHE.add_glwe_int' op should have the same GLWE 'crt' parameter}}
|
||||
%1 = "TFHE.add_glwe_int"(%arg0, %0): (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, i8) -> (!TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>)
|
||||
return %1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// integer width doesn't match GLWE parameter
|
||||
func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}> {
|
||||
%0 = arith.constant 1 : i9
|
||||
// expected-error @+1 {{'TFHE.add_glwe_int' op should have the width of `b` equals or less than 'p'+1}}
|
||||
%1 = "TFHE.add_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,12,64}{7}>, i9) -> (!TFHE.glwe<{1024,12,64}{7}>)
|
||||
return %1: !TFHE.glwe<{1024,12,64}{7}>
|
||||
}
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
// CHECK-LABEL: func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}>
|
||||
func.func @add_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %[[V2:.*]] = "TFHE.add_glwe_int"(%arg0, %[[V1]]) : (!TFHE.glwe<{1024,12,64}{7}>, i8) -> !TFHE.glwe<{1024,12,64}{7}>
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i64
|
||||
// CHECK-NEXT: %[[V2:.*]] = "TFHE.add_glwe_int"(%arg0, %[[V1]]) : (!TFHE.glwe<{1024,12,64}{7}>, i64) -> !TFHE.glwe<{1024,12,64}{7}>
|
||||
// CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{1024,12,64}{7}>
|
||||
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "TFHE.add_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,12,64}{7}>, i8) -> (!TFHE.glwe<{1024,12,64}{7}>)
|
||||
%0 = arith.constant 1 : i64
|
||||
%1 = "TFHE.add_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,12,64}{7}>, i64) -> (!TFHE.glwe<{1024,12,64}{7}>)
|
||||
return %1: !TFHE.glwe<{1024,12,64}{7}>
|
||||
}
|
||||
|
||||
@@ -27,23 +27,3 @@ func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,
|
||||
%1 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,12,64}{7}>, i8) -> (!TFHE.glwe<{1024,11,64}{7}>)
|
||||
return %1: !TFHE.glwe<{1024,11,64}{7}>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// GLWE crt parameter
|
||||
func.func @mul_glwe_int(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}> {
|
||||
%0 = arith.constant 1 : i8
|
||||
// expected-error @+1 {{'TFHE.mul_glwe_int' op should have the same GLWE 'crt' parameter}}
|
||||
%1 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>, i8) -> (!TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>)
|
||||
return %1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// integer width doesn't match GLWE parameter
|
||||
func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}> {
|
||||
%0 = arith.constant 1 : i9
|
||||
// expected-error @+1 {{'TFHE.mul_glwe_int' op should have the width of `b` equals or less than 'p'+1}}
|
||||
%1 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,12,64}{7}>, i9) -> (!TFHE.glwe<{1024,12,64}{7}>)
|
||||
return %1: !TFHE.glwe<{1024,12,64}{7}>
|
||||
}
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
// CHECK-LABEL: func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}>
|
||||
func.func @mul_glwe_int(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %[[V2:.*]] = "TFHE.mul_glwe_int"(%arg0, %[[V1]]) : (!TFHE.glwe<{1024,12,64}{7}>, i8) -> !TFHE.glwe<{1024,12,64}{7}>
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i64
|
||||
// CHECK-NEXT: %[[V2:.*]] = "TFHE.mul_glwe_int"(%arg0, %[[V1]]) : (!TFHE.glwe<{1024,12,64}{7}>, i64) -> !TFHE.glwe<{1024,12,64}{7}>
|
||||
// CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{1024,12,64}{7}>
|
||||
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,12,64}{7}>, i8) -> (!TFHE.glwe<{1024,12,64}{7}>)
|
||||
%0 = arith.constant 1 : i64
|
||||
%1 = "TFHE.mul_glwe_int"(%arg0, %0): (!TFHE.glwe<{1024,12,64}{7}>, i64) -> (!TFHE.glwe<{1024,12,64}{7}>)
|
||||
return %1: !TFHE.glwe<{1024,12,64}{7}>
|
||||
}
|
||||
|
||||
@@ -27,15 +27,6 @@ func.func @neg_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,11,6
|
||||
|
||||
// -----
|
||||
|
||||
// GLWE crt parameter
|
||||
func.func @neg_glwe(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}> {
|
||||
// expected-error @+1 {{'TFHE.neg_glwe' op should have the same GLWE 'crt' parameter}}
|
||||
%1 = "TFHE.neg_glwe"(%arg0): (!TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> (!TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>)
|
||||
return %1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// integer width doesn't match GLWE parameter
|
||||
func.func @neg_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,11,64}{7}> {
|
||||
// expected-error @+1 {{'TFHE.neg_glwe' op should have the same GLWE 'polynomialSize' parameter}}
|
||||
|
||||
@@ -28,16 +28,6 @@ func.func @sub_int_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,
|
||||
return %1: !TFHE.glwe<{1024,11,64}{7}>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// GLWE polynomialSize parameter
|
||||
func.func @sub_int_glwe(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}> {
|
||||
%0 = arith.constant 1 : i8
|
||||
// expected-error @+1 {{'TFHE.sub_int_glwe' op should have the same GLWE 'crt' parameter}}
|
||||
%1 = "TFHE.sub_int_glwe"(%0, %arg0): (i8, !TFHE.glwe<crt=[2,3,5,7,11]{1024,12,64}{7}>) -> (!TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>)
|
||||
return %1: !TFHE.glwe<crt=[7,3,5,7,11]{1024,12,64}{7}>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
// CHECK-LABEL: func.func @sub_int_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}>
|
||||
func.func @sub_int_glwe(%arg0: !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %[[V2:.*]] = "TFHE.sub_int_glwe"(%[[V1]], %arg0) : (i8, !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}>
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i64
|
||||
// CHECK-NEXT: %[[V2:.*]] = "TFHE.sub_int_glwe"(%[[V1]], %arg0) : (i64, !TFHE.glwe<{1024,12,64}{7}>) -> !TFHE.glwe<{1024,12,64}{7}>
|
||||
// CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{1024,12,64}{7}>
|
||||
|
||||
%0 = arith.constant 1 : i8
|
||||
%1 = "TFHE.sub_int_glwe"(%0, %arg0): (i8, !TFHE.glwe<{1024,12,64}{7}>) -> (!TFHE.glwe<{1024,12,64}{7}>)
|
||||
%0 = arith.constant 1 : i64
|
||||
%1 = "TFHE.sub_int_glwe"(%0, %arg0): (i64, !TFHE.glwe<{1024,12,64}{7}>) -> (!TFHE.glwe<{1024,12,64}{7}>)
|
||||
return %1: !TFHE.glwe<{1024,12,64}{7}>
|
||||
}
|
||||
|
||||
@@ -11,15 +11,3 @@ func.func @glwe_1(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> {
|
||||
// CHECK-LABEL: return %arg0 : !TFHE.glwe<{_,_,_}{7}>
|
||||
return %arg0: !TFHE.glwe<{_,_,_}{7}>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @glwe_crt(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}>) -> !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}>
|
||||
func.func @glwe_crt(%arg0: !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}>) -> !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}> {
|
||||
// CHECK-LABEL: return %arg0 : !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}>
|
||||
return %arg0: !TFHE.glwe<crt=[2,3,5,7,11]{_,_,_}{7}>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @glwe_crt_undef(%arg0: !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}>) -> !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}>
|
||||
func.func @glwe_crt_undef(%arg0: !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}>) -> !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}> {
|
||||
// CHECK-LABEL: return %arg0 : !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}>
|
||||
return %arg0: !TFHE.glwe<crt=[_,_,_,_,_]{_,_,_}{7}>
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --split-input-file --action=dump-concrete-with-loops --batch-concrete-ops %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --split-input-file --action=dump-concrete --batch-concrete-ops %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @batch_continuous_slice_keyswitch(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<572,2>>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<572,2>> {
|
||||
func.func @batch_continuous_slice_keyswitch(%arg0: tensor<2x3x4x!Concrete.lwe_ciphertext<572,2>>) -> tensor<2x3x4x!Concrete.lwe_ciphertext<572,2>> {
|
||||
|
||||
Reference in New Issue
Block a user