mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
chore: rename dialects
HLFHE to FHE MidLFHE to TFHE LowLFHE to Concrete
This commit is contained in:
@@ -68,8 +68,8 @@ build-end-to-end-jit-clear-tensor: build-initialized
|
||||
build-end-to-end-jit-encrypted-tensor: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_jit_encrypted_tensor
|
||||
|
||||
build-end-to-end-jit-hlfhelinalg: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_jit_hlfhelinalg
|
||||
build-end-to-end-jit-fhelinalg: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_jit_fhelinalg
|
||||
|
||||
build-end-to-end-jit-lambda: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_jit_lambda
|
||||
@@ -80,7 +80,7 @@ build-end-to-end-jit-dfr: build-initialized
|
||||
build-end-to-end-jit-auto-parallelization: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_jit_auto_parallelization
|
||||
|
||||
build-end-to-end-jit: build-end-to-end-jit-test build-end-to-end-jit-clear-tensor build-end-to-end-jit-encrypted-tensor build-end-to-end-jit-hlfhelinalg
|
||||
build-end-to-end-jit: build-end-to-end-jit-test build-end-to-end-jit-clear-tensor build-end-to-end-jit-encrypted-tensor build-end-to-end-jit-fhelinalg
|
||||
|
||||
|
||||
test-end-to-end-jit-test: build-end-to-end-jit-test
|
||||
@@ -92,8 +92,8 @@ test-end-to-end-jit-clear-tensor: build-end-to-end-jit-clear-tensor
|
||||
test-end-to-end-jit-encrypted-tensor: build-end-to-end-jit-encrypted-tensor
|
||||
$(BUILD_DIR)/bin/end_to_end_jit_encrypted_tensor
|
||||
|
||||
test-end-to-end-jit-hlfhelinalg: build-end-to-end-jit-hlfhelinalg
|
||||
$(BUILD_DIR)/bin/end_to_end_jit_hlfhelinalg
|
||||
test-end-to-end-jit-fhelinalg: build-end-to-end-jit-fhelinalg
|
||||
$(BUILD_DIR)/bin/end_to_end_jit_fhelinalg
|
||||
|
||||
test-end-to-end-jit-lambda: build-initialized build-end-to-end-jit-lambda
|
||||
$(BUILD_DIR)/bin/end_to_end_jit_lambda
|
||||
@@ -104,7 +104,7 @@ test-end-to-end-jit-dfr: build-end-to-end-jit-dfr
|
||||
test-end-to-end-jit-auto-parallelization: build-end-to-end-jit-auto-parallelization
|
||||
$(BUILD_DIR)/bin/end_to_end_jit_auto_parallelization
|
||||
|
||||
test-end-to-end-jit: test-end-to-end-jit-test test-end-to-end-jit-clear-tensor test-end-to-end-jit-encrypted-tensor test-end-to-end-jit-hlfhelinalg
|
||||
test-end-to-end-jit: test-end-to-end-jit-test test-end-to-end-jit-clear-tensor test-end-to-end-jit-encrypted-tensor test-end-to-end-jit-fhelinalg
|
||||
|
||||
show-stress-tests-summary:
|
||||
@echo '------ Stress tests summary ------'
|
||||
@@ -163,7 +163,7 @@ update_python_version:
|
||||
test-end-to-end-jit-test \
|
||||
test-end-to-end-jit-clear-tensor \
|
||||
test-end-to-end-jit-encrypted-tensor \
|
||||
test-end-to-end-jit-hlfhelinalg \
|
||||
test-end-to-end-jit-fhelinalg \
|
||||
test-python \
|
||||
test \
|
||||
add-deps \
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
#ifndef CONCRETELANG_C_DIALECT_HLFHE_H
|
||||
#define CONCRETELANG_C_DIALECT_HLFHE_H
|
||||
#ifndef CONCRETELANG_C_DIALECT_FHE_H
|
||||
#define CONCRETELANG_C_DIALECT_FHE_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
@@ -13,18 +13,18 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(HLFHE, hlfhe);
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(FHE, fhe);
|
||||
|
||||
/// Creates an encrypted integer type of `width` bits
|
||||
MLIR_CAPI_EXPORTED MlirType hlfheEncryptedIntegerTypeGetChecked(
|
||||
MLIR_CAPI_EXPORTED MlirType fheEncryptedIntegerTypeGetChecked(
|
||||
MlirContext context, unsigned width,
|
||||
mlir::function_ref<mlir::InFlightDiagnostic()> emitError);
|
||||
|
||||
/// If the type is an EncryptedInteger
|
||||
MLIR_CAPI_EXPORTED bool hlfheTypeIsAnEncryptedIntegerType(MlirType);
|
||||
MLIR_CAPI_EXPORTED bool fheTypeIsAnEncryptedIntegerType(MlirType);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // CONCRETELANG_C_DIALECT_HLFHE_H
|
||||
#endif // CONCRETELANG_C_DIALECT_FHE_H
|
||||
@@ -1,8 +1,8 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
#ifndef CONCRETELANG_C_DIALECT_HLFHELINALG_H
|
||||
#define CONCRETELANG_C_DIALECT_HLFHELINALG_H
|
||||
#ifndef CONCRETELANG_C_DIALECT_FHELINALG_H
|
||||
#define CONCRETELANG_C_DIALECT_FHELINALG_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
@@ -13,10 +13,10 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(HLFHELinalg, hlfhelinalg);
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(FHELinalg, fhelinalg);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // CONCRETELANG_C_DIALECT_HLFHELINALG_H
|
||||
#endif // CONCRETELANG_C_DIALECT_FHELINALG_H
|
||||
@@ -3,5 +3,5 @@ mlir_tablegen(Passes.h.inc -gen-pass-decls -name Conversion)
|
||||
add_public_tablegen_target(ConcretelangConversionPassIncGen)
|
||||
|
||||
|
||||
add_subdirectory(HLFHEToMidLFHE)
|
||||
add_subdirectory(MidLFHEToLowLFHE)
|
||||
add_subdirectory(FHEToTFHE)
|
||||
add_subdirectory(TFHEToConcrete)
|
||||
@@ -2,8 +2,8 @@
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_LOWLFHETOCONCRETECAPI_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_LOWLFHETOCONCRETECAPI_PASS_H_
|
||||
#ifndef CONCRETELANG_CONVERSION_CONCRETETOCONCRETECAPI_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_CONCRETETOCONCRETECAPI_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
@@ -11,10 +11,10 @@
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `LowLFHE` operators to function call to the
|
||||
/// Create a pass to convert `Concrete` operators to function call to the
|
||||
/// `ConcreteCAPI`
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertLowLFHEToConcreteCAPIPass();
|
||||
createConvertConcreteToConcreteCAPIPass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -2,15 +2,15 @@
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_LOWLFHEUNPARAMETRIZE_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_LOWLFHEUNPARAMETRIZE_PASS_H_
|
||||
#ifndef CONCRETELANG_CONVERSION_CONCRETEUNPARAMETRIZE_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_CONCRETEUNPARAMETRIZE_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertLowLFHEUnparametrizePass();
|
||||
createConvertConcreteUnparametrizePass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -2,16 +2,16 @@
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_HLFHETENSOROPSTOLINALG_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_HLFHETENSOROPSTOLINALG_PASS_H_
|
||||
#ifndef CONCRETELANG_CONVERSION_FHETENSOROPSTOLINALG_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_FHETENSOROPSTOLINALG_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `HLFHE` tensor operators to linal.generic
|
||||
/// Create a pass to convert `FHE` tensor operators to linal.generic
|
||||
/// operators.
|
||||
std::unique_ptr<mlir::FunctionPass> createConvertHLFHETensorOpsToLinalg();
|
||||
std::unique_ptr<mlir::FunctionPass> createConvertFHETensorOpsToLinalg();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Patterns.td)
|
||||
mlir_tablegen(Patterns.h.inc -gen-rewriters -name FHE)
|
||||
add_public_tablegen_target(FHEToTFHEPatternsIncGen)
|
||||
|
||||
add_concretelang_doc(Patterns FHEToTFHEPatterns concretelang/ -gen-pass-doc)
|
||||
@@ -2,15 +2,15 @@
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_MIDLFHETOLOWLFHE_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_MIDLFHETOLOWLFHE_PASS_H_
|
||||
#ifndef CONCRETELANG_CONVERSION_FHETOTFHE_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_FHETOTFHE_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `MidLFHE` dialect to `LowLFHE` dialect.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertMidLFHEToLowLFHEPass();
|
||||
/// Create a pass to convert `FHE` dialect to `TFHE` dialect.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertFHEToTFHEPass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -1,50 +1,50 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_HLFHETOMIDLFHE_PATTERNS_H_
|
||||
#define CONCRETELANG_CONVERSION_HLFHETOMIDLFHE_PATTERNS_H_
|
||||
#ifndef CONCRETELANG_CONVERSION_FHETOTFHE_PATTERNS_H_
|
||||
#define CONCRETELANG_CONVERSION_FHETOTFHE_PATTERNS_H_
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "concretelang/Dialect/HLFHE/IR/HLFHEOps.h"
|
||||
#include "concretelang/Dialect/MidLFHE/IR/MidLFHEOps.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
using HLFHE::EncryptedIntegerType;
|
||||
using MidLFHE::GLWECipherTextType;
|
||||
using FHE::EncryptedIntegerType;
|
||||
using TFHE::GLWECipherTextType;
|
||||
|
||||
/// Converts HLFHE::EncryptedInteger into MidLFHE::GlweCiphetext
|
||||
/// Converts FHE::EncryptedInteger into TFHE::GlweCiphetext
|
||||
GLWECipherTextType
|
||||
convertTypeEncryptedIntegerToGLWE(mlir::MLIRContext *context,
|
||||
EncryptedIntegerType &eint) {
|
||||
return GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth());
|
||||
}
|
||||
|
||||
mlir::Value createZeroGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter,
|
||||
mlir::Value createZeroGLWEOpFromFHE(mlir::PatternRewriter &rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::OpResult result) {
|
||||
mlir::SmallVector<mlir::Value> args{};
|
||||
mlir::SmallVector<mlir::NamedAttribute, 0> attrs;
|
||||
auto eint =
|
||||
result.getType().cast<mlir::concretelang::HLFHE::EncryptedIntegerType>();
|
||||
result.getType().cast<mlir::concretelang::FHE::EncryptedIntegerType>();
|
||||
mlir::SmallVector<mlir::Type, 1> resTypes{
|
||||
convertTypeEncryptedIntegerToGLWE(rewriter.getContext(), eint)};
|
||||
MidLFHE::ZeroGLWEOp op =
|
||||
rewriter.create<MidLFHE::ZeroGLWEOp>(loc, resTypes, args, attrs);
|
||||
TFHE::ZeroGLWEOp op =
|
||||
rewriter.create<TFHE::ZeroGLWEOp>(loc, resTypes, args, attrs);
|
||||
return op.getODSResults(0).front();
|
||||
}
|
||||
|
||||
template <class Operator>
|
||||
mlir::Value createGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter,
|
||||
mlir::Value createGLWEOpFromFHE(mlir::PatternRewriter &rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
mlir::Value arg1, mlir::OpResult result) {
|
||||
mlir::SmallVector<mlir::Value, 2> args{arg0, arg1};
|
||||
mlir::SmallVector<mlir::NamedAttribute, 0> attrs;
|
||||
auto eint =
|
||||
result.getType().cast<mlir::concretelang::HLFHE::EncryptedIntegerType>();
|
||||
result.getType().cast<mlir::concretelang::FHE::EncryptedIntegerType>();
|
||||
mlir::SmallVector<mlir::Type, 1> resTypes{
|
||||
convertTypeEncryptedIntegerToGLWE(rewriter.getContext(), eint)};
|
||||
Operator op = rewriter.create<Operator>(loc, resTypes, args, attrs);
|
||||
@@ -52,13 +52,13 @@ mlir::Value createGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter,
|
||||
}
|
||||
|
||||
template <class Operator>
|
||||
mlir::Value createGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter,
|
||||
mlir::Value createGLWEOpFromFHE(mlir::PatternRewriter &rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
mlir::OpResult result) {
|
||||
mlir::SmallVector<mlir::Value, 1> args{arg0};
|
||||
mlir::SmallVector<mlir::NamedAttribute, 0> attrs;
|
||||
auto eint =
|
||||
result.getType().cast<mlir::concretelang::HLFHE::EncryptedIntegerType>();
|
||||
result.getType().cast<mlir::concretelang::FHE::EncryptedIntegerType>();
|
||||
mlir::SmallVector<mlir::Type, 1> resTypes{
|
||||
convertTypeEncryptedIntegerToGLWE(rewriter.getContext(), eint)};
|
||||
Operator op = rewriter.create<Operator>(loc, resTypes, args, attrs);
|
||||
@@ -66,7 +66,7 @@ mlir::Value createGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter,
|
||||
}
|
||||
|
||||
mlir::Value
|
||||
createApplyLookupTableGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter,
|
||||
createApplyLookupTableGLWEOpFromFHE(mlir::PatternRewriter &rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
mlir::Value arg1, mlir::OpResult result) {
|
||||
mlir::SmallVector<mlir::Value, 2> args{arg0, arg1};
|
||||
@@ -86,10 +86,10 @@ createApplyLookupTableGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter,
|
||||
unset),
|
||||
};
|
||||
auto eint =
|
||||
result.getType().cast<mlir::concretelang::HLFHE::EncryptedIntegerType>();
|
||||
result.getType().cast<mlir::concretelang::FHE::EncryptedIntegerType>();
|
||||
mlir::SmallVector<mlir::Type, 1> resTypes{
|
||||
convertTypeEncryptedIntegerToGLWE(rewriter.getContext(), eint)};
|
||||
auto op = rewriter.create<concretelang::MidLFHE::ApplyLookupTable>(loc, resTypes,
|
||||
auto op = rewriter.create<concretelang::TFHE::ApplyLookupTable>(loc, resTypes,
|
||||
args, attrs);
|
||||
return op.getODSResults(0).front();
|
||||
}
|
||||
@@ -98,10 +98,10 @@ createApplyLookupTableGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter,
|
||||
} // namespace mlir
|
||||
|
||||
namespace {
|
||||
#include "concretelang/Conversion/HLFHEToMidLFHE/Patterns.h.inc"
|
||||
#include "concretelang/Conversion/FHEToTFHE/Patterns.h.inc"
|
||||
}
|
||||
|
||||
void populateWithGeneratedHLFHEToMidLFHE(mlir::RewritePatternSet &patterns) {
|
||||
void populateWithGeneratedFHEToTFHE(mlir::RewritePatternSet &patterns) {
|
||||
populateWithGenerated(patterns);
|
||||
}
|
||||
|
||||
@@ -1,47 +1,47 @@
|
||||
#ifndef CONCRETELANG_CONVERSION_HLFHETOMIDLFHE_PATTERNS
|
||||
#define CONCRETELANG_CONVERSION_HLFHETOMIDLFHE_PATTERNS
|
||||
#ifndef CONCRETELANG_CONVERSION_FHETOTFHE_PATTERNS
|
||||
#define CONCRETELANG_CONVERSION_FHETOTFHE_PATTERNS
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
include "concretelang/Dialect/HLFHE/IR/HLFHEOps.td"
|
||||
include "concretelang/Dialect/MidLFHE/IR/MidLFHEOps.td"
|
||||
include "concretelang/Dialect/FHE/IR/FHEOps.td"
|
||||
include "concretelang/Dialect/TFHE/IR/TFHEOps.td"
|
||||
|
||||
def createZeroGLWEOp : NativeCodeCall<"mlir::concretelang::createZeroGLWEOpFromHLFHE($_builder, $_loc, $0)">;
|
||||
def createZeroGLWEOp : NativeCodeCall<"mlir::concretelang::createZeroGLWEOpFromFHE($_builder, $_loc, $0)">;
|
||||
|
||||
def ZeroEintPattern : Pat<
|
||||
(ZeroEintOp:$result),
|
||||
(createZeroGLWEOp $result)>;
|
||||
|
||||
def createAddGLWEIntOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromHLFHE<mlir::concretelang::MidLFHE::AddGLWEIntOp>($_builder, $_loc, $0, $1, $2)">;
|
||||
def createAddGLWEIntOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromFHE<mlir::concretelang::TFHE::AddGLWEIntOp>($_builder, $_loc, $0, $1, $2)">;
|
||||
|
||||
def AddEintIntPattern : Pat<
|
||||
(AddEintIntOp:$result $arg0, $arg1),
|
||||
(createAddGLWEIntOp $arg0, $arg1, $result)>;
|
||||
|
||||
def createAddGLWEOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromHLFHE<mlir::concretelang::MidLFHE::AddGLWEOp>($_builder, $_loc, $0, $1, $2)">;
|
||||
def createAddGLWEOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromFHE<mlir::concretelang::TFHE::AddGLWEOp>($_builder, $_loc, $0, $1, $2)">;
|
||||
|
||||
def AddEintPattern : Pat<
|
||||
(AddEintOp:$result $arg0, $arg1),
|
||||
(createAddGLWEOp $arg0, $arg1, $result)>;
|
||||
|
||||
def createSubIntGLWEOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromHLFHE<mlir::concretelang::MidLFHE::SubIntGLWEOp>($_builder, $_loc, $0, $1, $2)">;
|
||||
def createSubIntGLWEOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromFHE<mlir::concretelang::TFHE::SubIntGLWEOp>($_builder, $_loc, $0, $1, $2)">;
|
||||
|
||||
def SubIntEintPattern : Pat<
|
||||
(SubIntEintOp:$result $arg0, $arg1),
|
||||
(createSubIntGLWEOp $arg0, $arg1, $result)>;
|
||||
|
||||
def createNegGLWEOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromHLFHE<mlir::concretelang::MidLFHE::NegGLWEOp>($_builder, $_loc, $0, $1)">;
|
||||
def createNegGLWEOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromFHE<mlir::concretelang::TFHE::NegGLWEOp>($_builder, $_loc, $0, $1)">;
|
||||
|
||||
def NegEintPattern : Pat<
|
||||
(NegEintOp:$result $arg0),
|
||||
(createNegGLWEOp $arg0, $result)>;
|
||||
|
||||
def createMulGLWEIntOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromHLFHE<mlir::concretelang::MidLFHE::MulGLWEIntOp>($_builder, $_loc, $0, $1, $2)">;
|
||||
def createMulGLWEIntOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromFHE<mlir::concretelang::TFHE::MulGLWEIntOp>($_builder, $_loc, $0, $1, $2)">;
|
||||
|
||||
def MulEintIntPattern : Pat<
|
||||
(MulEintIntOp:$result $arg0, $arg1),
|
||||
(createMulGLWEIntOp $arg0, $arg1, $result)>;
|
||||
|
||||
def createApplyLookupTableGLWEOp : NativeCodeCall<"mlir::concretelang::createApplyLookupTableGLWEOpFromHLFHE($_builder, $_loc, $0, $1, $2)">;
|
||||
def createApplyLookupTableGLWEOp : NativeCodeCall<"mlir::concretelang::createApplyLookupTableGLWEOpFromFHE($_builder, $_loc, $0, $1, $2)">;
|
||||
|
||||
def ApplyLookupTableEintPattern : Pat<
|
||||
(ApplyLookupTableEintOp:$result $arg0, $arg1),
|
||||
@@ -1,5 +0,0 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Patterns.td)
|
||||
mlir_tablegen(Patterns.h.inc -gen-rewriters -name HLFHE)
|
||||
add_public_tablegen_target(HLFHEToMidLFHEPatternsIncGen)
|
||||
|
||||
add_concretelang_doc(Patterns HLFHEToMidLFHEPatterns concretelang/ -gen-pass-doc)
|
||||
@@ -1,5 +0,0 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Patterns.td)
|
||||
mlir_tablegen(Patterns.h.inc -gen-rewriters -name MidLFHE)
|
||||
add_public_tablegen_target(MidLFHEToLowLFHEPatternsIncGen)
|
||||
|
||||
add_concretelang_doc(Patterns MidLFHEToLowLFHEPatterns concretelang/ -gen-pass-doc)
|
||||
@@ -9,16 +9,16 @@
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
|
||||
#include "concretelang/Conversion/HLFHETensorOpsToLinalg/Pass.h"
|
||||
#include "concretelang/Conversion/HLFHEToMidLFHE/Pass.h"
|
||||
#include "concretelang/Conversion/LowLFHEToConcreteCAPI/Pass.h"
|
||||
#include "concretelang/Conversion/LowLFHEUnparametrize/Pass.h"
|
||||
#include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h"
|
||||
#include "concretelang/Conversion/FHEToTFHE/Pass.h"
|
||||
#include "concretelang/Conversion/ConcreteToConcreteCAPI/Pass.h"
|
||||
#include "concretelang/Conversion/ConcreteUnparametrize/Pass.h"
|
||||
#include "concretelang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
|
||||
#include "concretelang/Conversion/MidLFHEGlobalParametrization/Pass.h"
|
||||
#include "concretelang/Conversion/MidLFHEToLowLFHE/Pass.h"
|
||||
#include "concretelang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "concretelang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
#include "concretelang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
#include "concretelang/Conversion/TFHEGlobalParametrization/Pass.h"
|
||||
#include "concretelang/Conversion/TFHEToConcrete/Pass.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "concretelang/Conversion/Passes.h.inc"
|
||||
|
||||
@@ -3,45 +3,45 @@
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def HLFHETensorOpsToLinalg : FunctionPass<"hlfhe-tensor-ops-to-linalg"> {
|
||||
let summary = "Lowers tensor operations of HLFHE dialect to linalg.generic";
|
||||
let constructor = "mlir::concretelang::createConvertHLFHETensorOpsToLinalg()";
|
||||
def FHETensorOpsToLinalg : FunctionPass<"fhe-tensor-ops-to-linalg"> {
|
||||
let summary = "Lowers tensor operations of FHE dialect to linalg.generic";
|
||||
let constructor = "mlir::concretelang::createConvertFHETensorOpsToLinalg()";
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect"];
|
||||
}
|
||||
|
||||
def HLFHEToMidLFHE : Pass<"hlfhe-to-midlfhe", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the HLFHE dialect to MidLFHE";
|
||||
let description = [{ Lowers operations from the HLFHE dialect to Std + Math }];
|
||||
let constructor = "mlir::concretelang::createConvertHLFHEToMidLFHEPass()";
|
||||
def FHEToTFHE : Pass<"fhe-to-tfhe", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the FHE dialect to TFHE";
|
||||
let description = [{ Lowers operations from the FHE dialect to Std + Math }];
|
||||
let constructor = "mlir::concretelang::createConvertFHEToTFHEPass()";
|
||||
let options = [];
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect"];
|
||||
}
|
||||
|
||||
def MidLFHEGlobalParametrization : Pass<"midlfhe-global-parametrization", "mlir::ModuleOp"> {
|
||||
let summary = "Inject global fhe parameters to the MidLFHE dialect";
|
||||
let constructor = "mlir::concretelang::createConvertMidLFHEToLowLFHEPass()";
|
||||
def TFHEGlobalParametrization : Pass<"tfhe-global-parametrization", "mlir::ModuleOp"> {
|
||||
let summary = "Inject global fhe parameters to the TFHE dialect";
|
||||
let constructor = "mlir::concretelang::createConvertTFHEToConcretePass()";
|
||||
let options = [];
|
||||
let dependentDialects = ["mlir::concretelang::MidLFHE::MidLFHEDialect"];
|
||||
let dependentDialects = ["mlir::concretelang::TFHE::TFHEDialect"];
|
||||
}
|
||||
|
||||
def MidLFHEToLowLFHE : Pass<"midlfhe-to-lowlfhe", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the MidLFHE dialect to LowLFHE";
|
||||
let description = [{ Lowers operations from the MidLFHE dialect to LowLFHE }];
|
||||
let constructor = "mlir::concretelang::createConvertMidLFHEToLowLFHEPass()";
|
||||
def TFHEToConcrete : Pass<"tfhe-to-concrete", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the TFHE dialect to Concrete";
|
||||
let description = [{ Lowers operations from the TFHE dialect to Concrete }];
|
||||
let constructor = "mlir::concretelang::createConvertTFHEToConcretePass()";
|
||||
let options = [];
|
||||
let dependentDialects = ["mlir::linalg::LinalgDialect"];
|
||||
}
|
||||
|
||||
def LowLFHEToConcreteCAPI : Pass<"lowlfhe-to-concrete-c-api", "mlir::ModuleOp"> {
|
||||
let summary = "Lower operations from the LowLFHE dialect to std with function call to the Concrete C API";
|
||||
let constructor = "mlir::concretelang::createConvertLowLFHEToConcreteCAPIPass()";
|
||||
let dependentDialects = ["mlir::concretelang::LowLFHE::LowLFHEDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
|
||||
def ConcreteToConcreteCAPI : Pass<"concrete-to-concrete-c-api", "mlir::ModuleOp"> {
|
||||
let summary = "Lower operations from the Concrete dialect to std with function call to the Concrete C API";
|
||||
let constructor = "mlir::concretelang::createConvertConcreteToConcreteCAPIPass()";
|
||||
let dependentDialects = ["mlir::concretelang::Concrete::ConcreteDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def LowLFHEUnparametrize : Pass<"lowlfhe-unparametrize", "mlir::ModuleOp"> {
|
||||
let summary = "Unparametrize LowLFHE types and remove unrealized_conversion_cast";
|
||||
let constructor = "mlir::concretelang::createConvertLowLFHEToConcreteCAPIPass()";
|
||||
let dependentDialects = ["mlir::concretelang::LowLFHE::LowLFHEDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
|
||||
def ConcreteUnparametrize : Pass<"concrete-unparametrize", "mlir::ModuleOp"> {
|
||||
let summary = "Unparametrize Concrete types and remove unrealized_conversion_cast";
|
||||
let constructor = "mlir::concretelang::createConvertConcreteToConcreteCAPIPass()";
|
||||
let dependentDialects = ["mlir::concretelang::Concrete::ConcreteDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> {
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_MIDLFHEGLOBALPARAMETRIZATION_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_MIDLFHEGLOBALPARAMETRIZATION_PASS_H_
|
||||
#ifndef CONCRETELANG_CONVERSION_TFHEGLOBALPARAMETRIZATION_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_TFHEGLOBALPARAMETRIZATION_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
@@ -11,9 +11,9 @@
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to inject fhe parameters to the MidLFHE types and operators.
|
||||
/// Create a pass to inject fhe parameters to the TFHE types and operators.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertMidLFHEGlobalParametrizationPass(
|
||||
createConvertTFHEGlobalParametrizationPass(
|
||||
mlir::concretelang::V0FHEContext &fheContext);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -0,0 +1,5 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Patterns.td)
|
||||
mlir_tablegen(Patterns.h.inc -gen-rewriters -name TFHE)
|
||||
add_public_tablegen_target(TFHEToConcretePatternsIncGen)
|
||||
|
||||
add_concretelang_doc(Patterns TFHEToConcretePatterns concretelang/ -gen-pass-doc)
|
||||
@@ -2,15 +2,15 @@
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_HLFHETOMIDLFHE_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_HLFHETOMIDLFHE_PASS_H_
|
||||
#ifndef CONCRETELANG_CONVERSION_TFHETOCONCRETE_PASS_H_
|
||||
#define CONCRETELANG_CONVERSION_TFHETOCONCRETE_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `HLFHE` dialect to `MidLFHE` dialect.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertHLFHEToMidLFHEPass();
|
||||
/// Create a pass to convert `TFHE` dialect to `Concrete` dialect.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTFHEToConcretePass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CONVERSION_MIDLFHETOLOWLFHE_PATTERNS_H_
|
||||
#define CONCRETELANG_CONVERSION_MIDLFHETOLOWLFHE_PATTERNS_H_
|
||||
#ifndef CONCRETELANG_CONVERSION_TFHETOCONCRETE_PATTERNS_H_
|
||||
#define CONCRETELANG_CONVERSION_TFHETOCONCRETE_PATTERNS_H_
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "concretelang/Dialect/LowLFHE/IR/LowLFHEOps.h"
|
||||
#include "concretelang/Dialect/MidLFHE/IR/MidLFHEOps.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
using LowLFHE::CleartextType;
|
||||
using LowLFHE::LweCiphertextType;
|
||||
using LowLFHE::PlaintextType;
|
||||
using MidLFHE::GLWECipherTextType;
|
||||
using Concrete::CleartextType;
|
||||
using Concrete::LweCiphertextType;
|
||||
using Concrete::PlaintextType;
|
||||
using TFHE::GLWECipherTextType;
|
||||
|
||||
LweCiphertextType convertTypeToLWE(mlir::MLIRContext *context,
|
||||
mlir::Type type) {
|
||||
@@ -81,7 +81,7 @@ CleartextType convertCleartextTypeFromType(mlir::MLIRContext *context,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
mlir::Value createZeroLWEOpFromMidLFHE(mlir::PatternRewriter &rewriter,
|
||||
mlir::Value createZeroLWEOpFromTFHE(mlir::PatternRewriter &rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::OpResult result) {
|
||||
mlir::SmallVector<mlir::Value> args{};
|
||||
@@ -89,13 +89,13 @@ mlir::Value createZeroLWEOpFromMidLFHE(mlir::PatternRewriter &rewriter,
|
||||
auto glwe = result.getType().cast<GLWECipherTextType>();
|
||||
mlir::SmallVector<mlir::Type, 1> resTypes{
|
||||
convertTypeToLWE(rewriter.getContext(), glwe)};
|
||||
LowLFHE::ZeroLWEOp op =
|
||||
rewriter.create<LowLFHE::ZeroLWEOp>(loc, resTypes, args, attrs);
|
||||
Concrete::ZeroLWEOp op =
|
||||
rewriter.create<Concrete::ZeroLWEOp>(loc, resTypes, args, attrs);
|
||||
return op.getODSResults(0).front();
|
||||
}
|
||||
|
||||
template <class Operator>
|
||||
mlir::Value createLowLFHEOpFromMidLFHE(mlir::PatternRewriter &rewriter,
|
||||
mlir::Value createConcreteOpFromTFHE(mlir::PatternRewriter &rewriter,
|
||||
mlir::Location loc, mlir::Value arg0,
|
||||
mlir::Value arg1,
|
||||
mlir::OpResult result) {
|
||||
@@ -116,14 +116,14 @@ mlir::Value createAddPlainLweCiphertextWithGlwe(
|
||||
// encode int into plaintext
|
||||
mlir::Value encoded =
|
||||
rewriter
|
||||
.create<mlir::concretelang::LowLFHE::EncodeIntOp>(loc, encoded_type, arg1)
|
||||
.create<mlir::concretelang::Concrete::EncodeIntOp>(loc, encoded_type, arg1)
|
||||
.plaintext();
|
||||
// convert result type
|
||||
LweCiphertextType lwe_type =
|
||||
convertTypeToLWE(rewriter.getContext(), result.getType());
|
||||
// replace op using the encoded plaintext instead of int
|
||||
auto op =
|
||||
rewriter.create<mlir::concretelang::LowLFHE::AddPlaintextLweCiphertextOp>(
|
||||
rewriter.create<mlir::concretelang::Concrete::AddPlaintextLweCiphertextOp>(
|
||||
loc, lwe_type, arg0, encoded);
|
||||
return op.getODSResults(0).front();
|
||||
}
|
||||
@@ -142,7 +142,7 @@ mlir::Value createSubIntLweCiphertext(mlir::PatternRewriter &rewriter,
|
||||
auto arg1_type = arg1.getType();
|
||||
auto negated_arg1 =
|
||||
rewriter
|
||||
.create<mlir::concretelang::LowLFHE::NegateLweCiphertextOp>(
|
||||
.create<mlir::concretelang::Concrete::NegateLweCiphertextOp>(
|
||||
loc, convertTypeToLWE(rewriter.getContext(), arg1_type), arg1)
|
||||
.result();
|
||||
return createAddPlainLweCiphertextWithGlwe(rewriter, loc, negated_arg1, arg0,
|
||||
@@ -154,7 +154,7 @@ mlir::Value createNegLweCiphertext(mlir::PatternRewriter &rewriter,
|
||||
mlir::OpResult result) {
|
||||
auto arg0_type = arg0.getType();
|
||||
auto negated =
|
||||
rewriter.create<mlir::concretelang::LowLFHE::NegateLweCiphertextOp>(
|
||||
rewriter.create<mlir::concretelang::Concrete::NegateLweCiphertextOp>(
|
||||
loc, convertTypeToLWE(rewriter.getContext(), arg0_type), arg0);
|
||||
return negated.getODSResults(0).front();
|
||||
}
|
||||
@@ -168,7 +168,7 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter &rewriter,
|
||||
convertCleartextTypeFromType(rewriter.getContext(), inType);
|
||||
// encode int into plaintext
|
||||
mlir::Value encoded = rewriter
|
||||
.create<mlir::concretelang::LowLFHE::IntToCleartextOp>(
|
||||
.create<mlir::concretelang::Concrete::IntToCleartextOp>(
|
||||
loc, encoded_type, arg1)
|
||||
.cleartext();
|
||||
// convert result type
|
||||
@@ -176,12 +176,12 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter &rewriter,
|
||||
LweCiphertextType lwe_type = convertTypeToLWE(rewriter.getContext(), resType);
|
||||
// replace op using the encoded plaintext instead of int
|
||||
auto op =
|
||||
rewriter.create<mlir::concretelang::LowLFHE::MulCleartextLweCiphertextOp>(
|
||||
rewriter.create<mlir::concretelang::Concrete::MulCleartextLweCiphertextOp>(
|
||||
loc, lwe_type, arg0, encoded);
|
||||
return op.getODSResults(0).front();
|
||||
}
|
||||
|
||||
// This is the rewritting of the HLFHE::ApplyLookupTable operation, it will be
|
||||
// This is the rewritting of the FHE::ApplyLookupTable operation, it will be
|
||||
// rewritten as 3 new operations:
|
||||
// - Create the required GLWE ciphertext out of the plain lookup table
|
||||
// - Keyswitch the input ciphertext to match the input key of the bootstrapping
|
||||
@@ -189,7 +189,7 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter &rewriter,
|
||||
// Example:
|
||||
// from:
|
||||
// ```
|
||||
// "%result = MidLFHE.apply_lookup_table"(% arg0, % tlu){
|
||||
// "%result = TFHE.apply_lookup_table"(% arg0, % tlu){
|
||||
// glweDimension = 1 : i32,
|
||||
// polynomialSize = 2048 : i32,
|
||||
// levelKS = 3 : i32,
|
||||
@@ -197,29 +197,29 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter &rewriter,
|
||||
// levelBS = 5 : i32,
|
||||
// baseLogBS = 4 : i32,
|
||||
// outputSizeKS = 600 : i32
|
||||
// } : (!MidLFHE.glwe<{2048, 1, 64} {4}>, tensor<16xi4>)
|
||||
// ->(!MidLFHE.glwe<{2048, 1, 64} {4}>)
|
||||
// } : (!TFHE.glwe<{2048, 1, 64} {4}>, tensor<16xi4>)
|
||||
// ->(!TFHE.glwe<{2048, 1, 64} {4}>)
|
||||
// ```
|
||||
// to:
|
||||
// ```
|
||||
// % accumulator =
|
||||
// "LowLFHE.glwe_from_table"(
|
||||
// "Concrete.glwe_from_table"(
|
||||
// % [[TABLE]]){glweDimension = 1 : i32, p = 4 : i32, polynomialSize =
|
||||
// 2048 : i32}
|
||||
// : (tensor<16xi4>)
|
||||
// ->!LowLFHE.glwe_ciphertext
|
||||
// % keyswitched = "LowLFHE.keyswitch_lwe"(% arg0){
|
||||
// ->!Concrete.glwe_ciphertext
|
||||
// % keyswitched = "Concrete.keyswitch_lwe"(% arg0){
|
||||
// baseLog = 2 : i32,
|
||||
// level = 3 : i32
|
||||
// } : (!LowLFHE.lwe_ciphertext<2048, 4>)
|
||||
// ->!LowLFHE.lwe_ciphertext<600, 4>
|
||||
// % result = "LowLFHE.bootstrap_lwe"(% keyswitched, % accumulator){
|
||||
// } : (!Concrete.lwe_ciphertext<2048, 4>)
|
||||
// ->!Concrete.lwe_ciphertext<600, 4>
|
||||
// % result = "Concrete.bootstrap_lwe"(% keyswitched, % accumulator){
|
||||
// baseLog = 4 : i32,
|
||||
// glweDimension = 1 : i32,
|
||||
// level = 5 : i32,
|
||||
// polynomialSize = 2048 : i32
|
||||
// } : (!LowLFHE.lwe_ciphertext<600, 4>, !LowLFHE.glwe_ciphertext)
|
||||
// ->!LowLFHE.lwe_ciphertext<2048, 4>
|
||||
// } : (!Concrete.lwe_ciphertext<600, 4>, !Concrete.glwe_ciphertext)
|
||||
// ->!Concrete.lwe_ciphertext<2048, 4>
|
||||
// ```
|
||||
mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc,
|
||||
mlir::Value ct, mlir::Value table,
|
||||
@@ -236,8 +236,8 @@ mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc,
|
||||
mlir::IntegerAttr precision = rewriter.getI32IntegerAttr(lwe_type.getP());
|
||||
mlir::Value accumulator =
|
||||
rewriter
|
||||
.create<mlir::concretelang::LowLFHE::GlweFromTable>(
|
||||
loc, LowLFHE::GlweCiphertextType::get(rewriter.getContext()),
|
||||
.create<mlir::concretelang::Concrete::GlweFromTable>(
|
||||
loc, Concrete::GlweCiphertextType::get(rewriter.getContext()),
|
||||
table, polynomialSize, glweDimension, precision)
|
||||
.result();
|
||||
|
||||
@@ -255,7 +255,7 @@ mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc,
|
||||
convertTypeToLWE(rewriter.getContext(), result.getType());
|
||||
mlir::Value keyswitched =
|
||||
rewriter
|
||||
.create<mlir::concretelang::LowLFHE::KeySwitchLweOp>(loc, ksOutType,
|
||||
.create<mlir::concretelang::Concrete::KeySwitchLweOp>(loc, ksOutType,
|
||||
ksArgs, ksAttrs)
|
||||
.result();
|
||||
|
||||
@@ -275,7 +275,7 @@ mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc,
|
||||
};
|
||||
mlir::Value bootstrapped =
|
||||
rewriter
|
||||
.create<mlir::concretelang::LowLFHE::BootstrapLweOp>(loc, lwe_type,
|
||||
.create<mlir::concretelang::Concrete::BootstrapLweOp>(loc, lwe_type,
|
||||
bsArgs, bsAttrs)
|
||||
.result();
|
||||
|
||||
@@ -286,10 +286,10 @@ mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc,
|
||||
} // namespace mlir
|
||||
|
||||
namespace {
|
||||
#include "concretelang/Conversion/MidLFHEToLowLFHE/Patterns.h.inc"
|
||||
#include "concretelang/Conversion/TFHEToConcrete/Patterns.h.inc"
|
||||
}
|
||||
|
||||
void populateWithGeneratedMidLFHEToLowLFHE(mlir::RewritePatternSet &patterns) {
|
||||
void populateWithGeneratedTFHEToConcrete(mlir::RewritePatternSet &patterns) {
|
||||
populateWithGenerated(patterns);
|
||||
}
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
#ifndef CONCRETELANG_CONVERSION_MIDLFHETOLOWLFHE_PATTERNS
|
||||
#define CONCRETELANG_CONVERSION_MIDLFHETOLOWLFHE_PATTERNS
|
||||
#ifndef CONCRETELANG_CONVERSION_TFHETOCONCRETE_PATTERNS
|
||||
#define CONCRETELANG_CONVERSION_TFHETOCONCRETE_PATTERNS
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "concretelang/Dialect/LowLFHE/IR/LowLFHEOps.td"
|
||||
include "concretelang/Dialect/MidLFHE/IR/MidLFHEOps.td"
|
||||
include "concretelang/Dialect/Concrete/IR/ConcreteOps.td"
|
||||
include "concretelang/Dialect/TFHE/IR/TFHEOps.td"
|
||||
|
||||
def createZeroLWEOp : NativeCodeCall<"mlir::concretelang::createZeroLWEOpFromMidLFHE($_builder, $_loc, $0)">;
|
||||
def createZeroLWEOp : NativeCodeCall<"mlir::concretelang::createZeroLWEOpFromTFHE($_builder, $_loc, $0)">;
|
||||
|
||||
def ZeroGLWEPattern : Pat<
|
||||
(ZeroGLWEOp:$result),
|
||||
(createZeroLWEOp $result)>;
|
||||
|
||||
def createAddLWEOp : NativeCodeCall<"mlir::concretelang::createLowLFHEOpFromMidLFHE<mlir::concretelang::LowLFHE::AddLweCiphertextsOp>($_builder, $_loc, $0, $1, $2)">;
|
||||
def createAddLWEOp : NativeCodeCall<"mlir::concretelang::createConcreteOpFromTFHE<mlir::concretelang::Concrete::AddLweCiphertextsOp>($_builder, $_loc, $0, $1, $2)">;
|
||||
|
||||
def AddGLWEPattern : Pat<
|
||||
(AddGLWEOp:$result $arg0, $arg1),
|
||||
@@ -1,5 +1,5 @@
|
||||
add_subdirectory(HLFHE)
|
||||
add_subdirectory(HLFHELinalg)
|
||||
add_subdirectory(MidLFHE)
|
||||
add_subdirectory(LowLFHE)
|
||||
add_subdirectory(FHE)
|
||||
add_subdirectory(FHELinalg)
|
||||
add_subdirectory(TFHE)
|
||||
add_subdirectory(Concrete)
|
||||
add_subdirectory(RT)
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
set(LLVM_TARGET_DEFINITIONS ConcreteOps.td)
|
||||
mlir_tablegen(ConcreteOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(ConcreteOps.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(ConcreteOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=Concrete)
|
||||
mlir_tablegen(ConcreteOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=Concrete)
|
||||
mlir_tablegen(ConcreteOpsDialect.h.inc -gen-dialect-decls -dialect=Concrete)
|
||||
mlir_tablegen(ConcreteOpsDialect.cpp.inc -gen-dialect-defs -dialect=Concrete)
|
||||
add_public_tablegen_target(MLIRConcreteOpsIncGen)
|
||||
add_dependencies(mlir-headers MLIRConcreteOpsIncGen)
|
||||
|
||||
add_concretelang_doc(ConcreteDialect ConcreteDialect concretelang/ -gen-dialect-doc)
|
||||
add_concretelang_doc(ConcreteOps ConcreteOps concretelang/ -gen-op-doc)
|
||||
add_concretelang_doc(ConcreteTypes ConcreteTypes concretelang/ -gen-typedef-doc)
|
||||
@@ -1,8 +1,8 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_MIDLFHE_IR_MIDLFHEDIALECT_H
|
||||
#define CONCRETELANG_DIALECT_MIDLFHE_IR_MIDLFHEDIALECT_H
|
||||
#ifndef CONCRETELANG_DIALECT_Concrete_IR_ConcreteDIALECT_H
|
||||
#define CONCRETELANG_DIALECT_Concrete_IR_ConcreteDIALECT_H
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
@@ -11,6 +11,6 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "concretelang/Dialect/MidLFHE/IR/MidLFHEOpsDialect.h.inc"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteOpsDialect.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,15 @@
|
||||
#ifndef CONCRETELANG_DIALECT_Concrete_IR_Concrete_DIALECT
|
||||
#define CONCRETELANG_DIALECT_Concrete_IR_Concrete_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def Concrete_Dialect : Dialect {
|
||||
let name = "Concrete";
|
||||
let summary = "Low Level Fully Homorphic Encryption dialect";
|
||||
let description = [{
|
||||
A dialect for representation of low level operation on fully homomorphic ciphertext.
|
||||
}];
|
||||
let cppNamespace = "::mlir::concretelang::Concrete";
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,8 +1,8 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_MIDLFHE_IR_MIDLFHEOPS_H
|
||||
#define CONCRETELANG_DIALECT_MIDLFHE_IR_MIDLFHEOPS_H
|
||||
#ifndef CONCRETELANG_DIALECT_Concrete_Concrete_OPS_H
|
||||
#define CONCRETELANG_DIALECT_Concrete_Concrete_OPS_H
|
||||
|
||||
#include <mlir/IR/Builders.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
@@ -10,9 +10,9 @@
|
||||
#include <mlir/Interfaces/ControlFlowInterfaces.h>
|
||||
#include <mlir/Interfaces/SideEffectInterfaces.h>
|
||||
|
||||
#include "concretelang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "concretelang/Dialect/MidLFHE/IR/MidLFHEOps.h.inc"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -1,58 +1,58 @@
|
||||
#ifndef CONCRETELANG_DIALECT_LowLFHE_IR_LowLFHE_OPS
|
||||
#define CONCRETELANG_DIALECT_LowLFHE_IR_LowLFHE_OPS
|
||||
#ifndef CONCRETELANG_DIALECT_Concrete_IR_Concrete_OPS
|
||||
#define CONCRETELANG_DIALECT_Concrete_IR_Concrete_OPS
|
||||
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
|
||||
include "concretelang/Dialect/LowLFHE/IR/LowLFHEDialect.td"
|
||||
include "concretelang/Dialect/LowLFHE/IR/LowLFHETypes.td"
|
||||
include "concretelang/Dialect/Concrete/IR/ConcreteDialect.td"
|
||||
include "concretelang/Dialect/Concrete/IR/ConcreteTypes.td"
|
||||
|
||||
class LowLFHE_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<LowLFHE_Dialect, mnemonic, traits>;
|
||||
class Concrete_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Concrete_Dialect, mnemonic, traits>;
|
||||
|
||||
def ZeroLWEOp : LowLFHE_Op<"zero"> {
|
||||
def ZeroLWEOp : Concrete_Op<"zero"> {
|
||||
let summary = "Returns a trivial encyption of 0";
|
||||
|
||||
let arguments = (ins);
|
||||
let results = (outs LweCiphertextType:$out);
|
||||
}
|
||||
|
||||
def AddLweCiphertextsOp : LowLFHE_Op<"add_lwe_ciphertexts"> {
|
||||
def AddLweCiphertextsOp : Concrete_Op<"add_lwe_ciphertexts"> {
|
||||
let summary = "Returns the sum of 2 lwe ciphertexts";
|
||||
|
||||
let arguments = (ins LweCiphertextType:$lhs, LweCiphertextType:$rhs);
|
||||
let results = (outs LweCiphertextType:$result);
|
||||
}
|
||||
|
||||
def AddPlaintextLweCiphertextOp : LowLFHE_Op<"add_plaintext_lwe_ciphertext"> {
|
||||
def AddPlaintextLweCiphertextOp : Concrete_Op<"add_plaintext_lwe_ciphertext"> {
|
||||
let summary = "Returns the sum of a clear integer and a lwe ciphertext";
|
||||
|
||||
let arguments = (ins LweCiphertextType:$lhs, PlaintextType:$rhs);
|
||||
let results = (outs LweCiphertextType:$result);
|
||||
}
|
||||
|
||||
def MulCleartextLweCiphertextOp : LowLFHE_Op<"mul_cleartext_lwe_ciphertext"> {
|
||||
def MulCleartextLweCiphertextOp : Concrete_Op<"mul_cleartext_lwe_ciphertext"> {
|
||||
let summary = "Returns the product of a clear integer and a lwe ciphertext";
|
||||
|
||||
let arguments = (ins LweCiphertextType:$lhs, CleartextType:$rhs);
|
||||
let results = (outs LweCiphertextType:$result);
|
||||
}
|
||||
|
||||
def NegateLweCiphertextOp : LowLFHE_Op<"negate_lwe_ciphertext"> {
|
||||
def NegateLweCiphertextOp : Concrete_Op<"negate_lwe_ciphertext"> {
|
||||
let summary = "Negates a lwe ciphertext";
|
||||
|
||||
let arguments = (ins LweCiphertextType:$ciphertext);
|
||||
let results = (outs LweCiphertextType:$result);
|
||||
}
|
||||
|
||||
def GlweFromTable : LowLFHE_Op<"glwe_from_table"> {
|
||||
def GlweFromTable : Concrete_Op<"glwe_from_table"> {
|
||||
let summary = "Creates a GLWE ciphertext which is the trivial encrytion of a the input table interpreted as a polynomial (to use later in a bootstrap)";
|
||||
|
||||
let arguments = (ins TensorOf<[AnyInteger]>:$table, I32Attr:$polynomialSize, I32Attr:$glweDimension, I32Attr:$p);
|
||||
let results = (outs GlweCiphertextType:$result);
|
||||
}
|
||||
|
||||
def BootstrapLweOp : LowLFHE_Op<"bootstrap_lwe"> {
|
||||
def BootstrapLweOp : Concrete_Op<"bootstrap_lwe"> {
|
||||
let summary = "Bootstraps a LWE ciphertext with a GLWE trivial encryption of the lookup table";
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ def BootstrapLweOp : LowLFHE_Op<"bootstrap_lwe"> {
|
||||
let results = (outs LweCiphertextType:$result);
|
||||
}
|
||||
|
||||
def KeySwitchLweOp : LowLFHE_Op<"keyswitch_lwe"> {
|
||||
def KeySwitchLweOp : Concrete_Op<"keyswitch_lwe"> {
|
||||
let summary = "Keyswitches a LWE ciphertext";
|
||||
|
||||
let arguments = (ins
|
||||
@@ -80,14 +80,14 @@ def KeySwitchLweOp : LowLFHE_Op<"keyswitch_lwe"> {
|
||||
let results = (outs LweCiphertextType:$result);
|
||||
}
|
||||
|
||||
def EncodeIntOp : LowLFHE_Op<"encode_int"> {
|
||||
def EncodeIntOp : Concrete_Op<"encode_int"> {
|
||||
let summary = "Encodes an integer (for it to later be added to a LWE ciphertext)";
|
||||
|
||||
let arguments = (ins AnyInteger:$i);
|
||||
let results = (outs PlaintextType:$plaintext);
|
||||
}
|
||||
|
||||
def IntToCleartextOp : LowLFHE_Op<"int_to_cleartext"> {
|
||||
def IntToCleartextOp : Concrete_Op<"int_to_cleartext"> {
|
||||
let summary = "Keyswitches a LWE ciphertext";
|
||||
|
||||
let arguments = (ins AnyInteger:$i);
|
||||
@@ -1,8 +1,8 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_LowLFHE_IR_LowLFHETYPES_H
|
||||
#define CONCRETELANG_DIALECT_LowLFHE_IR_LowLFHETYPES_H
|
||||
#ifndef CONCRETELANG_DIALECT_Concrete_IR_ConcreteTYPES_H
|
||||
#define CONCRETELANG_DIALECT_Concrete_IR_ConcreteTYPES_H
|
||||
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
@@ -10,6 +10,6 @@
|
||||
#include <mlir/IR/DialectImplementation.h>
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "concretelang/Dialect/LowLFHE/IR/LowLFHEOpsTypes.h.inc"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteOpsTypes.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -1,13 +1,13 @@
|
||||
#ifndef CONCRETELANG_DIALECT_LowLFHE_IR_LowLFHE_TYPES
|
||||
#define CONCRETELANG_DIALECT_LowLFHE_IR_LowLFHE_TYPES
|
||||
#ifndef CONCRETELANG_DIALECT_Concrete_IR_Concrete_TYPES
|
||||
#define CONCRETELANG_DIALECT_Concrete_IR_Concrete_TYPES
|
||||
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
|
||||
include "concretelang/Dialect/LowLFHE/IR/LowLFHEDialect.td"
|
||||
include "concretelang/Dialect/Concrete/IR/ConcreteDialect.td"
|
||||
|
||||
class LowLFHE_Type<string name, list<Trait> traits = []> : TypeDef<LowLFHE_Dialect, name, traits> { }
|
||||
class Concrete_Type<string name, list<Trait> traits = []> : TypeDef<Concrete_Dialect, name, traits> { }
|
||||
|
||||
def GlweCiphertextType : LowLFHE_Type<"GlweCiphertext"> {
|
||||
def GlweCiphertextType : Concrete_Type<"GlweCiphertext"> {
|
||||
let mnemonic = "glwe_ciphertext";
|
||||
|
||||
let summary = "A GLWE ciphertext (encryption of a polynomial of fixed-precision integers)";
|
||||
@@ -25,7 +25,7 @@ def GlweCiphertextType : LowLFHE_Type<"GlweCiphertext"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def LweCiphertextType : LowLFHE_Type<"LweCiphertext", [MemRefElementTypeInterface]> {
|
||||
def LweCiphertextType : Concrete_Type<"LweCiphertext", [MemRefElementTypeInterface]> {
|
||||
let mnemonic = "lwe_ciphertext";
|
||||
|
||||
let summary = "A LWE ciphertext (encryption of a fixed-precision integer)";
|
||||
@@ -70,7 +70,7 @@ def LweCiphertextType : LowLFHE_Type<"LweCiphertext", [MemRefElementTypeInterfac
|
||||
}];
|
||||
}
|
||||
|
||||
def CleartextType : LowLFHE_Type<"Cleartext"> {
|
||||
def CleartextType : Concrete_Type<"Cleartext"> {
|
||||
let mnemonic = "cleartext";
|
||||
|
||||
let summary = "A cleartext (a fixed-precision integer) ready to be multiplied to a LWE ciphertext";
|
||||
@@ -104,7 +104,7 @@ def CleartextType : LowLFHE_Type<"Cleartext"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PlaintextType : LowLFHE_Type<"Plaintext"> {
|
||||
def PlaintextType : Concrete_Type<"Plaintext"> {
|
||||
let mnemonic = "plaintext";
|
||||
|
||||
let summary = "A Plaintext (a fixed-precision integer) ready to be added to a LWE ciphertext";
|
||||
@@ -138,7 +138,7 @@ def PlaintextType : LowLFHE_Type<"Plaintext"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PlaintextListType : LowLFHE_Type<"PlaintextList"> {
|
||||
def PlaintextListType : Concrete_Type<"PlaintextList"> {
|
||||
let mnemonic = "plaintext_list";
|
||||
|
||||
let summary = "List of plaintexts";
|
||||
@@ -156,7 +156,7 @@ def PlaintextListType : LowLFHE_Type<"PlaintextList"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def ForeignPlaintextListType : LowLFHE_Type<"ForeignPlaintextList"> {
|
||||
def ForeignPlaintextListType : Concrete_Type<"ForeignPlaintextList"> {
|
||||
let mnemonic = "foreign_plaintext_list";
|
||||
|
||||
let summary = "A foreign (reference to a independently allocated memory space) plaintext list";
|
||||
@@ -174,7 +174,7 @@ def ForeignPlaintextListType : LowLFHE_Type<"ForeignPlaintextList"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def LweKeySwitchKeyType : LowLFHE_Type<"LweKeySwitchKey"> {
|
||||
def LweKeySwitchKeyType : Concrete_Type<"LweKeySwitchKey"> {
|
||||
let mnemonic = "lwe_key_switch_key";
|
||||
|
||||
let summary = "A LWE keyswitching key";
|
||||
@@ -192,7 +192,7 @@ def LweKeySwitchKeyType : LowLFHE_Type<"LweKeySwitchKey"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def LweBootstrapKeyType : LowLFHE_Type<"LweBootstrapKey"> {
|
||||
def LweBootstrapKeyType : Concrete_Type<"LweBootstrapKey"> {
|
||||
let mnemonic = "lwe_bootstrap_key";
|
||||
|
||||
let summary = "A LWE bootstrapping key";
|
||||
@@ -210,7 +210,7 @@ def LweBootstrapKeyType : LowLFHE_Type<"LweBootstrapKey"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def Context : LowLFHE_Type<"Context"> {
|
||||
def Context : Concrete_Type<"Context"> {
|
||||
let mnemonic = "context";
|
||||
|
||||
let summary = "A runtime context";
|
||||
@@ -1,8 +1,8 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_HLFHE_ANALYSIS_MANP_H
|
||||
#define CONCRETELANG_DIALECT_HLFHE_ANALYSIS_MANP_H
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_ANALYSIS_MANP_H
|
||||
#define CONCRETELANG_DIALECT_FHE_ANALYSIS_MANP_H
|
||||
|
||||
#include <functional>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
@@ -1,10 +1,10 @@
|
||||
#ifndef CONCRETELANG_DIALECT_HLFHE_ANALYSIS_MANP
|
||||
#define CONCRETELANG_DIALECT_HLFHE_ANALYSIS_MANP
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_ANALYSIS_MANP
|
||||
#define CONCRETELANG_DIALECT_FHE_ANALYSIS_MANP
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def MANP : FunctionPass<"MANP"> {
|
||||
let summary = "HLFHE Minimal Arithmetic Noise Padding Pass";
|
||||
let summary = "FHE Minimal Arithmetic Noise Padding Pass";
|
||||
let description = [{
|
||||
This pass calculates the Minimal Arithmetic Noise Padding
|
||||
(MANP) for each operation of a function and stores the result in an
|
||||
@@ -15,14 +15,14 @@ def MANP : FunctionPass<"MANP"> {
|
||||
|
||||
The pass supports the following operations:
|
||||
|
||||
- HLFHELinalg.dot_eint_int
|
||||
- HLFHE.zero
|
||||
- HLFHE.add_eint_int
|
||||
- HLFHE.add_eint
|
||||
- HLFHE.sub_int_eint
|
||||
- HLFHE.neg_eint
|
||||
- HLFHE.mul_eint_int
|
||||
- HLFHE.apply_lookup_table
|
||||
- FHELinalg.dot_eint_int
|
||||
- FHE.zero
|
||||
- FHE.add_eint_int
|
||||
- FHE.add_eint
|
||||
- FHE.sub_int_eint
|
||||
- FHE.neg_eint
|
||||
- FHE.mul_eint_int
|
||||
- FHE.apply_lookup_table
|
||||
|
||||
If any other operation is encountered, the pass conservatively
|
||||
fails. The pass further makes the optimistic assumption that all
|
||||
@@ -42,24 +42,24 @@ def MANP : FunctionPass<"MANP"> {
|
||||
|
||||
with the following replacement rules:
|
||||
|
||||
- Function argument a -> HLFHELinalg.dot_eint_int([a], [1])
|
||||
- HLFHE.apply_lookup_table -> HLFHELinalg.dot_eint_int([LUT result], [1])
|
||||
- HLFHE.zero() -> HLFHELinalg.dot_eint_int([encrypted 0], [1])
|
||||
- HLFHE.add_eint_int(e, c) -> HLFHELinalg.dot_eint_int([e, 1], [1, c])
|
||||
- HLFHE.add_eint(e0, e1) -> HLFHELinalg.dot_eint_int([e0, e1], [1, 1])
|
||||
- HLFHE.sub_int_eint(c, e) -> HLFHELinalg.dot_eint_int([e, c], [1, -1])
|
||||
- HLFHE.neg_eint(e) -> HLFHELinalg.dot_eint_int([e], [-1])
|
||||
- HLFHE.mul_eint_int(e, c) -> HLFHELinalg.dot_eint_int([e], [c])
|
||||
- Function argument a -> FHELinalg.dot_eint_int([a], [1])
|
||||
- FHE.apply_lookup_table -> FHELinalg.dot_eint_int([LUT result], [1])
|
||||
- FHE.zero() -> FHELinalg.dot_eint_int([encrypted 0], [1])
|
||||
- FHE.add_eint_int(e, c) -> FHELinalg.dot_eint_int([e, 1], [1, c])
|
||||
- FHE.add_eint(e0, e1) -> FHELinalg.dot_eint_int([e0, e1], [1, 1])
|
||||
- FHE.sub_int_eint(c, e) -> FHELinalg.dot_eint_int([e, c], [1, -1])
|
||||
- FHE.neg_eint(e) -> FHELinalg.dot_eint_int([e], [-1])
|
||||
- FHE.mul_eint_int(e, c) -> FHELinalg.dot_eint_int([e], [c])
|
||||
|
||||
Dependent dot operations, e.g.,
|
||||
|
||||
a = HLFHELinalg.dot_eint_int([a0, a1, ...], [c0, c1, ...])
|
||||
b = HLFHELinalg.dot_eint_int([b0, b1, ...], [d0, d1, ...])
|
||||
x = HLFHELinalg.dot_eint_int([a, b, ...], [f0, f1, ...])
|
||||
a = FHELinalg.dot_eint_int([a0, a1, ...], [c0, c1, ...])
|
||||
b = FHELinalg.dot_eint_int([b0, b1, ...], [d0, d1, ...])
|
||||
x = FHELinalg.dot_eint_int([a, b, ...], [f0, f1, ...])
|
||||
|
||||
are merged as follows:
|
||||
|
||||
x = HLFHELinalg.dot_eint_int([a0, a1, ..., b0, b1, ...],
|
||||
x = FHELinalg.dot_eint_int([a0, a1, ..., b0, b1, ...],
|
||||
[f0*c0, f0*c1, ..., f1*d0, f1*d1, ...])
|
||||
|
||||
However, the implementation does not explicitly create the
|
||||
@@ -80,15 +80,15 @@ def MANP : FunctionPass<"MANP"> {
|
||||
for the supported operations:
|
||||
|
||||
- Function argument -> 1
|
||||
- HLFHE.apply_lookup_table -> 1
|
||||
- HLFHE.zero() -> 1
|
||||
- HLFHELinalg.dot_eint_int([e0, e1, ...], [c0, c1, ...]) ->
|
||||
- FHE.apply_lookup_table -> 1
|
||||
- FHE.zero() -> 1
|
||||
- FHELinalg.dot_eint_int([e0, e1, ...], [c0, c1, ...]) ->
|
||||
c0*c0*sqN(e0) + c1*c1*sqN(e1) + ...
|
||||
- HLFHE.add_eint_int(e, c) -> 1*1*sqN(e) + c*c*1*1 = sqN(e) + c*c
|
||||
- HLFHE.add_eint(e0, e1) -> 1*1*sqN(e0) + 1*1*sqN(e2) = sqN(e1) + sqN(e2)
|
||||
- HLFHE.sub_int_eint(c, e) -> 1*1*sqN(e) + c*c*(-1)*(-1) = sqN(e) + c*c
|
||||
- HLFHE.neg_eint(e) -> (-1)*(-1)*sqN(e) = sqN(e)
|
||||
- HLFHE.mul_eint_int(e, c) -> c*c*sqN(e)
|
||||
- FHE.add_eint_int(e, c) -> 1*1*sqN(e) + c*c*1*1 = sqN(e) + c*c
|
||||
- FHE.add_eint(e0, e1) -> 1*1*sqN(e0) + 1*1*sqN(e2) = sqN(e1) + sqN(e2)
|
||||
- FHE.sub_int_eint(c, e) -> 1*1*sqN(e) + c*c*(-1)*(-1) = sqN(e) + c*c
|
||||
- FHE.neg_eint(e) -> (-1)*(-1)*sqN(e) = sqN(e)
|
||||
- FHE.mul_eint_int(e, c) -> c*c*sqN(e)
|
||||
|
||||
The final, non-squared 2-norm of an operation is the square root of the
|
||||
squared value rounded to the next highest integer.
|
||||
@@ -96,7 +96,7 @@ def MANP : FunctionPass<"MANP"> {
|
||||
}
|
||||
|
||||
def MaxMANP : FunctionPass<"MaxMANP"> {
|
||||
let summary = "Extract maximum HLFHE Minimal Arithmetic Noise Padding and maximum encrypted integer width";
|
||||
let summary = "Extract maximum FHE Minimal Arithmetic Noise Padding and maximum encrypted integer width";
|
||||
let description = [{
|
||||
This pass calculates the squared Minimal Arithmetic Noise Padding
|
||||
(MANP) for each operation using the MANP pass and extracts the
|
||||
13
compiler/include/concretelang/Dialect/FHE/IR/CMakeLists.txt
Normal file
13
compiler/include/concretelang/Dialect/FHE/IR/CMakeLists.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
set(LLVM_TARGET_DEFINITIONS FHEOps.td)
|
||||
mlir_tablegen(FHEOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(FHEOps.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(FHEOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=FHE)
|
||||
mlir_tablegen(FHEOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=FHE)
|
||||
mlir_tablegen(FHEOpsDialect.h.inc -gen-dialect-decls -dialect=FHE)
|
||||
mlir_tablegen(FHEOpsDialect.cpp.inc -gen-dialect-defs -dialect=FHE)
|
||||
add_public_tablegen_target(MLIRFHEOpsIncGen)
|
||||
add_dependencies(mlir-headers MLIRFHEOpsIncGen)
|
||||
|
||||
add_concretelang_doc(FHEDialect FHEDialect concretelang/ -gen-dialect-doc)
|
||||
add_concretelang_doc(FHEOps FHEOps concretelang/ -gen-op-doc)
|
||||
add_concretelang_doc(FHETypes FHETypes concretelang/ -gen-typedef-doc)
|
||||
@@ -1,8 +1,8 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_HLFHE_IR_HLFHEDIALECT_H
|
||||
#define CONCRETELANG_DIALECT_HLFHE_IR_HLFHEDIALECT_H
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHEDIALECT_H
|
||||
#define CONCRETELANG_DIALECT_FHE_IR_FHEDIALECT_H
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
@@ -11,6 +11,6 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "concretelang/Dialect/HLFHE/IR/HLFHEOpsDialect.h.inc"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -1,4 +1,4 @@
|
||||
//===- HLFHEDialect.td - HLFHE dialect ----------------*- tablegen -*-===//
|
||||
//===- FHEDialect.td - FHE dialect ----------------*- tablegen -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
@@ -6,18 +6,18 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_HLFHE_IR_HLFHE_DIALECT
|
||||
#define CONCRETELANG_DIALECT_HLFHE_IR_HLFHE_DIALECT
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHE_DIALECT
|
||||
#define CONCRETELANG_DIALECT_FHE_IR_FHE_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def HLFHE_Dialect : Dialect {
|
||||
let name = "HLFHE";
|
||||
def FHE_Dialect : Dialect {
|
||||
let name = "FHE";
|
||||
let summary = "High Level Fully Homorphic Encryption dialect";
|
||||
let description = [{
|
||||
A dialect for representation of high level operation on fully homomorphic ciphertext.
|
||||
}];
|
||||
let cppNamespace = "::mlir::concretelang::HLFHE";
|
||||
let cppNamespace = "::mlir::concretelang::FHE";
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,19 +1,19 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_HLFHE_IR_HLFHEOPS_H
|
||||
#define CONCRETELANG_DIALECT_HLFHE_IR_HLFHEOPS_H
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHEOPS_H
|
||||
#define CONCRETELANG_DIALECT_FHE_IR_FHEOPS_H
|
||||
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
#include <mlir/Interfaces/ControlFlowInterfaces.h>
|
||||
#include <mlir/Interfaces/SideEffectInterfaces.h>
|
||||
|
||||
#include "concretelang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace HLFHE {
|
||||
namespace FHE {
|
||||
|
||||
bool verifyEncryptedIntegerInputAndResultConsistency(
|
||||
OpState &op, EncryptedIntegerType &input, EncryptedIntegerType &result);
|
||||
@@ -23,7 +23,7 @@ bool verifyEncryptedIntegerAndIntegerInputsConsistency(OpState &op,
|
||||
IntegerType &b);
|
||||
|
||||
/** Shared error message for all ApplyLookupTable variant Op (several Dialect)
|
||||
* E.g. HLFHE.apply_lookup_table(input, lut)
|
||||
* E.g. FHE.apply_lookup_table(input, lut)
|
||||
* Message when the lut tensor has an invalid size,
|
||||
* i.e. it cannot accomodate the input elements bitwidth
|
||||
*/
|
||||
@@ -38,11 +38,11 @@ void emitErrorBadLutSize(Op &op, std::string lutName, std::string inputName,
|
||||
<< " elements bitwidth (" << bitWidth << ")";
|
||||
}
|
||||
|
||||
} // namespace HLFHE
|
||||
} // namespace FHE
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "concretelang/Dialect/HLFHE/IR/HLFHEOps.h.inc"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOps.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -1,4 +1,4 @@
|
||||
//===- HLFHEOps.td - High level FHE dialect ops ----------------*- tablegen -*-===//
|
||||
//===- FHEOps.td - High level FHE dialect ops ----------------*- tablegen -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
@@ -6,27 +6,27 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_HLFHE_IR_HLFHE_OPS
|
||||
#define CONCRETELANG_DIALECT_HLFHE_IR_HLFHE_OPS
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHE_OPS
|
||||
#define CONCRETELANG_DIALECT_FHE_IR_FHE_OPS
|
||||
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
|
||||
include "concretelang/Dialect/HLFHE/IR/HLFHEDialect.td"
|
||||
include "concretelang/Dialect/HLFHE/IR/HLFHETypes.td"
|
||||
include "concretelang/Dialect/FHE/IR/FHEDialect.td"
|
||||
include "concretelang/Dialect/FHE/IR/FHETypes.td"
|
||||
|
||||
class HLFHE_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<HLFHE_Dialect, mnemonic, traits>;
|
||||
class FHE_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<FHE_Dialect, mnemonic, traits>;
|
||||
|
||||
// Generates an encrypted zero constant
|
||||
def ZeroEintOp : HLFHE_Op<"zero", [NoSideEffect]> {
|
||||
def ZeroEintOp : FHE_Op<"zero", [NoSideEffect]> {
|
||||
let summary = "Return an encryption of 0";
|
||||
|
||||
let arguments = (ins);
|
||||
let results = (outs EncryptedIntegerType:$out);
|
||||
}
|
||||
|
||||
def AddEintIntOp : HLFHE_Op<"add_eint_int"> {
|
||||
def AddEintIntOp : FHE_Op<"add_eint_int"> {
|
||||
|
||||
let summary = "Adds an encrypted integer and a clear integer";
|
||||
|
||||
@@ -40,11 +40,11 @@ def AddEintIntOp : HLFHE_Op<"add_eint_int"> {
|
||||
];
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::HLFHE::verifyAddEintIntOp(*this);
|
||||
return ::mlir::concretelang::FHE::verifyAddEintIntOp(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def AddEintOp : HLFHE_Op<"add_eint"> {
|
||||
def AddEintOp : FHE_Op<"add_eint"> {
|
||||
|
||||
let summary = "Adds two encrypted integers";
|
||||
|
||||
@@ -58,11 +58,11 @@ def AddEintOp : HLFHE_Op<"add_eint"> {
|
||||
];
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::HLFHE::verifyAddEintOp(*this);
|
||||
return ::mlir::concretelang::FHE::verifyAddEintOp(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def SubIntEintOp : HLFHE_Op<"sub_int_eint"> {
|
||||
def SubIntEintOp : FHE_Op<"sub_int_eint"> {
|
||||
|
||||
let summary = "Substract a clear integer and an encrypted integer";
|
||||
|
||||
@@ -76,11 +76,11 @@ def SubIntEintOp : HLFHE_Op<"sub_int_eint"> {
|
||||
];
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::HLFHE::verifySubIntEintOp(*this);
|
||||
return ::mlir::concretelang::FHE::verifySubIntEintOp(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def NegEintOp : HLFHE_Op<"neg_eint"> {
|
||||
def NegEintOp : FHE_Op<"neg_eint"> {
|
||||
|
||||
let summary = "Negates an encrypted integer";
|
||||
|
||||
@@ -94,11 +94,11 @@ def NegEintOp : HLFHE_Op<"neg_eint"> {
|
||||
];
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::HLFHE::verifyNegEintOp(*this);
|
||||
return ::mlir::concretelang::FHE::verifyNegEintOp(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def MulEintIntOp : HLFHE_Op<"mul_eint_int"> {
|
||||
def MulEintIntOp : FHE_Op<"mul_eint_int"> {
|
||||
|
||||
let summary = "Mulitplies an encrypted integer and a clear integer";
|
||||
|
||||
@@ -112,11 +112,11 @@ def MulEintIntOp : HLFHE_Op<"mul_eint_int"> {
|
||||
];
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::HLFHE::verifyMulEintIntOp(*this);
|
||||
return ::mlir::concretelang::FHE::verifyMulEintIntOp(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def ApplyLookupTableEintOp : HLFHE_Op<"apply_lookup_table"> {
|
||||
def ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table"> {
|
||||
|
||||
let summary = "Applies a clear lookup table to an encrypted integer";
|
||||
|
||||
@@ -125,7 +125,7 @@ def ApplyLookupTableEintOp : HLFHE_Op<"apply_lookup_table"> {
|
||||
let results = (outs EncryptedIntegerType);
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::HLFHE::verifyApplyLookupTable(*this);
|
||||
return ::mlir::concretelang::FHE::verifyApplyLookupTable(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_HLFHE_IR_HLFHETYPES_H
|
||||
#define CONCRETELANG_DIALECT_HLFHE_IR_HLFHETYPES_H
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHETYPES_H
|
||||
#define CONCRETELANG_DIALECT_FHE_IR_FHETYPES_H
|
||||
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
@@ -10,6 +10,6 @@
|
||||
#include <mlir/IR/DialectImplementation.h>
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "concretelang/Dialect/HLFHE/IR/HLFHEOpsTypes.h.inc"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOpsTypes.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -1,13 +1,13 @@
|
||||
#ifndef CONCRETELANG_DIALECT_HLFHE_IR_HLFHE_TYPES
|
||||
#define CONCRETELANG_DIALECT_HLFHE_IR_HLFHE_TYPES
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHE_TYPES
|
||||
#define CONCRETELANG_DIALECT_FHE_IR_FHE_TYPES
|
||||
|
||||
include "concretelang/Dialect/HLFHE/IR/HLFHEDialect.td"
|
||||
include "concretelang/Dialect/FHE/IR/FHEDialect.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
|
||||
class HLFHE_Type<string name, list<Trait> traits = []> :
|
||||
TypeDef<HLFHE_Dialect, name, traits> { }
|
||||
class FHE_Type<string name, list<Trait> traits = []> :
|
||||
TypeDef<FHE_Dialect, name, traits> { }
|
||||
|
||||
def EncryptedIntegerType : HLFHE_Type<"EncryptedInteger",
|
||||
def EncryptedIntegerType : FHE_Type<"EncryptedInteger",
|
||||
[MemRefElementTypeInterface]> {
|
||||
let mnemonic = "eint";
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
set(LLVM_TARGET_DEFINITIONS FHELinalgOps.td)
|
||||
mlir_tablegen(FHELinalgOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(FHELinalgOps.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(FHELinalgOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=FHELinalg)
|
||||
mlir_tablegen(FHELinalgOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=FHELinalg)
|
||||
mlir_tablegen(FHELinalgOpsDialect.h.inc -gen-dialect-decls -dialect=FHELinalg)
|
||||
mlir_tablegen(FHELinalgOpsDialect.cpp.inc -gen-dialect-defs -dialect=FHELinalg)
|
||||
add_public_tablegen_target(MLIRFHELinalgOpsIncGen)
|
||||
add_dependencies(mlir-headers MLIRFHELinalgOpsIncGen)
|
||||
|
||||
add_concretelang_doc(FHELinalgDialect FHELinalgDialect concretelang/ -gen-dialect-doc)
|
||||
add_concretelang_doc(FHELinalgOps FHELinalgOps concretelang/ -gen-op-doc)
|
||||
add_concretelang_doc(FHELinalgTypes FHELinalgTypes concretelang/ -gen-typedef-doc)
|
||||
@@ -1,13 +1,13 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalgDIALECT_H
|
||||
#define CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalgDIALECT_H
|
||||
#ifndef CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalgDIALECT_H
|
||||
#define CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalgDIALECT_H
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgOpsDialect.h.inc"
|
||||
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOpsDialect.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,15 @@
|
||||
#ifndef CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalg_DIALECT
|
||||
#define CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalg_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def FHELinalg_Dialect : Dialect {
|
||||
let name = "FHELinalg";
|
||||
let summary = "High Level Fully Homorphic Encryption Linalg dialect";
|
||||
let description = [{
|
||||
A dialect for representation of high level linalg operations on fully homomorphic ciphertexts.
|
||||
}];
|
||||
let cppNamespace = "::mlir::concretelang::FHELinalg";
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,13 +1,13 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalgOPS_H
|
||||
#define CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalgOPS_H
|
||||
#ifndef CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalgOPS_H
|
||||
#define CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalgOPS_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "concretelang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
||||
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.h"
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
|
||||
@@ -42,8 +42,8 @@ public:
|
||||
|
||||
/// TensorBinaryEintInt verifies that the operation matches the following
|
||||
/// signature
|
||||
/// `(tensor<...x!HLFHE.eint<$p>>, tensor<...xi$p'>) ->
|
||||
/// tensor<...x!HLFHE.eint<$p>>` where `$p <= $p+1`.
|
||||
/// `(tensor<...x!FHE.eint<$p>>, tensor<...xi$p'>) ->
|
||||
/// tensor<...x!FHE.eint<$p>>` where `$p <= $p+1`.
|
||||
template <typename ConcreteType>
|
||||
class TensorBinaryEintInt
|
||||
: public mlir::OpTrait::TraitBase<ConcreteType, TensorBinaryEintInt> {
|
||||
@@ -55,8 +55,8 @@ public:
|
||||
|
||||
/// TensorBinaryEintInt verifies that the operation matches the following
|
||||
/// signature
|
||||
/// `(tensor<...xi$p'>, tensor<...x!HLFHE.eint<$p>>) ->
|
||||
/// tensor<...x!HLFHE.eint<$p>>` where `$p <= $p+1`.
|
||||
/// `(tensor<...xi$p'>, tensor<...x!FHE.eint<$p>>) ->
|
||||
/// tensor<...x!FHE.eint<$p>>` where `$p <= $p+1`.
|
||||
template <typename ConcreteType>
|
||||
class TensorBinaryIntEint
|
||||
: public mlir::OpTrait::TraitBase<ConcreteType, TensorBinaryEintInt> {
|
||||
@@ -67,8 +67,8 @@ public:
|
||||
};
|
||||
|
||||
/// TensorBinary verify the operation match the following signature
|
||||
/// `(tensor<...x!HLFHE.eint<$p>>, tensor<...x!HLFHE.eint<$p>>) ->
|
||||
/// tensor<...x!HLFHE.eint<$p>>`
|
||||
/// `(tensor<...x!FHE.eint<$p>>, tensor<...x!FHE.eint<$p>>) ->
|
||||
/// tensor<...x!FHE.eint<$p>>`
|
||||
template <typename ConcreteType>
|
||||
class TensorBinaryEint
|
||||
: public mlir::OpTrait::TraitBase<ConcreteType, TensorBinaryEint> {
|
||||
@@ -79,7 +79,7 @@ public:
|
||||
};
|
||||
|
||||
/// TensorBinary verify the operation match the following signature
|
||||
/// `(tensor<...x!HLFHE.eint<$p>>) -> tensor<...x!HLFHE.eint<$p>>`
|
||||
/// `(tensor<...x!FHE.eint<$p>>) -> tensor<...x!FHE.eint<$p>>`
|
||||
template <typename ConcreteType>
|
||||
class TensorUnaryEint
|
||||
: public mlir::OpTrait::TraitBase<ConcreteType, TensorUnaryEint> {
|
||||
@@ -93,6 +93,6 @@ public:
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h.inc"
|
||||
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -1,14 +1,14 @@
|
||||
#ifndef CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_OPS
|
||||
#define CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_OPS
|
||||
#ifndef CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalg_OPS
|
||||
#define CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalg_OPS
|
||||
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
|
||||
include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.td"
|
||||
include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.td"
|
||||
include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.td"
|
||||
include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.td"
|
||||
|
||||
class HLFHELinalg_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<HLFHELinalg_Dialect, mnemonic, traits>;
|
||||
class FHELinalg_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<FHELinalg_Dialect, mnemonic, traits>;
|
||||
|
||||
// TensorBroadcastingRules verify that the operands and result verify the broadcasting rules
|
||||
def TensorBroadcastingRules : NativeOpTrait<"TensorBroadcastingRules">;
|
||||
@@ -18,7 +18,7 @@ def TensorBinaryEint : NativeOpTrait<"TensorBinaryEint">;
|
||||
def TensorUnaryEint : NativeOpTrait<"TensorUnaryEint">;
|
||||
|
||||
|
||||
def AddEintIntOp : HLFHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> {
|
||||
def AddEintIntOp : FHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> {
|
||||
let summary = "Returns a tensor that contains the addition of a tensor of encrypted integers and a tensor of clear integers.";
|
||||
|
||||
let description = [{
|
||||
@@ -28,10 +28,10 @@ def AddEintIntOp : HLFHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, Tens
|
||||
Examples:
|
||||
```mlir
|
||||
// Returns the term to term addition of `%a0` with `%a1`
|
||||
"HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4xi5>) -> tensor<4x!HLFHE.eint<4>>
|
||||
"FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<4>>, tensor<4xi5>) -> tensor<4x!FHE.eint<4>>
|
||||
|
||||
// Returns the term to term addition of `%a0` with `%a1`, where dimensions equal to one are stretched.
|
||||
"HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4xi5>) -> tensor<4x4x4x!HLFHE.eint<4>>
|
||||
"FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x1x4x!FHE.eint<4>>, tensor<1x4x4xi5>) -> tensor<4x4x4x!FHE.eint<4>>
|
||||
|
||||
// Returns the addition of a 3x3 matrix of encrypted integers and a 3x1 matrix (a column) of integers.
|
||||
//
|
||||
@@ -40,7 +40,7 @@ def AddEintIntOp : HLFHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, Tens
|
||||
// [7,8,9] [3] [10,11,12]
|
||||
//
|
||||
// The dimension #1 of operand #2 is stretched as it is equals to 1.
|
||||
"HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
"FHELinalg.add_eint_int"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!FHE.eint<4>>
|
||||
|
||||
// Returns the addition of a 3x3 matrix of encrypted integers and a 1x3 matrix (a line) of integers.
|
||||
//
|
||||
@@ -49,10 +49,10 @@ def AddEintIntOp : HLFHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, Tens
|
||||
// [7,8,9] [8,10,12]
|
||||
//
|
||||
// The dimension #2 of operand #2 is stretched as it is equals to 1.
|
||||
"HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
"FHELinalg.add_eint_int"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!FHE.eint<4>>
|
||||
|
||||
// Same behavior than the previous one, but as the dimension #2 is missing of operand #2.
|
||||
"HLFHELinalg.add_eint_int(%a0, %a1)" : (tensor<3x4x!HLFHE.eint<4>>, tensor<3xi5>) -> tensor<4x4x4x!HLFHE.eint<4>>
|
||||
"FHELinalg.add_eint_int(%a0, %a1)" : (tensor<3x4x!FHE.eint<4>>, tensor<3xi5>) -> tensor<4x4x4x!FHE.eint<4>>
|
||||
|
||||
```
|
||||
}];
|
||||
@@ -71,7 +71,7 @@ def AddEintIntOp : HLFHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, Tens
|
||||
];
|
||||
}
|
||||
|
||||
def AddEintOp : HLFHELinalg_Op<"add_eint", [TensorBroadcastingRules, TensorBinaryEint]> {
|
||||
def AddEintOp : FHELinalg_Op<"add_eint", [TensorBroadcastingRules, TensorBinaryEint]> {
|
||||
let summary = "Returns a tensor that contains the addition of two tensor of encrypted integers.";
|
||||
|
||||
let description = [{
|
||||
@@ -81,10 +81,10 @@ def AddEintOp : HLFHELinalg_Op<"add_eint", [TensorBroadcastingRules, TensorBinar
|
||||
Examples:
|
||||
```mlir
|
||||
// Returns the term to term addition of `%a0` with `%a1`
|
||||
"HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>>
|
||||
"FHELinalg.add_eint"(%a0, %a1) : (tensor<4x!FHE.eint<4>>, tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>>
|
||||
|
||||
// Returns the term to term addition of `%a0` with `%a1`, where dimensions equal to one are stretched.
|
||||
"HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>>
|
||||
"FHELinalg.add_eint"(%a0, %a1) : (tensor<4x1x4x!FHE.eint<4>>, tensor<1x4x4x!FHE.eint<4>>) -> tensor<4x4x4x!FHE.eint<4>>
|
||||
|
||||
// Returns the addition of a 3x3 matrix of encrypted integers and a 3x1 matrix (a column) of encrypted integers.
|
||||
//
|
||||
@@ -93,7 +93,7 @@ def AddEintOp : HLFHELinalg_Op<"add_eint", [TensorBroadcastingRules, TensorBinar
|
||||
// [7,8,9] [3] [10,11,12]
|
||||
//
|
||||
// The dimension #1 of operand #2 is stretched as it is equals to 1.
|
||||
"HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x1x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
"FHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x1x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
|
||||
|
||||
// Returns the addition of a 3x3 matrix of encrypted integers and a 1x3 matrix (a line) of encrypted integers.
|
||||
//
|
||||
@@ -102,10 +102,10 @@ def AddEintOp : HLFHELinalg_Op<"add_eint", [TensorBroadcastingRules, TensorBinar
|
||||
// [7,8,9] [8,10,12]
|
||||
//
|
||||
// The dimension #2 of operand #2 is stretched as it is equals to 1.
|
||||
"HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<1x3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
"FHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<1x3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
|
||||
|
||||
// Same behavior than the previous one, but as the dimension #2 of operand #2 is missing.
|
||||
"HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
"FHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
|
||||
```
|
||||
}];
|
||||
|
||||
@@ -123,7 +123,7 @@ def AddEintOp : HLFHELinalg_Op<"add_eint", [TensorBroadcastingRules, TensorBinar
|
||||
];
|
||||
}
|
||||
|
||||
def SubIntEintOp : HLFHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, TensorBinaryIntEint]> {
|
||||
def SubIntEintOp : FHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, TensorBinaryIntEint]> {
|
||||
let summary = "Returns a tensor that contains the substraction of a tensor of clear integers and a tensor of encrypted integers.";
|
||||
|
||||
let description = [{
|
||||
@@ -133,10 +133,10 @@ def SubIntEintOp : HLFHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, Tens
|
||||
Examples:
|
||||
```mlir
|
||||
// Returns the term to term substraction of `%a0` with `%a1`
|
||||
"HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4xi5>, tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>>
|
||||
"FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4xi5>, tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>>
|
||||
|
||||
// Returns the term to term substraction of `%a0` with `%a1`, where dimensions equal to one are stretched.
|
||||
"HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4x1x4xi5>, tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>>
|
||||
"FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4x1x4xi5>, tensor<1x4x4x!FHE.eint<4>>) -> tensor<4x4x4x!FHE.eint<4>>
|
||||
|
||||
// Returns the substraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers.
|
||||
//
|
||||
@@ -145,7 +145,7 @@ def SubIntEintOp : HLFHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, Tens
|
||||
// [7,8,9] [3] [4,5,6]
|
||||
//
|
||||
// The dimension #1 of operand #2 is stretched as it is equals to 1.
|
||||
"HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<3x3xi5>, tensor<3x1x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
"FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<3x3xi5>, tensor<3x1x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
|
||||
|
||||
// Returns the substraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers.
|
||||
//
|
||||
@@ -154,10 +154,10 @@ def SubIntEintOp : HLFHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, Tens
|
||||
// [7,8,9] [6,6,6]
|
||||
//
|
||||
// The dimension #2 of operand #2 is stretched as it is equals to 1.
|
||||
"HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<3x3xi5>, tensor<1x3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
"FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<3x3xi5>, tensor<1x3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
|
||||
|
||||
// Same behavior than the previous one, but as the dimension #2 is missing of operand #2.
|
||||
"HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<3x3xi5>, tensor<3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
"FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<3x3xi5>, tensor<3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
|
||||
|
||||
```
|
||||
}];
|
||||
@@ -176,7 +176,7 @@ def SubIntEintOp : HLFHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, Tens
|
||||
];
|
||||
}
|
||||
|
||||
def NegEintOp : HLFHELinalg_Op<"neg_eint", [TensorUnaryEint]> {
|
||||
def NegEintOp : FHELinalg_Op<"neg_eint", [TensorUnaryEint]> {
|
||||
let summary = "Returns a tensor that contains the negation of a tensor of encrypted integers.";
|
||||
|
||||
let description = [{
|
||||
@@ -185,7 +185,7 @@ def NegEintOp : HLFHELinalg_Op<"neg_eint", [TensorUnaryEint]> {
|
||||
Examples:
|
||||
```mlir
|
||||
// Returns the term to term negation of `%a0`
|
||||
"HLFHELinalg.neg_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
"FHELinalg.neg_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
|
||||
//
|
||||
// ( [1,2,3] ) [31,30,29]
|
||||
// negate ( [4,5,6] ) = [28,27,26]
|
||||
@@ -208,7 +208,7 @@ def NegEintOp : HLFHELinalg_Op<"neg_eint", [TensorUnaryEint]> {
|
||||
];
|
||||
}
|
||||
|
||||
def MulEintIntOp : HLFHELinalg_Op<"mul_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> {
|
||||
def MulEintIntOp : FHELinalg_Op<"mul_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> {
|
||||
let summary = "Returns a tensor that contains the multiplication of a tensor of encrypted integers and a tensor of clear integers.";
|
||||
|
||||
let description = [{
|
||||
@@ -218,10 +218,10 @@ def MulEintIntOp : HLFHELinalg_Op<"mul_eint_int", [TensorBroadcastingRules, Tens
|
||||
Examples:
|
||||
```mlir
|
||||
// Returns the term to term multiplication of `%a0` with `%a1`
|
||||
"HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4xi5>) -> tensor<4x!HLFHE.eint<4>>
|
||||
"FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<4>>, tensor<4xi5>) -> tensor<4x!FHE.eint<4>>
|
||||
|
||||
// Returns the term to term multiplication of `%a0` with `%a1`, where dimensions equal to one are stretched.
|
||||
"HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4xi5>) -> tensor<4x4x4x!HLFHE.eint<4>>
|
||||
"FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x1x4x!FHE.eint<4>>, tensor<1x4x4xi5>) -> tensor<4x4x4x!FHE.eint<4>>
|
||||
|
||||
// Returns the multiplication of a 3x3 matrix of encrypted integers and a 3x1 matrix (a column) of integers.
|
||||
//
|
||||
@@ -230,7 +230,7 @@ def MulEintIntOp : HLFHELinalg_Op<"mul_eint_int", [TensorBroadcastingRules, Tens
|
||||
// [7,8,9] [3] [21,24,27]
|
||||
//
|
||||
// The dimension #1 of operand #2 is stretched as it is equals to 1.
|
||||
"HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
"FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!FHE.eint<4>>
|
||||
|
||||
// Returns the multiplication of a 3x3 matrix of encrypted integers and a 1x3 matrix (a line) of integers.
|
||||
//
|
||||
@@ -239,10 +239,10 @@ def MulEintIntOp : HLFHELinalg_Op<"mul_eint_int", [TensorBroadcastingRules, Tens
|
||||
// [7,8,9] [8,10,12]
|
||||
//
|
||||
// The dimension #2 of operand #2 is stretched as it is equals to 1.
|
||||
"HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
"FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!FHE.eint<4>>
|
||||
|
||||
// Same behavior than the previous one, but as the dimension #2 is missing of operand #2.
|
||||
"HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3xi5>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
"FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3xi5>) -> tensor<3x3x!FHE.eint<4>>
|
||||
|
||||
```
|
||||
}];
|
||||
@@ -255,7 +255,7 @@ def MulEintIntOp : HLFHELinalg_Op<"mul_eint_int", [TensorBroadcastingRules, Tens
|
||||
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
|
||||
}
|
||||
|
||||
def ApplyLookupTableEintOp : HLFHELinalg_Op<"apply_lookup_table", []> {
|
||||
def ApplyLookupTableEintOp : FHELinalg_Op<"apply_lookup_table", []> {
|
||||
let summary = "Returns a tensor that contains the result of the lookup on a table.";
|
||||
|
||||
let description = [{
|
||||
@@ -264,7 +264,7 @@ def ApplyLookupTableEintOp : HLFHELinalg_Op<"apply_lookup_table", []> {
|
||||
```mlir
|
||||
// The result of this operation, is a tensor that contains the result of the lookup on a table.
|
||||
// i.e. %res[i, ..., k] = %lut[%t[i, ..., k]]
|
||||
%res = HLFHELinalg.apply_lookup_table(%t, %lut): tensor<DNx...xD1x!HLFHE.eint<$p>>, tensor<D2^$pxi64> -> tensor<DNx...xD1x!HLFHE.eint<$p>>
|
||||
%res = FHELinalg.apply_lookup_table(%t, %lut): tensor<DNx...xD1x!FHE.eint<$p>>, tensor<D2^$pxi64> -> tensor<DNx...xD1x!FHE.eint<$p>>
|
||||
```
|
||||
|
||||
The `%lut` argument must be a tensor with one dimension, where its dimension is equals to `2^p` where `p` is the width of the encrypted integers.
|
||||
@@ -277,7 +277,7 @@ def ApplyLookupTableEintOp : HLFHELinalg_Op<"apply_lookup_table", []> {
|
||||
// [0,1,2] [1,3,5]
|
||||
// [3,0,1] lut [1,3,5,7] = [7,1,3]
|
||||
// [2,3,0] [5,7,1]
|
||||
"HLFHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!HLFHE.eint<3>>
|
||||
"FHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!FHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!FHE.eint<3>>
|
||||
```
|
||||
}];
|
||||
|
||||
@@ -289,11 +289,11 @@ def ApplyLookupTableEintOp : HLFHELinalg_Op<"apply_lookup_table", []> {
|
||||
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::HLFHELinalg::verifyApplyLookupTable(*this);
|
||||
return ::mlir::concretelang::FHELinalg::verifyApplyLookupTable(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def ApplyMultiLookupTableEintOp : HLFHELinalg_Op<"apply_multi_lookup_table", []> {
|
||||
def ApplyMultiLookupTableEintOp : FHELinalg_Op<"apply_multi_lookup_table", []> {
|
||||
let summary = "Returns a tensor that contains the result of the lookup on a table, using a different lookup table for each element.";
|
||||
|
||||
let description = [{
|
||||
@@ -303,7 +303,7 @@ def ApplyMultiLookupTableEintOp : HLFHELinalg_Op<"apply_multi_lookup_table", []>
|
||||
```mlir
|
||||
// The result of this operation, is a tensor that contains the result of the lookup on different tables.
|
||||
// i.e. %res[i, ..., k] = [ %luts[i][%t[i]], ..., %luts[k][%t[k]] ]
|
||||
%res = HLFHELinalg.apply_multi_lookup_table(%t, %lut): tensor<DNx...xD1x!HLFHE.eint<$p>>, tensor<DMx...xD1xD2^$pxi64> -> tensor<DNx...xD1x!HLFHE.eint<$p>>
|
||||
%res = FHELinalg.apply_multi_lookup_table(%t, %lut): tensor<DNx...xD1x!FHE.eint<$p>>, tensor<DMx...xD1xD2^$pxi64> -> tensor<DNx...xD1x!FHE.eint<$p>>
|
||||
```
|
||||
|
||||
The `%luts` argument should be a tensor with M dimension, where the first M-1 dimensions are broadcastable with the N dimensions of the encrypted tensor,
|
||||
@@ -318,7 +318,7 @@ def ApplyMultiLookupTableEintOp : HLFHELinalg_Op<"apply_multi_lookup_table", []>
|
||||
// [0,1] = [1,2]
|
||||
// [3,0] lut [[1,3,5,7], [0,2,4,6]] = [7,0]
|
||||
// [2,3] = [5,6]
|
||||
"HLFHELinalg.apply_multi_lookup_table"(%t, %luts) : (tensor<3x2x!HLFHE.eint<2>>, tensor<2x4xi64>) -> tensor<3x2x!HLFHE.eint<3>>
|
||||
"FHELinalg.apply_multi_lookup_table"(%t, %luts) : (tensor<3x2x!FHE.eint<2>>, tensor<2x4xi64>) -> tensor<3x2x!FHE.eint<3>>
|
||||
```
|
||||
|
||||
```mlir
|
||||
@@ -326,7 +326,7 @@ def ApplyMultiLookupTableEintOp : HLFHELinalg_Op<"apply_multi_lookup_table", []>
|
||||
// Returns the lookup of a vector of 3 encrypted indices of width 2 on a vector of 3 tables of size 4=2² of clear integers.
|
||||
//
|
||||
// [3,0,1] lut [[1,3,5,7], [0,2,4,6], [1,2,3,4]] = [7,0,2]
|
||||
"HLFHELinalg.apply_multi_lookup_table"(%t, %luts) : (tensor<3x!HLFHE.eint<2>>, tensor<3x4xi64>) -> tensor<3x!HLFHE.eint<3>>
|
||||
"FHELinalg.apply_multi_lookup_table"(%t, %luts) : (tensor<3x!FHE.eint<2>>, tensor<3x4xi64>) -> tensor<3x!FHE.eint<3>>
|
||||
```
|
||||
}];
|
||||
|
||||
@@ -338,11 +338,11 @@ def ApplyMultiLookupTableEintOp : HLFHELinalg_Op<"apply_multi_lookup_table", []>
|
||||
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::HLFHELinalg::verifyApplyMultiLookupTable(*this);
|
||||
return ::mlir::concretelang::FHELinalg::verifyApplyMultiLookupTable(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def ApplyMappedLookupTableEintOp : HLFHELinalg_Op<"apply_mapped_lookup_table", []> {
|
||||
def ApplyMappedLookupTableEintOp : FHELinalg_Op<"apply_mapped_lookup_table", []> {
|
||||
let summary = "Returns a tensor that contains the result of the lookup on a table, using a different lookup table for each element, specified by a map.";
|
||||
|
||||
let description = [{
|
||||
@@ -352,7 +352,7 @@ def ApplyMappedLookupTableEintOp : HLFHELinalg_Op<"apply_mapped_lookup_table", [
|
||||
```mlir
|
||||
// The result of this operation, is a tensor that contains the result of the lookup on different tables.
|
||||
// i.e. %res[i, ..., k] = %luts[ %map[i, ..., k] ][ %t[i, ..., k] ]
|
||||
%res = HLFHELinalg.apply_mapped_lookup_table(%t, %luts, %map): tensor<DNx...xD1x!HLFHE.eint<$p>>, tensor<DM x ^$p>, tensor<DNx...xD1xindex> -> tensor<DNx...xD1x!HLFHE.eint<$p>>
|
||||
%res = FHELinalg.apply_mapped_lookup_table(%t, %luts, %map): tensor<DNx...xD1x!FHE.eint<$p>>, tensor<DM x ^$p>, tensor<DNx...xD1xindex> -> tensor<DNx...xD1x!FHE.eint<$p>>
|
||||
```
|
||||
|
||||
Examples:
|
||||
@@ -363,7 +363,7 @@ def ApplyMappedLookupTableEintOp : HLFHELinalg_Op<"apply_mapped_lookup_table", [
|
||||
// [0,1] [0, 1] = [1,2]
|
||||
// [3,0] lut [[1,3,5,7], [0,2,4,6]] with [0, 1] = [7,0]
|
||||
// [2,3] [0, 1] = [5,6]
|
||||
"HLFHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) : (tensor<3x2x!HLFHE.eint<2>>, tensor<2x4xi64>, tensor<3x2xindex>) -> tensor<3x2x!HLFHE.eint<3>>
|
||||
"FHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) : (tensor<3x2x!FHE.eint<2>>, tensor<2x4xi64>, tensor<3x2xindex>) -> tensor<3x2x!FHE.eint<3>>
|
||||
```
|
||||
|
||||
Others examples:
|
||||
@@ -394,12 +394,12 @@ def ApplyMappedLookupTableEintOp : HLFHELinalg_Op<"apply_mapped_lookup_table", [
|
||||
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::HLFHELinalg::verifyApplyMappedLookupTable(*this);
|
||||
return ::mlir::concretelang::FHELinalg::verifyApplyMappedLookupTable(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
// Dot product
|
||||
def Dot : HLFHELinalg_Op<"dot_eint_int"> {
|
||||
def Dot : FHELinalg_Op<"dot_eint_int"> {
|
||||
let summary = "Returns the encrypted dot product between a vector of encrypted integers and a vector of clean integers.";
|
||||
|
||||
let description = [{
|
||||
@@ -408,7 +408,7 @@ def Dot : HLFHELinalg_Op<"dot_eint_int"> {
|
||||
Examples:
|
||||
```mlir
|
||||
// Returns the dot product of `%a0` with `%a1`
|
||||
"HLFHELinalg.dot_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4xi5>) -> !HLFHE.eint<4>
|
||||
"FHELinalg.dot_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<4>>, tensor<4xi5>) -> !FHE.eint<4>
|
||||
|
||||
```
|
||||
}];
|
||||
@@ -420,11 +420,11 @@ def Dot : HLFHELinalg_Op<"dot_eint_int"> {
|
||||
let results = (outs EncryptedIntegerType:$out);
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::HLFHELinalg::verifyDotEintInt(*this);
|
||||
return ::mlir::concretelang::FHELinalg::verifyDotEintInt(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def MatMulEintIntOp : HLFHELinalg_Op<"matmul_eint_int", [TensorBinaryEintInt]> {
|
||||
def MatMulEintIntOp : FHELinalg_Op<"matmul_eint_int", [TensorBinaryEintInt]> {
|
||||
let summary = "Returns a tensor that contains the result of the matrix multiplication of a matrix of encrypted integers and a matrix of clear integers.";
|
||||
|
||||
let description = [{
|
||||
@@ -432,7 +432,7 @@ def MatMulEintIntOp : HLFHELinalg_Op<"matmul_eint_int", [TensorBinaryEintInt]> {
|
||||
The width of the clear integers must be less than or equals to the witdh of encrypted integers.
|
||||
|
||||
```mlir
|
||||
"HLFHELinalg.matmul_eint_int(%a, %b) : (tensor<MxNx!HLFHE.eint<p>>, tensor<NxPxip'>) -> tensor<MxPx!HLFHE.eint<p>>"
|
||||
"FHELinalg.matmul_eint_int(%a, %b) : (tensor<MxNx!FHE.eint<p>>, tensor<NxPxip'>) -> tensor<MxPx!FHE.eint<p>>"
|
||||
```
|
||||
|
||||
Examples:
|
||||
@@ -445,7 +445,7 @@ def MatMulEintIntOp : HLFHELinalg_Op<"matmul_eint_int", [TensorBinaryEintInt]> {
|
||||
// [3,4] = [11,18,25]
|
||||
// [5,6] [17,28,39]
|
||||
//
|
||||
"HLFHELinalg.matmul_eint_int"(%a, %b) : (tensor<3x2x!HLFHE.eint<6>>, tensor<2x3xi7>) -> tensor<3x3x!HLFHE.eint<6>>
|
||||
"FHELinalg.matmul_eint_int"(%a, %b) : (tensor<3x2x!FHE.eint<6>>, tensor<2x3xi7>) -> tensor<3x3x!FHE.eint<6>>
|
||||
|
||||
```
|
||||
}];
|
||||
@@ -458,11 +458,11 @@ def MatMulEintIntOp : HLFHELinalg_Op<"matmul_eint_int", [TensorBinaryEintInt]> {
|
||||
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::HLFHELinalg::verifyMatmul<mlir::concretelang::HLFHELinalg::MatMulEintIntOp>(*this);
|
||||
return ::mlir::concretelang::FHELinalg::verifyMatmul<mlir::concretelang::FHELinalg::MatMulEintIntOp>(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def MatMulIntEintOp : HLFHELinalg_Op<"matmul_int_eint", [TensorBinaryIntEint]> {
|
||||
def MatMulIntEintOp : FHELinalg_Op<"matmul_int_eint", [TensorBinaryIntEint]> {
|
||||
let summary = "Returns a tensor that contains the result of the matrix multiplication of a matrix of clear integers and a matrix of encrypted integers.";
|
||||
|
||||
let description = [{
|
||||
@@ -470,7 +470,7 @@ def MatMulIntEintOp : HLFHELinalg_Op<"matmul_int_eint", [TensorBinaryIntEint]> {
|
||||
The width of the clear integers must be less than or equals to the witdh of encrypted integers.
|
||||
|
||||
```mlir
|
||||
"HLFHELinalg.matmul_int_eint(%a, %b) : (tensor<MxNxip'>, tensor<NxPxHLFHE.eint<p>>) -> tensor<MxPx!HLFHE.eint<p>>"
|
||||
"FHELinalg.matmul_int_eint(%a, %b) : (tensor<MxNxip'>, tensor<NxPxFHE.eint<p>>) -> tensor<MxPx!FHE.eint<p>>"
|
||||
```
|
||||
|
||||
Examples:
|
||||
@@ -483,7 +483,7 @@ def MatMulIntEintOp : HLFHELinalg_Op<"matmul_int_eint", [TensorBinaryIntEint]> {
|
||||
// [3,4] = [11,18,25]
|
||||
// [5,6] [17,28,39]
|
||||
//
|
||||
"HLFHELinalg.matmul_int_eint"(%a, %b) : (tensor<3x2xi7>, tensor<2x3x!HLFHE.eint<6>>) -> tensor<3x3x!HLFHE.eint<6>>
|
||||
"FHELinalg.matmul_int_eint"(%a, %b) : (tensor<3x2xi7>, tensor<2x3x!FHE.eint<6>>) -> tensor<3x3x!FHE.eint<6>>
|
||||
|
||||
```
|
||||
}];
|
||||
@@ -496,11 +496,11 @@ def MatMulIntEintOp : HLFHELinalg_Op<"matmul_int_eint", [TensorBinaryIntEint]> {
|
||||
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::HLFHELinalg::verifyMatmul<mlir::concretelang::HLFHELinalg::MatMulIntEintOp>(*this);
|
||||
return ::mlir::concretelang::FHELinalg::verifyMatmul<mlir::concretelang::FHELinalg::MatMulIntEintOp>(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def ZeroOp : HLFHELinalg_Op<"zero", []> {
|
||||
def ZeroOp : FHELinalg_Op<"zero", []> {
|
||||
let summary = "Creates a new tensor with all elements initialized to an encrypted zero.";
|
||||
|
||||
let description = [{
|
||||
@@ -508,7 +508,7 @@ def ZeroOp : HLFHELinalg_Op<"zero", []> {
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
%tensor = "HLFHELinalg.zero"() : () -> tensor<5x!HLFHE.eint<4>>
|
||||
%tensor = "FHELinalg.zero"() : () -> tensor<5x!FHE.eint<4>>
|
||||
```
|
||||
}];
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalgTYPES_H
|
||||
#define CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalgTYPES_H
|
||||
#ifndef CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalgTYPES_H
|
||||
#define CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalgTYPES_H
|
||||
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
#include <mlir/IR/DialectImplementation.h>
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgOpsTypes.h.inc"
|
||||
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOpsTypes.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,11 @@
|
||||
#ifndef CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalg_TYPES
|
||||
#define CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalg_TYPES
|
||||
|
||||
include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
include "concretelang/Dialect/FHE/IR/FHETypes.td"
|
||||
|
||||
class FHELinalg_Type<string name, list<Trait> traits = []> :
|
||||
TypeDef<FHELinalg_Dialect, name, traits> { }
|
||||
|
||||
#endif
|
||||
@@ -1,3 +1,3 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Tiling.td)
|
||||
mlir_tablegen(Tiling.h.inc -gen-pass-decls -name Transforms)
|
||||
add_public_tablegen_target(ConcretelangHLFHELinalgTilingPassIncGen)
|
||||
add_public_tablegen_target(ConcretelangFHELinalgTilingPassIncGen)
|
||||
@@ -1,21 +1,21 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
#ifndef CONCRETELANG_HLFHELINALG_TILING_PASS_H
|
||||
#define CONCRETELANG_HLFHELINALG_TILING_PASS_H
|
||||
#ifndef CONCRETELANG_FHELINALG_TILING_PASS_H
|
||||
#define CONCRETELANG_FHELINALG_TILING_PASS_H
|
||||
|
||||
#include <mlir/Pass/Pass.h>
|
||||
#include <concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h>
|
||||
#include <concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.h>
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include <concretelang/Dialect/HLFHELinalg/Transforms/Tiling.h.inc>
|
||||
#include <concretelang/Dialect/FHELinalg/Transforms/Tiling.h.inc>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<mlir::OperationPass<>>
|
||||
createHLFHELinalgTilingMarkerPass(llvm::ArrayRef<int64_t> tileSizes);
|
||||
createFHELinalgTilingMarkerPass(llvm::ArrayRef<int64_t> tileSizes);
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<>> createHLFHELinalgTilingPass();
|
||||
std::unique_ptr<mlir::OperationPass<>> createFHELinalgTilingPass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
#ifndef CONCRETELANG_FHELINALG_TILING_PASS
|
||||
#define CONCRETELANG_FHELINALG_TILING_PASS
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def FHELinalgTilingMarker : Pass<"fhe-linalg-tiling-marker"> {
|
||||
let summary =
|
||||
"Marks FHELinalg operations for tiling using a vector of tile sizes";
|
||||
let constructor = "mlir::concretelang::createFHELinalgTilingMarkerPass()";
|
||||
let options = [];
|
||||
let dependentDialects = [ "mlir::concretelang::FHELinalg::FHELinalgDialect" ];
|
||||
}
|
||||
|
||||
def FHELinalgTiling : Pass<"fhe-linalg-tiling"> {
|
||||
let summary = "Performs tiling of FHELinalg operations based on the "
|
||||
"tile-size attribute";
|
||||
let constructor = "mlir::concretelang::createFHELinalgTilingPass()";
|
||||
let options = [];
|
||||
let dependentDialects = [ "mlir::concretelang::FHELinalg::FHELinalgDialect" ];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,13 +0,0 @@
|
||||
set(LLVM_TARGET_DEFINITIONS HLFHEOps.td)
|
||||
mlir_tablegen(HLFHEOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(HLFHEOps.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(HLFHEOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=HLFHE)
|
||||
mlir_tablegen(HLFHEOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=HLFHE)
|
||||
mlir_tablegen(HLFHEOpsDialect.h.inc -gen-dialect-decls -dialect=HLFHE)
|
||||
mlir_tablegen(HLFHEOpsDialect.cpp.inc -gen-dialect-defs -dialect=HLFHE)
|
||||
add_public_tablegen_target(MLIRHLFHEOpsIncGen)
|
||||
add_dependencies(mlir-headers MLIRHLFHEOpsIncGen)
|
||||
|
||||
add_concretelang_doc(HLFHEDialect HLFHEDialect concretelang/ -gen-dialect-doc)
|
||||
add_concretelang_doc(HLFHEOps HLFHEOps concretelang/ -gen-op-doc)
|
||||
add_concretelang_doc(HLFHETypes HLFHETypes concretelang/ -gen-typedef-doc)
|
||||
@@ -1,13 +0,0 @@
|
||||
set(LLVM_TARGET_DEFINITIONS HLFHELinalgOps.td)
|
||||
mlir_tablegen(HLFHELinalgOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(HLFHELinalgOps.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(HLFHELinalgOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=HLFHELinalg)
|
||||
mlir_tablegen(HLFHELinalgOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=HLFHELinalg)
|
||||
mlir_tablegen(HLFHELinalgOpsDialect.h.inc -gen-dialect-decls -dialect=HLFHELinalg)
|
||||
mlir_tablegen(HLFHELinalgOpsDialect.cpp.inc -gen-dialect-defs -dialect=HLFHELinalg)
|
||||
add_public_tablegen_target(MLIRHLFHELinalgOpsIncGen)
|
||||
add_dependencies(mlir-headers MLIRHLFHELinalgOpsIncGen)
|
||||
|
||||
add_concretelang_doc(HLFHELinalgDialect HLFHELinalgDialect concretelang/ -gen-dialect-doc)
|
||||
add_concretelang_doc(HLFHELinalgOps HLFHELinalgOps concretelang/ -gen-op-doc)
|
||||
add_concretelang_doc(HLFHELinalgTypes HLFHELinalgTypes concretelang/ -gen-typedef-doc)
|
||||
@@ -1,15 +0,0 @@
|
||||
#ifndef CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_DIALECT
|
||||
#define CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def HLFHELinalg_Dialect : Dialect {
|
||||
let name = "HLFHELinalg";
|
||||
let summary = "High Level Fully Homorphic Encryption Linalg dialect";
|
||||
let description = [{
|
||||
A dialect for representation of high level linalg operations on fully homomorphic ciphertexts.
|
||||
}];
|
||||
let cppNamespace = "::mlir::concretelang::HLFHELinalg";
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,11 +0,0 @@
|
||||
#ifndef CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_TYPES
|
||||
#define CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_TYPES
|
||||
|
||||
include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
include "concretelang/Dialect/HLFHE/IR/HLFHETypes.td"
|
||||
|
||||
class HLFHELinalg_Type<string name, list<Trait> traits = []> :
|
||||
TypeDef<HLFHELinalg_Dialect, name, traits> { }
|
||||
|
||||
#endif
|
||||
@@ -1,22 +0,0 @@
|
||||
#ifndef CONCRETELANG_HLFHELINALG_TILING_PASS
|
||||
#define CONCRETELANG_HLFHELINALG_TILING_PASS
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def HLFHELinalgTilingMarker : Pass<"hlfhe-linalg-tiling-marker"> {
|
||||
let summary =
|
||||
"Marks HLFHELinalg operations for tiling using a vector of tile sizes";
|
||||
let constructor = "mlir::concretelang::createHLFHELinalgTilingMarkerPass()";
|
||||
let options = [];
|
||||
let dependentDialects = [ "mlir::concretelang::HLFHELinalg::HLFHELinalgDialect" ];
|
||||
}
|
||||
|
||||
def HLFHELinalgTiling : Pass<"hlfhe-linalg-tiling"> {
|
||||
let summary = "Performs tiling of HLFHELinalg operations based on the "
|
||||
"tile-size attribute";
|
||||
let constructor = "mlir::concretelang::createHLFHELinalgTilingPass()";
|
||||
let options = [];
|
||||
let dependentDialects = [ "mlir::concretelang::HLFHELinalg::HLFHELinalgDialect" ];
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,13 +0,0 @@
|
||||
set(LLVM_TARGET_DEFINITIONS LowLFHEOps.td)
|
||||
mlir_tablegen(LowLFHEOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(LowLFHEOps.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(LowLFHEOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=LowLFHE)
|
||||
mlir_tablegen(LowLFHEOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=LowLFHE)
|
||||
mlir_tablegen(LowLFHEOpsDialect.h.inc -gen-dialect-decls -dialect=LowLFHE)
|
||||
mlir_tablegen(LowLFHEOpsDialect.cpp.inc -gen-dialect-defs -dialect=LowLFHE)
|
||||
add_public_tablegen_target(MLIRLowLFHEOpsIncGen)
|
||||
add_dependencies(mlir-headers MLIRLowLFHEOpsIncGen)
|
||||
|
||||
add_concretelang_doc(LowLFHEDialect LowLFHEDialect concretelang/ -gen-dialect-doc)
|
||||
add_concretelang_doc(LowLFHEOps LowLFHEOps concretelang/ -gen-op-doc)
|
||||
add_concretelang_doc(LowLFHETypes LowLFHETypes concretelang/ -gen-typedef-doc)
|
||||
@@ -1,15 +0,0 @@
|
||||
#ifndef CONCRETELANG_DIALECT_LowLFHE_IR_LowLFHE_DIALECT
|
||||
#define CONCRETELANG_DIALECT_LowLFHE_IR_LowLFHE_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def LowLFHE_Dialect : Dialect {
|
||||
let name = "LowLFHE";
|
||||
let summary = "Low Level Fully Homorphic Encryption dialect";
|
||||
let description = [{
|
||||
A dialect for representation of low level operation on fully homomorphic ciphertext.
|
||||
}];
|
||||
let cppNamespace = "::mlir::concretelang::LowLFHE";
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,13 +0,0 @@
|
||||
set(LLVM_TARGET_DEFINITIONS MidLFHEOps.td)
|
||||
mlir_tablegen(MidLFHEOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(MidLFHEOps.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(MidLFHEOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=MidLFHE)
|
||||
mlir_tablegen(MidLFHEOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=MidLFHE)
|
||||
mlir_tablegen(MidLFHEOpsDialect.h.inc -gen-dialect-decls -dialect=MidLFHE)
|
||||
mlir_tablegen(MidLFHEOpsDialect.cpp.inc -gen-dialect-defs -dialect=MidLFHE)
|
||||
add_public_tablegen_target(MLIRMidLFHEOpsIncGen)
|
||||
add_dependencies(mlir-headers MLIRMidLFHEOpsIncGen)
|
||||
|
||||
add_concretelang_doc(MidLFHEDialect MidLFHEDialect concretelang/ -gen-dialect-doc)
|
||||
add_concretelang_doc(MidLFHEOps MidLFHEOps concretelang/ -gen-op-doc)
|
||||
add_concretelang_doc(MidLFHETypes MidLFHETypes concretelang/ -gen-typedef-doc)
|
||||
@@ -8,10 +8,10 @@ def BuildDataflowTaskGraph : Pass<"BuildDataflowTaskGraph", "mlir::ModuleOp"> {
|
||||
"Identify profitable dataflow tasks and build DataflowTaskGraph.";
|
||||
|
||||
let description = [{
|
||||
This pass builds a dataflow graph out of a HLFHE program.
|
||||
This pass builds a dataflow graph out of a FHE program.
|
||||
|
||||
In its current incarnation, it considers some heavier weight
|
||||
operations (e.g., HLFHELinalg Dot and Matmult or bootstraps) as
|
||||
operations (e.g., FHELinalg Dot and Matmult or bootstraps) as
|
||||
candidates for being executed in a discrete task, and then
|
||||
sinks within the task the lighter weight operation that do not
|
||||
increase the graph cut (amount of dependences in or out).
|
||||
@@ -23,21 +23,21 @@ def BuildDataflowTaskGraph : Pass<"BuildDataflowTaskGraph", "mlir::ModuleOp"> {
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
func @main(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> {
|
||||
%0 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1) : (tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
return %0 : tensor<3x2x!HLFHE.eint<2>>
|
||||
func @main(%arg0: tensor<3x4x!FHE.eint<2>>, %arg1: tensor<4x2xi3>) -> tensor<3x2x!FHE.eint<2>> {
|
||||
%0 = "FHELinalg.matmul_eint_int"(%arg0, %arg1) : (tensor<3x4x!FHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!FHE.eint<2>>
|
||||
return %0 : tensor<3x2x!FHE.eint<2>>
|
||||
}
|
||||
```
|
||||
|
||||
Will result in generating a dataflow task for the Matmul operation:
|
||||
|
||||
```mlir
|
||||
func @main(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> {
|
||||
func @main(%arg0: tensor<3x4x!FHE.eint<2>>, %arg1: tensor<4x2xi3>) -> tensor<3x2x!FHE.eint<2>> {
|
||||
%0 = "RT.dataflow_task"(%arg0, %arg1) ( {
|
||||
%1 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1) : (tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
"RT.dataflow_yield"(%1) : (tensor<3x2x!HLFHE.eint<2>>) -> ()
|
||||
}) : (tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
return %0 : tensor<3x2x!HLFHE.eint<2>>
|
||||
%1 = "FHELinalg.matmul_eint_int"(%arg0, %arg1) : (tensor<3x4x!FHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!FHE.eint<2>>
|
||||
"RT.dataflow_yield"(%1) : (tensor<3x2x!FHE.eint<2>>) -> ()
|
||||
}) : (tensor<3x4x!FHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!FHE.eint<2>>
|
||||
return %0 : tensor<3x2x!FHE.eint<2>>
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#ifndef CONCRETELANG_DIALECT_HLFHE_IR_HLFHE_TYPES
|
||||
#define CONCRETELANG_DIALECT_HLFHE_IR_HLFHE_TYPES
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHE_TYPES
|
||||
#define CONCRETELANG_DIALECT_FHE_IR_FHE_TYPES
|
||||
|
||||
include "concretelang/Dialect/RT/IR/RTDialect.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
|
||||
13
compiler/include/concretelang/Dialect/TFHE/IR/CMakeLists.txt
Normal file
13
compiler/include/concretelang/Dialect/TFHE/IR/CMakeLists.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
set(LLVM_TARGET_DEFINITIONS TFHEOps.td)
|
||||
mlir_tablegen(TFHEOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(TFHEOps.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(TFHEOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=TFHE)
|
||||
mlir_tablegen(TFHEOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=TFHE)
|
||||
mlir_tablegen(TFHEOpsDialect.h.inc -gen-dialect-decls -dialect=TFHE)
|
||||
mlir_tablegen(TFHEOpsDialect.cpp.inc -gen-dialect-defs -dialect=TFHE)
|
||||
add_public_tablegen_target(MLIRTFHEOpsIncGen)
|
||||
add_dependencies(mlir-headers MLIRTFHEOpsIncGen)
|
||||
|
||||
add_concretelang_doc(TFHEDialect TFHEDialect concretelang/ -gen-dialect-doc)
|
||||
add_concretelang_doc(TFHEOps TFHEOps concretelang/ -gen-op-doc)
|
||||
add_concretelang_doc(TFHETypes TFHETypes concretelang/ -gen-typedef-doc)
|
||||
@@ -1,8 +1,8 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_LowLFHE_IR_LowLFHEDIALECT_H
|
||||
#define CONCRETELANG_DIALECT_LowLFHE_IR_LowLFHEDIALECT_H
|
||||
#ifndef CONCRETELANG_DIALECT_TFHE_IR_TFHEDIALECT_H
|
||||
#define CONCRETELANG_DIALECT_TFHE_IR_TFHEDIALECT_H
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
@@ -11,6 +11,6 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "concretelang/Dialect/LowLFHE/IR/LowLFHEOpsDialect.h.inc"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEOpsDialect.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -1,4 +1,4 @@
|
||||
//===- MidLFHEDialect.td - MidLFHE dialect ----------------*- tablegen -*-===//
|
||||
//===- TFHEDialect.td - TFHE dialect ----------------*- tablegen -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
@@ -6,18 +6,18 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_MidLFHE_IR_MidLFHE_DIALECT
|
||||
#define CONCRETELANG_DIALECT_MidLFHE_IR_MidLFHE_DIALECT
|
||||
#ifndef CONCRETELANG_DIALECT_TFHE_IR_TFHE_DIALECT
|
||||
#define CONCRETELANG_DIALECT_TFHE_IR_TFHE_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def MidLFHE_Dialect : Dialect {
|
||||
let name = "MidLFHE";
|
||||
def TFHE_Dialect : Dialect {
|
||||
let name = "TFHE";
|
||||
let summary = "High Level Fully Homorphic Encryption dialect";
|
||||
let description = [{
|
||||
A dialect for representation of high level operation on fully homomorphic ciphertext.
|
||||
}];
|
||||
let cppNamespace = "::mlir::concretelang::MidLFHE";
|
||||
let cppNamespace = "::mlir::concretelang::TFHE";
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,8 +1,8 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_LowLFHE_LowLFHE_OPS_H
|
||||
#define CONCRETELANG_DIALECT_LowLFHE_LowLFHE_OPS_H
|
||||
#ifndef CONCRETELANG_DIALECT_TFHE_IR_TFHEOPS_H
|
||||
#define CONCRETELANG_DIALECT_TFHE_IR_TFHEOPS_H
|
||||
|
||||
#include <mlir/IR/Builders.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
@@ -10,9 +10,9 @@
|
||||
#include <mlir/Interfaces/ControlFlowInterfaces.h>
|
||||
#include <mlir/Interfaces/SideEffectInterfaces.h>
|
||||
|
||||
#include "concretelang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "concretelang/Dialect/LowLFHE/IR/LowLFHEOps.h.inc"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -1,4 +1,4 @@
|
||||
//===- MidLFHEOps.td - High level FHE dialect ops ----------------*- tablegen -*-===//
|
||||
//===- TFHEOps.td - High level FHE dialect ops ----------------*- tablegen -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
@@ -6,84 +6,84 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_MidLFHE_IR_MidLFHE_OPS
|
||||
#define CONCRETELANG_DIALECT_MidLFHE_IR_MidLFHE_OPS
|
||||
#ifndef CONCRETELANG_DIALECT_TFHE_IR_TFHE_OPS
|
||||
#define CONCRETELANG_DIALECT_TFHE_IR_TFHE_OPS
|
||||
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
|
||||
include "concretelang/Dialect/MidLFHE/IR/MidLFHEDialect.td"
|
||||
include "concretelang/Dialect/MidLFHE/IR/MidLFHETypes.td"
|
||||
include "concretelang/Dialect/TFHE/IR/TFHEDialect.td"
|
||||
include "concretelang/Dialect/TFHE/IR/TFHETypes.td"
|
||||
|
||||
class MidLFHE_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<MidLFHE_Dialect, mnemonic, traits>;
|
||||
class TFHE_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<TFHE_Dialect, mnemonic, traits>;
|
||||
|
||||
def ZeroGLWEOp : MidLFHE_Op<"zero"> {
|
||||
def ZeroGLWEOp : TFHE_Op<"zero"> {
|
||||
let summary = "Returns a trivial encyption of 0";
|
||||
|
||||
let arguments = (ins);
|
||||
let results = (outs GLWECipherTextType:$out);
|
||||
}
|
||||
|
||||
def AddGLWEIntOp : MidLFHE_Op<"add_glwe_int"> {
|
||||
def AddGLWEIntOp : TFHE_Op<"add_glwe_int"> {
|
||||
let summary = "Returns the sum of a clear integer and a lwe ciphertext";
|
||||
|
||||
let arguments = (ins GLWECipherTextType:$a, AnyInteger:$b);
|
||||
let results = (outs GLWECipherTextType);
|
||||
|
||||
let verifier = [{
|
||||
return mlir::concretelang::MidLFHE::verifyGLWEIntegerOperator<AddGLWEIntOp>(*this);
|
||||
return mlir::concretelang::TFHE::verifyGLWEIntegerOperator<AddGLWEIntOp>(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def AddGLWEOp : MidLFHE_Op<"add_glwe"> {
|
||||
def AddGLWEOp : TFHE_Op<"add_glwe"> {
|
||||
let summary = "Returns the sum of 2 lwe ciphertexts";
|
||||
|
||||
let arguments = (ins GLWECipherTextType:$a, GLWECipherTextType:$b);
|
||||
let results = (outs GLWECipherTextType);
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::MidLFHE::verifyBinaryGLWEOperator<AddGLWEOp>(*this);
|
||||
return ::mlir::concretelang::TFHE::verifyBinaryGLWEOperator<AddGLWEOp>(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def SubIntGLWEOp : MidLFHE_Op<"sub_int_glwe"> {
|
||||
def SubIntGLWEOp : TFHE_Op<"sub_int_glwe"> {
|
||||
let summary = "Substracts an integer and a GLWE ciphertext";
|
||||
|
||||
let arguments = (ins AnyInteger:$a, GLWECipherTextType:$b);
|
||||
let results = (outs GLWECipherTextType);
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::MidLFHE::verifyIntegerGLWEOperator(*this);
|
||||
return ::mlir::concretelang::TFHE::verifyIntegerGLWEOperator(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def NegGLWEOp : MidLFHE_Op<"neg_glwe"> {
|
||||
def NegGLWEOp : TFHE_Op<"neg_glwe"> {
|
||||
let summary = "Negates a glwe ciphertext";
|
||||
|
||||
let arguments = (ins GLWECipherTextType:$a);
|
||||
let results = (outs GLWECipherTextType);
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::MidLFHE::verifyUnaryGLWEOperator<NegGLWEOp>(*this);
|
||||
return ::mlir::concretelang::TFHE::verifyUnaryGLWEOperator<NegGLWEOp>(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
def MulGLWEIntOp : MidLFHE_Op<"mul_glwe_int"> {
|
||||
def MulGLWEIntOp : TFHE_Op<"mul_glwe_int"> {
|
||||
let summary = "Returns the product of a clear integer and a lwe ciphertext";
|
||||
|
||||
let arguments = (ins GLWECipherTextType:$a, AnyInteger:$b);
|
||||
let results = (outs GLWECipherTextType);
|
||||
|
||||
let verifier = [{
|
||||
return mlir::concretelang::MidLFHE::verifyGLWEIntegerOperator<MulGLWEIntOp>(*this);
|
||||
return mlir::concretelang::TFHE::verifyGLWEIntegerOperator<MulGLWEIntOp>(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
|
||||
def ApplyLookupTable : MidLFHE_Op<"apply_lookup_table"> {
|
||||
def ApplyLookupTable : TFHE_Op<"apply_lookup_table"> {
|
||||
let summary = "Applies a lookup table to a GLWE ciphertext";
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ def ApplyLookupTable : MidLFHE_Op<"apply_lookup_table"> {
|
||||
let results = (outs GLWECipherTextType);
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::MidLFHE::verifyApplyLookupTable(*this);
|
||||
return ::mlir::concretelang::TFHE::verifyApplyLookupTable(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_MIDLFHE_IR_MIDLFHETYPES_H
|
||||
#define CONCRETELANG_DIALECT_MIDLFHE_IR_MIDLFHETYPES_H
|
||||
#ifndef CONCRETELANG_DIALECT_TFHE_IR_TFHETYPES_H
|
||||
#define CONCRETELANG_DIALECT_TFHE_IR_TFHETYPES_H
|
||||
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include <mlir/Dialect/StandardOps/IR/Ops.h>
|
||||
@@ -11,6 +11,6 @@
|
||||
#include <mlir/IR/DialectImplementation.h>
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "concretelang/Dialect/MidLFHE/IR/MidLFHEOpsTypes.h.inc"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEOpsTypes.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -1,14 +1,14 @@
|
||||
#ifndef CONCRETELANG_DIALECT_MidLFHE_IR_MidLFHE_TYPES
|
||||
#define CONCRETELANG_DIALECT_MidLFHE_IR_MidLFHE_TYPES
|
||||
#ifndef CONCRETELANG_DIALECT_TFHE_IR_TFHE_TYPES
|
||||
#define CONCRETELANG_DIALECT_TFHE_IR_TFHE_TYPES
|
||||
|
||||
// TODO: MLWE / GSW
|
||||
|
||||
include "concretelang/Dialect/MidLFHE/IR/MidLFHEDialect.td"
|
||||
include "concretelang/Dialect/TFHE/IR/TFHEDialect.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
|
||||
class MidLFHE_Type<string name, list<Trait> traits = []> : TypeDef<MidLFHE_Dialect, name, traits> { }
|
||||
class TFHE_Type<string name, list<Trait> traits = []> : TypeDef<TFHE_Dialect, name, traits> { }
|
||||
|
||||
def GLWECipherTextType : MidLFHE_Type<"GLWECipherText", [MemRefElementTypeInterface]> {
|
||||
def GLWECipherTextType : TFHE_Type<"GLWECipherText", [MemRefElementTypeInterface]> {
|
||||
let mnemonic = "glwe";
|
||||
|
||||
let summary = "A GLWE ciphertext";
|
||||
@@ -89,22 +89,22 @@ public:
|
||||
ROUND_TRIP,
|
||||
|
||||
// Read sources and exit before any lowering
|
||||
HLFHE,
|
||||
FHE,
|
||||
|
||||
// Read sources and lower all HLFHE operations to MidLFHE
|
||||
// Read sources and lower all FHE operations to TFHE
|
||||
// operations
|
||||
MIDLFHE,
|
||||
TFHE,
|
||||
|
||||
// Read sources and lower all HLFHE and MidLFHE operations to LowLFHE
|
||||
// Read sources and lower all FHE and TFHE operations to Concrete
|
||||
// operations
|
||||
LOWLFHE,
|
||||
CONCRETE,
|
||||
|
||||
// Read sources and lower all HLFHE, MidLFHE and LowLFHE
|
||||
// Read sources and lower all FHE, TFHE and Concrete
|
||||
// operations to canonical MLIR dialects. Cryptographic operations
|
||||
// are lowered to invocations of the concrete library.
|
||||
STD,
|
||||
|
||||
// Read sources and lower all HLFHE, MidLFHE and LowLFHE
|
||||
// Read sources and lower all FHE, TFHE and Concrete
|
||||
// operations to operations from the LLVM dialect. Cryptographic
|
||||
// operations are lowered to invocations of the concrete library.
|
||||
LLVM,
|
||||
@@ -152,14 +152,14 @@ public:
|
||||
void setAutoParallelize(bool v);
|
||||
void setGenerateClientParameters(bool v);
|
||||
void setClientParametersFuncName(const llvm::StringRef &name);
|
||||
void setHLFHELinalgTileSizes(llvm::ArrayRef<int64_t> sizes);
|
||||
void setFHELinalgTileSizes(llvm::ArrayRef<int64_t> sizes);
|
||||
void setEnablePass(std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
protected:
|
||||
llvm::Optional<size_t> overrideMaxEintPrecision;
|
||||
llvm::Optional<size_t> overrideMaxMANP;
|
||||
llvm::Optional<std::string> clientParametersFuncName;
|
||||
llvm::Optional<std::vector<int64_t>> hlfhelinalgTileSizes;
|
||||
llvm::Optional<std::vector<int64_t>> fhelinalgTileSizes;
|
||||
|
||||
bool verifyDiagnostics;
|
||||
bool autoParallelize;
|
||||
|
||||
@@ -19,29 +19,29 @@ mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
llvm::Expected<llvm::Optional<mlir::concretelang::V0FHEConstraint>>
|
||||
getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
getFHEConstraintsFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
tileMarkedHLFHELinalg(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
tileMarkedFHELinalg(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
markHLFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
markFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::ArrayRef<int64_t> tileSizes,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerHLFHEToMidLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerMidLFHEToLowLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerLowLFHEToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
lowerConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
|
||||
@@ -11,11 +11,11 @@ declare_mlir_python_extension(ConcretelangBindingsPythonExtension.Core
|
||||
ADD_TO_PARENT ConcretelangBindingsPythonExtension
|
||||
SOURCES
|
||||
ConcretelangModule.cpp
|
||||
HLFHEModule.cpp
|
||||
FHEModule.cpp
|
||||
CompilerAPIModule.cpp
|
||||
EMBED_CAPI_LINK_LIBS
|
||||
CONCRETELANGCAPIHLFHE
|
||||
CONCRETELANGCAPIHLFHELINALG
|
||||
CONCRETELANGCAPIFHE
|
||||
CONCRETELANGCAPIFHELINALG
|
||||
CONCRETELANGCAPISupport
|
||||
)
|
||||
|
||||
@@ -42,20 +42,20 @@ declare_mlir_python_sources(ConcretelangBindingsPythonSources.Dialects
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT ConcretelangBindingsPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
|
||||
CONCRETELANGBindingsPythonHLFHEOps
|
||||
TD_FILE concrete/lang/dialects/HLFHEOps.td
|
||||
CONCRETELANGBindingsPythonFHEOps
|
||||
TD_FILE concrete/lang/dialects/FHEOps.td
|
||||
SOURCES
|
||||
concrete/lang/dialects/hlfhe.py
|
||||
DIALECT_NAME HLFHE)
|
||||
concrete/lang/dialects/fhe.py
|
||||
DIALECT_NAME FHE)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT ConcretelangBindingsPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
|
||||
CONCRETELANGBindingsPythonHLFHELinalgOps
|
||||
TD_FILE concrete/lang/dialects/HLFHELinalgOps.td
|
||||
CONCRETELANGBindingsPythonFHELinalgOps
|
||||
TD_FILE concrete/lang/dialects/FHELinalgOps.td
|
||||
SOURCES
|
||||
concrete/lang/dialects/hlfhelinalg.py
|
||||
DIALECT_NAME HLFHELinalg)
|
||||
concrete/lang/dialects/fhelinalg.py
|
||||
DIALECT_NAME FHELinalg)
|
||||
|
||||
|
||||
################################################################################
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "CompilerAPIModule.h"
|
||||
#include "concretelang-c/Support/CompilerEngine.h"
|
||||
#include "concretelang/Dialect/HLFHE/IR/HLFHEOpsDialect.h.inc"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
#include "concretelang/Support/JitCompilerEngine.h"
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
#include "concretelang-c/Dialect/HLFHE.h"
|
||||
#include "concretelang-c/Dialect/HLFHELinalg.h"
|
||||
#include "concretelang-c/Dialect/FHE.h"
|
||||
#include "concretelang-c/Dialect/FHELinalg.h"
|
||||
|
||||
#include "llvm-c/ErrorHandling.h"
|
||||
#include "llvm/Support/Signals.h"
|
||||
@@ -29,17 +29,17 @@ PYBIND11_MODULE(_concretelang, m) {
|
||||
MlirContext context = mlirPythonCapsuleToContext(wrappedCapsule.ptr());
|
||||
|
||||
// Collect Concretelang dialects to register.
|
||||
MlirDialectHandle hlfhe = mlirGetDialectHandle__hlfhe__();
|
||||
mlirDialectHandleRegisterDialect(hlfhe, context);
|
||||
mlirDialectHandleLoadDialect(hlfhe, context);
|
||||
MlirDialectHandle hlfhelinalg = mlirGetDialectHandle__hlfhelinalg__();
|
||||
mlirDialectHandleRegisterDialect(hlfhelinalg, context);
|
||||
mlirDialectHandleLoadDialect(hlfhelinalg, context);
|
||||
MlirDialectHandle fhe = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleRegisterDialect(fhe, context);
|
||||
mlirDialectHandleLoadDialect(fhe, context);
|
||||
MlirDialectHandle fhelinalg = mlirGetDialectHandle__fhelinalg__();
|
||||
mlirDialectHandleRegisterDialect(fhelinalg, context);
|
||||
mlirDialectHandleLoadDialect(fhelinalg, context);
|
||||
},
|
||||
"Register Concretelang dialects on a PyMlirContext.");
|
||||
|
||||
py::module hlfhe = m.def_submodule("_hlfhe", "HLFHE API");
|
||||
mlir::concretelang::python::populateDialectHLFHESubmodule(hlfhe);
|
||||
py::module fhe = m.def_submodule("_fhe", "FHE API");
|
||||
mlir::concretelang::python::populateDialectFHESubmodule(fhe);
|
||||
|
||||
py::module api = m.def_submodule("_compiler", "Compiler API");
|
||||
mlir::concretelang::python::populateCompilerAPISubmodule(api);
|
||||
|
||||
@@ -10,7 +10,7 @@ namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace python {
|
||||
|
||||
void populateDialectHLFHESubmodule(pybind11::module &m);
|
||||
void populateDialectFHESubmodule(pybind11::module &m);
|
||||
|
||||
} // namespace python
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "DialectModules.h"
|
||||
|
||||
#include "concretelang-c/Dialect/HLFHE.h"
|
||||
#include "concretelang-c/Dialect/FHE.h"
|
||||
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
@@ -18,13 +18,13 @@
|
||||
using namespace mlir::concretelang;
|
||||
using namespace mlir::python::adaptors;
|
||||
|
||||
/// Populate the hlfhe python module.
|
||||
void mlir::concretelang::python::populateDialectHLFHESubmodule(
|
||||
/// Populate the fhe python module.
|
||||
void mlir::concretelang::python::populateDialectFHESubmodule(
|
||||
pybind11::module &m) {
|
||||
m.doc() = "HLFHE dialect Python native extension";
|
||||
m.doc() = "FHE dialect Python native extension";
|
||||
|
||||
mlir_type_subclass(m, "EncryptedIntegerType",
|
||||
hlfheTypeIsAnEncryptedIntegerType)
|
||||
fheTypeIsAnEncryptedIntegerType)
|
||||
.def_classmethod("get", [](pybind11::object cls, MlirContext ctx,
|
||||
unsigned width) {
|
||||
// We want the user to receive a python exception for not being able to
|
||||
@@ -33,6 +33,6 @@ void mlir::concretelang::python::populateDialectHLFHESubmodule(
|
||||
throw std::invalid_argument("can't create eint with the given width");
|
||||
};
|
||||
return cls(
|
||||
hlfheEncryptedIntegerTypeGetChecked(ctx, width, emitException));
|
||||
fheEncryptedIntegerTypeGetChecked(ctx, width, emitException));
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
#ifndef PYTHON_BINDINGS_FHELINALG_OPS
|
||||
#define PYTHON_BINDINGS_FHELINALG_OPS
|
||||
|
||||
include "mlir/Bindings/Python/Attributes.td"
|
||||
include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,7 @@
|
||||
#ifndef PYTHON_BINDINGS_FHE_OPS
|
||||
#define PYTHON_BINDINGS_FHE_OPS
|
||||
|
||||
include "mlir/Bindings/Python/Attributes.td"
|
||||
include "concretelang/Dialect/FHE/IR/FHEOps.td"
|
||||
|
||||
#endif
|
||||
@@ -1,7 +0,0 @@
|
||||
#ifndef PYTHON_BINDINGS_HLFHELINALG_OPS
|
||||
#define PYTHON_BINDINGS_HLFHELINALG_OPS
|
||||
|
||||
include "mlir/Bindings/Python/Attributes.td"
|
||||
include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td"
|
||||
|
||||
#endif
|
||||
@@ -1,7 +0,0 @@
|
||||
#ifndef PYTHON_BINDINGS_HLFHE_OPS
|
||||
#define PYTHON_BINDINGS_HLFHE_OPS
|
||||
|
||||
include "mlir/Bindings/Python/Attributes.td"
|
||||
include "concretelang/Dialect/HLFHE/IR/HLFHEOps.td"
|
||||
|
||||
#endif
|
||||
@@ -1,6 +1,6 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
"""HLFHE dialect module"""
|
||||
from ._HLFHE_ops_gen import *
|
||||
from mlir._mlir_libs._concretelang._hlfhe import *
|
||||
"""FHE dialect module"""
|
||||
from ._FHE_ops_gen import *
|
||||
from mlir._mlir_libs._concretelang._fhe import *
|
||||
@@ -1,5 +1,5 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
"""HLFHELinalg dialect module"""
|
||||
from ._HLFHELinalg_ops_gen import *
|
||||
"""FHELinalg dialect module"""
|
||||
from ._FHELinalg_ops_gen import *
|
||||
@@ -1,2 +1,2 @@
|
||||
add_subdirectory(HLFHE)
|
||||
add_subdirectory(HLFHELinalg)
|
||||
add_subdirectory(FHE)
|
||||
add_subdirectory(FHELinalg)
|
||||
10
compiler/lib/CAPI/Dialect/FHE/CMakeLists.txt
Normal file
10
compiler/lib/CAPI/Dialect/FHE/CMakeLists.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
set(LLVM_OPTIONAL_SOURCES FHE.cpp)
|
||||
|
||||
add_mlir_public_c_api_library(CONCRETELANGCAPIFHE
|
||||
|
||||
FHE.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRCAPIIR
|
||||
FHEDialect
|
||||
)
|
||||
@@ -1,31 +1,31 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
#include "concretelang-c/Dialect/HLFHE.h"
|
||||
#include "concretelang-c/Dialect/FHE.h"
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Registration.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
#include "concretelang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "concretelang/Dialect/HLFHE/IR/HLFHEOps.h"
|
||||
#include "concretelang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
||||
|
||||
using namespace mlir::concretelang::HLFHE;
|
||||
using namespace mlir::concretelang::FHE;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(HLFHE, hlfhe, HLFHEDialect)
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(FHE, fhe, FHEDialect)
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool hlfheTypeIsAnEncryptedIntegerType(MlirType type) {
|
||||
bool fheTypeIsAnEncryptedIntegerType(MlirType type) {
|
||||
return unwrap(type).isa<EncryptedIntegerType>();
|
||||
}
|
||||
|
||||
MlirType hlfheEncryptedIntegerTypeGetChecked(
|
||||
MlirType fheEncryptedIntegerTypeGetChecked(
|
||||
MlirContext ctx, unsigned width,
|
||||
mlir::function_ref<mlir::InFlightDiagnostic()> emitError) {
|
||||
return wrap(EncryptedIntegerType::getChecked(emitError, unwrap(ctx), width));
|
||||
10
compiler/lib/CAPI/Dialect/FHELinalg/CMakeLists.txt
Normal file
10
compiler/lib/CAPI/Dialect/FHELinalg/CMakeLists.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
set(LLVM_OPTIONAL_SOURCES FHELinalg.cpp)
|
||||
|
||||
add_mlir_public_c_api_library(CONCRETELANGCAPIFHELINALG
|
||||
|
||||
FHELinalg.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRCAPIIR
|
||||
FHELinalgDialect
|
||||
)
|
||||
@@ -1,19 +1,19 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
|
||||
|
||||
#include "concretelang-c/Dialect/HLFHELinalg.h"
|
||||
#include "concretelang-c/Dialect/FHELinalg.h"
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Registration.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
#include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h"
|
||||
#include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h"
|
||||
#include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.h"
|
||||
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.h"
|
||||
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h"
|
||||
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.h"
|
||||
|
||||
using namespace mlir::concretelang::HLFHELinalg;
|
||||
using namespace mlir::concretelang::FHELinalg;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(HLFHELinalg, hlfhelinalg,
|
||||
HLFHELinalgDialect)
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(FHELinalg, fhelinalg,
|
||||
FHELinalgDialect)
|
||||
@@ -1,10 +0,0 @@
|
||||
set(LLVM_OPTIONAL_SOURCES HLFHE.cpp)
|
||||
|
||||
add_mlir_public_c_api_library(CONCRETELANGCAPIHLFHE
|
||||
|
||||
HLFHE.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRCAPIIR
|
||||
HLFHEDialect
|
||||
)
|
||||
@@ -1,10 +0,0 @@
|
||||
set(LLVM_OPTIONAL_SOURCES HLFHELinalg.cpp)
|
||||
|
||||
add_mlir_public_c_api_library(CONCRETELANGCAPIHLFHELINALG
|
||||
|
||||
HLFHELinalg.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRCAPIIR
|
||||
HLFHELinalgDialect
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
add_subdirectory(HLFHEToMidLFHE)
|
||||
add_subdirectory(MidLFHEGlobalParametrization)
|
||||
add_subdirectory(MidLFHEToLowLFHE)
|
||||
add_subdirectory(HLFHETensorOpsToLinalg)
|
||||
add_subdirectory(LowLFHEToConcreteCAPI)
|
||||
add_subdirectory(FHEToTFHE)
|
||||
add_subdirectory(TFHEGlobalParametrization)
|
||||
add_subdirectory(TFHEToConcrete)
|
||||
add_subdirectory(FHETensorOpsToLinalg)
|
||||
add_subdirectory(ConcreteToConcreteCAPI)
|
||||
add_subdirectory(MLIRLowerableDialectsToLLVM)
|
||||
add_subdirectory(LowLFHEUnparametrize)
|
||||
add_subdirectory(ConcreteUnparametrize)
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
add_mlir_dialect_library(ConcreteToConcreteCAPI
|
||||
ConcreteToConcreteCAPI.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
|
||||
|
||||
DEPENDS
|
||||
ConcreteDialect
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRTransforms
|
||||
)
|
||||
|
||||
target_link_libraries(ConcreteToConcreteCAPI PUBLIC MLIRIR)
|
||||
@@ -10,20 +10,20 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
#include "concretelang/Dialect/LowLFHE/IR/LowLFHEOps.h"
|
||||
#include "concretelang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
#include "concretelang/Support/Constants.h"
|
||||
|
||||
class LowLFHEToConcreteCAPITypeConverter : public mlir::TypeConverter {
|
||||
class ConcreteToConcreteCAPITypeConverter : public mlir::TypeConverter {
|
||||
|
||||
public:
|
||||
LowLFHEToConcreteCAPITypeConverter() {
|
||||
ConcreteToConcreteCAPITypeConverter() {
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([&](mlir::concretelang::LowLFHE::PlaintextType type) {
|
||||
addConversion([&](mlir::concretelang::Concrete::PlaintextType type) {
|
||||
return mlir::IntegerType::get(type.getContext(), 64);
|
||||
});
|
||||
addConversion([&](mlir::concretelang::LowLFHE::CleartextType type) {
|
||||
addConversion([&](mlir::concretelang::Concrete::CleartextType type) {
|
||||
return mlir::IntegerType::get(type.getContext(), 64);
|
||||
});
|
||||
}
|
||||
@@ -64,56 +64,56 @@ mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op,
|
||||
// allocate them. All the calls to the C API should be done using this generic
|
||||
// types, and casting should then be performed back to the appropriate type.
|
||||
|
||||
inline mlir::concretelang::LowLFHE::LweCiphertextType
|
||||
inline mlir::concretelang::Concrete::LweCiphertextType
|
||||
getGenericLweCiphertextType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::LowLFHE::LweCiphertextType::get(context, -1, -1);
|
||||
return mlir::concretelang::Concrete::LweCiphertextType::get(context, -1, -1);
|
||||
}
|
||||
|
||||
inline mlir::concretelang::LowLFHE::GlweCiphertextType
|
||||
inline mlir::concretelang::Concrete::GlweCiphertextType
|
||||
getGenericGlweCiphertextType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::LowLFHE::GlweCiphertextType::get(context);
|
||||
return mlir::concretelang::Concrete::GlweCiphertextType::get(context);
|
||||
}
|
||||
|
||||
inline mlir::concretelang::LowLFHE::PlaintextType
|
||||
inline mlir::concretelang::Concrete::PlaintextType
|
||||
getGenericPlaintextType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::LowLFHE::PlaintextType::get(context, -1);
|
||||
return mlir::concretelang::Concrete::PlaintextType::get(context, -1);
|
||||
}
|
||||
|
||||
inline mlir::concretelang::LowLFHE::PlaintextListType
|
||||
inline mlir::concretelang::Concrete::PlaintextListType
|
||||
getGenericPlaintextListType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::LowLFHE::PlaintextListType::get(context);
|
||||
return mlir::concretelang::Concrete::PlaintextListType::get(context);
|
||||
}
|
||||
|
||||
inline mlir::concretelang::LowLFHE::ForeignPlaintextListType
|
||||
inline mlir::concretelang::Concrete::ForeignPlaintextListType
|
||||
getGenericForeignPlaintextListType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::LowLFHE::ForeignPlaintextListType::get(context);
|
||||
return mlir::concretelang::Concrete::ForeignPlaintextListType::get(context);
|
||||
}
|
||||
|
||||
inline mlir::concretelang::LowLFHE::CleartextType
|
||||
inline mlir::concretelang::Concrete::CleartextType
|
||||
getGenericCleartextType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::LowLFHE::CleartextType::get(context, -1);
|
||||
return mlir::concretelang::Concrete::CleartextType::get(context, -1);
|
||||
}
|
||||
|
||||
inline mlir::concretelang::LowLFHE::LweBootstrapKeyType
|
||||
inline mlir::concretelang::Concrete::LweBootstrapKeyType
|
||||
getGenericLweBootstrapKeyType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::LowLFHE::LweBootstrapKeyType::get(context);
|
||||
return mlir::concretelang::Concrete::LweBootstrapKeyType::get(context);
|
||||
}
|
||||
|
||||
inline mlir::concretelang::LowLFHE::LweKeySwitchKeyType
|
||||
inline mlir::concretelang::Concrete::LweKeySwitchKeyType
|
||||
getGenericLweKeySwitchKeyType(mlir::MLIRContext *context) {
|
||||
return mlir::concretelang::LowLFHE::LweKeySwitchKeyType::get(context);
|
||||
return mlir::concretelang::Concrete::LweKeySwitchKeyType::get(context);
|
||||
}
|
||||
|
||||
// Get the generic version of the type.
|
||||
// Useful when iterating over a set of types.
|
||||
mlir::Type getGenericType(mlir::Type baseType) {
|
||||
if (baseType.isa<mlir::concretelang::LowLFHE::LweCiphertextType>()) {
|
||||
if (baseType.isa<mlir::concretelang::Concrete::LweCiphertextType>()) {
|
||||
return getGenericLweCiphertextType(baseType.getContext());
|
||||
}
|
||||
if (baseType.isa<mlir::concretelang::LowLFHE::PlaintextType>()) {
|
||||
if (baseType.isa<mlir::concretelang::Concrete::PlaintextType>()) {
|
||||
return getGenericPlaintextType(baseType.getContext());
|
||||
}
|
||||
if (baseType.isa<mlir::concretelang::LowLFHE::CleartextType>()) {
|
||||
if (baseType.isa<mlir::concretelang::Concrete::CleartextType>()) {
|
||||
return getGenericCleartextType(baseType.getContext());
|
||||
}
|
||||
return baseType;
|
||||
@@ -138,7 +138,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
auto genericBSKType = getGenericLweBootstrapKeyType(rewriter.getContext());
|
||||
auto genericKSKType = getGenericLweKeySwitchKeyType(rewriter.getContext());
|
||||
auto contextType =
|
||||
mlir::concretelang::LowLFHE::ContextType::get(rewriter.getContext());
|
||||
mlir::concretelang::Concrete::ContextType::get(rewriter.getContext());
|
||||
|
||||
auto errType = mlir::IndexType::get(rewriter.getContext());
|
||||
|
||||
@@ -321,7 +321,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
/// LowLFHEOpToConcreteCAPICallPattern<Op> match the `Op` Operation and
|
||||
/// ConcreteOpToConcreteCAPICallPattern<Op> match the `Op` Operation and
|
||||
/// replace with a call to `funcName`, the funcName should be an external
|
||||
/// function that was linked later. It insert the forward declaration of the
|
||||
/// private `funcName` if it not already in the symbol table.
|
||||
@@ -336,8 +336,8 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
/// call_op(err, out, arg0, arg1);
|
||||
/// ```
|
||||
template <typename Op>
|
||||
struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
|
||||
LowLFHEOpToConcreteCAPICallPattern(
|
||||
struct ConcreteOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
|
||||
ConcreteOpToConcreteCAPICallPattern(
|
||||
mlir::MLIRContext *context, mlir::StringRef funcName,
|
||||
mlir::StringRef allocName,
|
||||
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
@@ -346,11 +346,11 @@ struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
|
||||
LowLFHEToConcreteCAPITypeConverter typeConverter;
|
||||
ConcreteToConcreteCAPITypeConverter typeConverter;
|
||||
|
||||
mlir::Type resultType = op->getResultTypes().front();
|
||||
auto lweResultType =
|
||||
resultType.cast<mlir::concretelang::LowLFHE::LweCiphertextType>();
|
||||
resultType.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
// Replace the operation with a call to the `funcName`
|
||||
{
|
||||
// Create the err value
|
||||
@@ -402,21 +402,21 @@ private:
|
||||
std::string allocName;
|
||||
};
|
||||
|
||||
struct LowLFHEZeroOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::LowLFHE::ZeroLWEOp> {
|
||||
LowLFHEZeroOpPattern(
|
||||
struct ConcreteZeroOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::ZeroLWEOp> {
|
||||
ConcreteZeroOpPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::concretelang::LowLFHE::ZeroLWEOp>(context,
|
||||
: mlir::OpRewritePattern<mlir::concretelang::Concrete::ZeroLWEOp>(context,
|
||||
benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::LowLFHE::ZeroLWEOp op,
|
||||
matchAndRewrite(mlir::concretelang::Concrete::ZeroLWEOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::Type resultType = op->getResultTypes().front();
|
||||
auto lweResultType =
|
||||
resultType.cast<mlir::concretelang::LowLFHE::LweCiphertextType>();
|
||||
resultType.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
// Create the err value
|
||||
auto errOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(0));
|
||||
@@ -438,16 +438,16 @@ struct LowLFHEZeroOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
struct LowLFHEEncodeIntOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::LowLFHE::EncodeIntOp> {
|
||||
LowLFHEEncodeIntOpPattern(
|
||||
struct ConcreteEncodeIntOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::EncodeIntOp> {
|
||||
ConcreteEncodeIntOpPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::concretelang::LowLFHE::EncodeIntOp>(context,
|
||||
: mlir::OpRewritePattern<mlir::concretelang::Concrete::EncodeIntOp>(context,
|
||||
benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::LowLFHE::EncodeIntOp op,
|
||||
matchAndRewrite(mlir::concretelang::Concrete::EncodeIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
{
|
||||
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
|
||||
@@ -463,16 +463,16 @@ struct LowLFHEEncodeIntOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
struct LowLFHEIntToCleartextOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::LowLFHE::IntToCleartextOp> {
|
||||
LowLFHEIntToCleartextOpPattern(
|
||||
struct ConcreteIntToCleartextOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::IntToCleartextOp> {
|
||||
ConcreteIntToCleartextOpPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::concretelang::LowLFHE::IntToCleartextOp>(
|
||||
: mlir::OpRewritePattern<mlir::concretelang::Concrete::IntToCleartextOp>(
|
||||
context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::LowLFHE::IntToCleartextOp op,
|
||||
matchAndRewrite(mlir::concretelang::Concrete::IntToCleartextOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(
|
||||
op, rewriter.getIntegerType(64), op->getOperands().front());
|
||||
@@ -489,17 +489,17 @@ struct LowLFHEIntToCleartextOpPattern
|
||||
// - construct the GLWE accumulator by adding the plaintext_list to a freshly
|
||||
// allocated GLWE
|
||||
struct GlweFromTableOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::LowLFHE::GlweFromTable> {
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::GlweFromTable> {
|
||||
GlweFromTableOpPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::concretelang::LowLFHE::GlweFromTable>(
|
||||
: mlir::OpRewritePattern<mlir::concretelang::Concrete::GlweFromTable>(
|
||||
context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::LowLFHE::GlweFromTable op,
|
||||
matchAndRewrite(mlir::concretelang::Concrete::GlweFromTable op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
LowLFHEToConcreteCAPITypeConverter typeConverter;
|
||||
ConcreteToConcreteCAPITypeConverter typeConverter;
|
||||
auto errType = mlir::IndexType::get(rewriter.getContext());
|
||||
|
||||
// TODO: move this to insertForwardDeclarations
|
||||
@@ -589,8 +589,8 @@ mlir::Value getContextArgument(mlir::Operation *op) {
|
||||
|
||||
mlir::Value context = block->getArguments().back();
|
||||
|
||||
assert(context.getType().isa<mlir::concretelang::LowLFHE::ContextType>() &&
|
||||
"the LowLFHE.context should be the last argument of the enclosing "
|
||||
assert(context.getType().isa<mlir::concretelang::Concrete::ContextType>() &&
|
||||
"the Concrete.context should be the last argument of the enclosing "
|
||||
"function of the op");
|
||||
|
||||
return context;
|
||||
@@ -606,23 +606,23 @@ mlir::Value getContextArgument(mlir::Operation *op) {
|
||||
// - get the global bootstrapping key
|
||||
// - use the key and the input accumulator (GLWE) to bootstrap the input
|
||||
// ciphertext
|
||||
struct LowLFHEBootstrapLweOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::LowLFHE::BootstrapLweOp> {
|
||||
LowLFHEBootstrapLweOpPattern(
|
||||
struct ConcreteBootstrapLweOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::BootstrapLweOp> {
|
||||
ConcreteBootstrapLweOpPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::concretelang::LowLFHE::BootstrapLweOp>(
|
||||
: mlir::OpRewritePattern<mlir::concretelang::Concrete::BootstrapLweOp>(
|
||||
context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::LowLFHE::BootstrapLweOp op,
|
||||
matchAndRewrite(mlir::concretelang::Concrete::BootstrapLweOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto resultType = op->getResultTypes().front();
|
||||
auto errOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(0));
|
||||
// Get the size from the dimension
|
||||
int64_t outputLweDimension =
|
||||
resultType.cast<mlir::concretelang::LowLFHE::LweCiphertextType>()
|
||||
resultType.cast<mlir::concretelang::Concrete::LweCiphertextType>()
|
||||
.getDimension();
|
||||
int64_t outputLweSize = outputLweDimension + 1;
|
||||
mlir::Value lweSizeOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
@@ -670,16 +670,16 @@ struct LowLFHEBootstrapLweOpPattern
|
||||
// - allocate the result LWE ciphertext
|
||||
// - get the global keyswitch key
|
||||
// - use the key to keyswitch the input ciphertext
|
||||
struct LowLFHEKeySwitchLweOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::LowLFHE::KeySwitchLweOp> {
|
||||
LowLFHEKeySwitchLweOpPattern(
|
||||
struct ConcreteKeySwitchLweOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::KeySwitchLweOp> {
|
||||
ConcreteKeySwitchLweOpPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::concretelang::LowLFHE::KeySwitchLweOp>(
|
||||
: mlir::OpRewritePattern<mlir::concretelang::Concrete::KeySwitchLweOp>(
|
||||
context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::LowLFHE::KeySwitchLweOp op,
|
||||
matchAndRewrite(mlir::concretelang::Concrete::KeySwitchLweOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto errOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(0));
|
||||
@@ -687,7 +687,7 @@ struct LowLFHEKeySwitchLweOpPattern
|
||||
int64_t lweDimension =
|
||||
op.getResult()
|
||||
.getType()
|
||||
.cast<mlir::concretelang::LowLFHE::LweCiphertextType>()
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>()
|
||||
.getDimension();
|
||||
int64_t lweSize = lweDimension + 1;
|
||||
mlir::Value lweSizeOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
@@ -724,31 +724,31 @@ struct LowLFHEKeySwitchLweOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
/// Populate the RewritePatternSet with all patterns that rewrite LowLFHE
|
||||
/// Populate the RewritePatternSet with all patterns that rewrite Concrete
|
||||
/// operators to the corresponding function call to the `Concrete C API`.
|
||||
void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns) {
|
||||
patterns.add<LowLFHEOpToConcreteCAPICallPattern<
|
||||
mlir::concretelang::LowLFHE::AddLweCiphertextsOp>>(
|
||||
void populateConcreteToConcreteCAPICall(mlir::RewritePatternSet &patterns) {
|
||||
patterns.add<ConcreteOpToConcreteCAPICallPattern<
|
||||
mlir::concretelang::Concrete::AddLweCiphertextsOp>>(
|
||||
patterns.getContext(), "add_lwe_ciphertexts_u64",
|
||||
"allocate_lwe_ciphertext_u64");
|
||||
patterns.add<LowLFHEOpToConcreteCAPICallPattern<
|
||||
mlir::concretelang::LowLFHE::AddPlaintextLweCiphertextOp>>(
|
||||
patterns.add<ConcreteOpToConcreteCAPICallPattern<
|
||||
mlir::concretelang::Concrete::AddPlaintextLweCiphertextOp>>(
|
||||
patterns.getContext(), "add_plaintext_lwe_ciphertext_u64",
|
||||
"allocate_lwe_ciphertext_u64");
|
||||
patterns.add<LowLFHEOpToConcreteCAPICallPattern<
|
||||
mlir::concretelang::LowLFHE::MulCleartextLweCiphertextOp>>(
|
||||
patterns.add<ConcreteOpToConcreteCAPICallPattern<
|
||||
mlir::concretelang::Concrete::MulCleartextLweCiphertextOp>>(
|
||||
patterns.getContext(), "mul_cleartext_lwe_ciphertext_u64",
|
||||
"allocate_lwe_ciphertext_u64");
|
||||
patterns.add<LowLFHEOpToConcreteCAPICallPattern<
|
||||
mlir::concretelang::LowLFHE::NegateLweCiphertextOp>>(
|
||||
patterns.add<ConcreteOpToConcreteCAPICallPattern<
|
||||
mlir::concretelang::Concrete::NegateLweCiphertextOp>>(
|
||||
patterns.getContext(), "negate_lwe_ciphertext_u64",
|
||||
"allocate_lwe_ciphertext_u64");
|
||||
patterns.add<LowLFHEEncodeIntOpPattern>(patterns.getContext());
|
||||
patterns.add<LowLFHEIntToCleartextOpPattern>(patterns.getContext());
|
||||
patterns.add<LowLFHEZeroOpPattern>(patterns.getContext());
|
||||
patterns.add<ConcreteEncodeIntOpPattern>(patterns.getContext());
|
||||
patterns.add<ConcreteIntToCleartextOpPattern>(patterns.getContext());
|
||||
patterns.add<ConcreteZeroOpPattern>(patterns.getContext());
|
||||
patterns.add<GlweFromTableOpPattern>(patterns.getContext());
|
||||
patterns.add<LowLFHEKeySwitchLweOpPattern>(patterns.getContext());
|
||||
patterns.add<LowLFHEBootstrapLweOpPattern>(patterns.getContext());
|
||||
patterns.add<ConcreteKeySwitchLweOpPattern>(patterns.getContext());
|
||||
patterns.add<ConcreteBootstrapLweOpPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
struct AddRuntimeContextToFuncOpPattern
|
||||
@@ -764,11 +764,11 @@ struct AddRuntimeContextToFuncOpPattern
|
||||
mlir::OpBuilder::InsertionGuard guard(rewriter);
|
||||
mlir::FunctionType oldFuncType = oldFuncOp.getType();
|
||||
|
||||
// Add a LowLFHE.context to the function signature
|
||||
// Add a Concrete.context to the function signature
|
||||
mlir::SmallVector<mlir::Type> newInputs(oldFuncType.getInputs().begin(),
|
||||
oldFuncType.getInputs().end());
|
||||
newInputs.push_back(
|
||||
rewriter.getType<mlir::concretelang::LowLFHE::ContextType>());
|
||||
rewriter.getType<mlir::concretelang::Concrete::ContextType>());
|
||||
mlir::FunctionType newFuncTy = rewriter.getType<mlir::FunctionType>(
|
||||
newInputs, oldFuncType.getResults());
|
||||
// Create the new func
|
||||
@@ -793,20 +793,20 @@ struct AddRuntimeContextToFuncOpPattern
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Legal function are one that are private or has a LowLFHE.context as last
|
||||
// Legal function are one that are private or has a Concrete.context as last
|
||||
// arguments.
|
||||
static bool isLegal(mlir::FuncOp funcOp) {
|
||||
if (!funcOp.isPublic()) {
|
||||
return true;
|
||||
}
|
||||
// TODO : Don't need to add a runtime context for function that doesn't
|
||||
// manipulates lowlfhe types.
|
||||
// manipulates concrete types.
|
||||
//
|
||||
// if (!llvm::any_of(funcOp.getType().getInputs(), [](mlir::Type t) {
|
||||
// if (auto tensorTy = t.dyn_cast_or_null<mlir::TensorType>()) {
|
||||
// t = tensorTy.getElementType();
|
||||
// }
|
||||
// return llvm::isa<mlir::concretelang::LowLFHE::LowLFHEDialect>(
|
||||
// return llvm::isa<mlir::concretelang::Concrete::ConcreteDialect>(
|
||||
// t.getDialect());
|
||||
// })) {
|
||||
// return true;
|
||||
@@ -815,21 +815,21 @@ struct AddRuntimeContextToFuncOpPattern
|
||||
funcOp.getType()
|
||||
.getInputs()
|
||||
.back()
|
||||
.isa<mlir::concretelang::LowLFHE::ContextType>();
|
||||
.isa<mlir::concretelang::Concrete::ContextType>();
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct LowLFHEToConcreteCAPIPass
|
||||
: public LowLFHEToConcreteCAPIBase<LowLFHEToConcreteCAPIPass> {
|
||||
struct ConcreteToConcreteCAPIPass
|
||||
: public ConcreteToConcreteCAPIBase<ConcreteToConcreteCAPIPass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void LowLFHEToConcreteCAPIPass::runOnOperation() {
|
||||
void ConcreteToConcreteCAPIPass::runOnOperation() {
|
||||
mlir::ModuleOp op = getOperation();
|
||||
|
||||
// First of all add the LowLFHE.context to the block arguments of function
|
||||
// First of all add the Concrete.context to the block arguments of function
|
||||
// that manipulates ciphertexts.
|
||||
{
|
||||
mlir::ConversionTarget target(getContext());
|
||||
@@ -854,17 +854,17 @@ void LowLFHEToConcreteCAPIPass::runOnOperation() {
|
||||
if (insertForwardDeclarations(op, rewriter).failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
// Rewrite LowLFHE ops to CallOp to the Concrete C API
|
||||
// Rewrite Concrete ops to CallOp to the Concrete C API
|
||||
{
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
target.addIllegalDialect<mlir::concretelang::LowLFHE::LowLFHEDialect>();
|
||||
target.addIllegalDialect<mlir::concretelang::Concrete::ConcreteDialect>();
|
||||
target.addLegalDialect<mlir::BuiltinDialect, mlir::StandardOpsDialect,
|
||||
mlir::memref::MemRefDialect,
|
||||
mlir::arith::ArithmeticDialect>();
|
||||
|
||||
populateLowLFHEToConcreteCAPICall(patterns);
|
||||
populateConcreteToConcreteCAPICall(patterns);
|
||||
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
@@ -876,8 +876,8 @@ void LowLFHEToConcreteCAPIPass::runOnOperation() {
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertLowLFHEToConcreteCAPIPass() {
|
||||
return std::make_unique<LowLFHEToConcreteCAPIPass>();
|
||||
createConvertConcreteToConcreteCAPIPass() {
|
||||
return std::make_unique<ConcreteToConcreteCAPIPass>();
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
16
compiler/lib/Conversion/ConcreteUnparametrize/CMakeLists.txt
Normal file
16
compiler/lib/Conversion/ConcreteUnparametrize/CMakeLists.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
add_mlir_dialect_library(ConcreteUnparametrize
|
||||
ConcreteUnparametrize.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
|
||||
|
||||
DEPENDS
|
||||
ConcreteDialect
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRTransforms
|
||||
)
|
||||
|
||||
target_link_libraries(ConcreteUnparametrize PUBLIC MLIRIR)
|
||||
@@ -7,32 +7,32 @@
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
|
||||
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
|
||||
#include "concretelang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
#include "concretelang/Dialect/LowLFHE/IR/LowLFHEOps.h"
|
||||
#include "concretelang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTOps.h"
|
||||
#include "concretelang/Support/Constants.h"
|
||||
|
||||
/// LowLFHEUnparametrizeTypeConverter is a type converter that unparametrize
|
||||
/// LowLFHE types
|
||||
class LowLFHEUnparametrizeTypeConverter : public mlir::TypeConverter {
|
||||
/// ConcreteUnparametrizeTypeConverter is a type converter that unparametrize
|
||||
/// Concrete types
|
||||
class ConcreteUnparametrizeTypeConverter : public mlir::TypeConverter {
|
||||
|
||||
public:
|
||||
static mlir::Type unparematrizeLowLFHEType(mlir::Type type) {
|
||||
if (type.isa<mlir::concretelang::LowLFHE::PlaintextType>()) {
|
||||
static mlir::Type unparematrizeConcreteType(mlir::Type type) {
|
||||
if (type.isa<mlir::concretelang::Concrete::PlaintextType>()) {
|
||||
return mlir::IntegerType::get(type.getContext(), 64);
|
||||
}
|
||||
if (type.isa<mlir::concretelang::LowLFHE::CleartextType>()) {
|
||||
if (type.isa<mlir::concretelang::Concrete::CleartextType>()) {
|
||||
return mlir::IntegerType::get(type.getContext(), 64);
|
||||
}
|
||||
if (type.isa<mlir::concretelang::LowLFHE::LweCiphertextType>()) {
|
||||
return mlir::concretelang::LowLFHE::LweCiphertextType::get(type.getContext(),
|
||||
if (type.isa<mlir::concretelang::Concrete::LweCiphertextType>()) {
|
||||
return mlir::concretelang::Concrete::LweCiphertextType::get(type.getContext(),
|
||||
-1, -1);
|
||||
}
|
||||
auto tensorType = type.dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
if (tensorType != nullptr) {
|
||||
auto eltTy0 = tensorType.getElementType();
|
||||
auto eltTy1 = unparematrizeLowLFHEType(eltTy0);
|
||||
auto eltTy1 = unparematrizeConcreteType(eltTy0);
|
||||
if (eltTy0 == eltTy1) {
|
||||
return type;
|
||||
}
|
||||
@@ -41,17 +41,17 @@ public:
|
||||
return type;
|
||||
}
|
||||
|
||||
LowLFHEUnparametrizeTypeConverter() {
|
||||
ConcreteUnparametrizeTypeConverter() {
|
||||
addConversion(
|
||||
[](mlir::Type type) { return unparematrizeLowLFHEType(type); });
|
||||
[](mlir::Type type) { return unparematrizeConcreteType(type); });
|
||||
}
|
||||
};
|
||||
|
||||
/// Replace `%1 = unrealized_conversion_cast %0 : t0 to t1` to `%0` where t0 or
|
||||
/// t1 are a LowLFHE type.
|
||||
struct LowLFHEUnrealizedCastReplacementPattern
|
||||
/// t1 are a Concrete type.
|
||||
struct ConcreteUnrealizedCastReplacementPattern
|
||||
: public mlir::OpRewritePattern<mlir::UnrealizedConversionCastOp> {
|
||||
LowLFHEUnrealizedCastReplacementPattern(
|
||||
ConcreteUnrealizedCastReplacementPattern(
|
||||
mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: mlir::OpRewritePattern<mlir::UnrealizedConversionCastOp>(context,
|
||||
@@ -60,9 +60,9 @@ struct LowLFHEUnrealizedCastReplacementPattern
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::UnrealizedConversionCastOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (mlir::isa<mlir::concretelang::LowLFHE::LowLFHEDialect>(
|
||||
if (mlir::isa<mlir::concretelang::Concrete::ConcreteDialect>(
|
||||
op.getOperandTypes()[0].getDialect()) ||
|
||||
mlir::isa<mlir::concretelang::LowLFHE::LowLFHEDialect>(
|
||||
mlir::isa<mlir::concretelang::Concrete::ConcreteDialect>(
|
||||
op.getType(0).getDialect())) {
|
||||
rewriter.replaceOp(op, op.getOperands());
|
||||
return mlir::success();
|
||||
@@ -71,21 +71,21 @@ struct LowLFHEUnrealizedCastReplacementPattern
|
||||
};
|
||||
};
|
||||
|
||||
/// LowLFHEUnparametrizePass remove all parameters of LowLFHE types and remove
|
||||
/// ConcreteUnparametrizePass remove all parameters of Concrete types and remove
|
||||
/// the unrealized_conversion_cast operation that operates on parametrized
|
||||
/// LowLFHE types.
|
||||
struct LowLFHEUnparametrizePass
|
||||
: public LowLFHEUnparametrizeBase<LowLFHEUnparametrizePass> {
|
||||
/// Concrete types.
|
||||
struct ConcreteUnparametrizePass
|
||||
: public ConcreteUnparametrizeBase<ConcreteUnparametrizePass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
|
||||
void LowLFHEUnparametrizePass::runOnOperation() {
|
||||
void ConcreteUnparametrizePass::runOnOperation() {
|
||||
auto op = this->getOperation();
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::OwningRewritePatternList patterns(&getContext());
|
||||
|
||||
LowLFHEUnparametrizeTypeConverter converter;
|
||||
ConcreteUnparametrizeTypeConverter converter;
|
||||
|
||||
// Conversion of linalg.generic operation
|
||||
target
|
||||
@@ -97,13 +97,13 @@ void LowLFHEUnparametrizePass::runOnOperation() {
|
||||
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
|
||||
});
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
|
||||
LowLFHEUnparametrizeTypeConverter>>(
|
||||
ConcreteUnparametrizeTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
|
||||
LowLFHEUnparametrizeTypeConverter>>(
|
||||
ConcreteUnparametrizeTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::scf::ForOp,
|
||||
LowLFHEUnparametrizeTypeConverter>>(
|
||||
ConcreteUnparametrizeTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
// Conversion of function signature and arguments
|
||||
@@ -116,7 +116,7 @@ void LowLFHEUnparametrizePass::runOnOperation() {
|
||||
// Replacement of unrealized_conversion_cast
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::UnrealizedConversionCastOp>(
|
||||
target, converter);
|
||||
patterns.add<LowLFHEUnrealizedCastReplacementPattern>(patterns.getContext());
|
||||
patterns.add<ConcreteUnrealizedCastReplacementPattern>(patterns.getContext());
|
||||
|
||||
// Conversion of tensor operators
|
||||
mlir::concretelang::populateWithTensorTypeConverterPatterns(patterns, target,
|
||||
@@ -142,8 +142,8 @@ void LowLFHEUnparametrizePass::runOnOperation() {
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertLowLFHEUnparametrizePass() {
|
||||
return std::make_unique<LowLFHEUnparametrizePass>();
|
||||
createConvertConcreteUnparametrizePass() {
|
||||
return std::make_unique<ConcreteUnparametrizePass>();
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
17
compiler/lib/Conversion/FHETensorOpsToLinalg/CMakeLists.txt
Normal file
17
compiler/lib/Conversion/FHETensorOpsToLinalg/CMakeLists.txt
Normal file
@@ -0,0 +1,17 @@
|
||||
add_mlir_dialect_library(FHETensorOpsToLinalg
|
||||
TensorOpsToLinalg.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
|
||||
|
||||
DEPENDS
|
||||
FHEDialect
|
||||
FHELinalgDialect
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
FHEDialect
|
||||
FHELinalgDialect)
|
||||
|
||||
target_link_libraries(FHEDialect PUBLIC MLIRIR)
|
||||
@@ -14,59 +14,59 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "concretelang/Dialect/HLFHE/IR/HLFHEOps.h"
|
||||
#include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h"
|
||||
#include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
|
||||
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.h"
|
||||
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h"
|
||||
#include "concretelang/Support/Constants.h"
|
||||
|
||||
struct DotToLinalgGeneric
|
||||
: public ::mlir::OpRewritePattern<mlir::concretelang::HLFHELinalg::Dot> {
|
||||
: public ::mlir::OpRewritePattern<mlir::concretelang::FHELinalg::Dot> {
|
||||
DotToLinalgGeneric(::mlir::MLIRContext *context)
|
||||
: ::mlir::OpRewritePattern<::mlir::concretelang::HLFHELinalg::Dot>(
|
||||
: ::mlir::OpRewritePattern<::mlir::concretelang::FHELinalg::Dot>(
|
||||
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
|
||||
// This rewrite pattern transforms any instance of
|
||||
// `HLFHELinalg.dot_eint_int` to an instance of `linalg.generic` with an
|
||||
// appropriate region using `HLFHE.mul_eint_int` and
|
||||
// `HLFHE.add_eint` operations, an appropriate specification for the
|
||||
// `FHELinalg.dot_eint_int` to an instance of `linalg.generic` with an
|
||||
// appropriate region using `FHE.mul_eint_int` and
|
||||
// `FHE.add_eint` operations, an appropriate specification for the
|
||||
// iteration dimensions and appropriate operations managing the
|
||||
// accumulator of `linalg.generic`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// %o = "HLFHELinalg.dot_eint_int"(%arg0, %arg1) :
|
||||
// (tensor<4x!HLFHE.eint<0>>,
|
||||
// tensor<4xi32>) -> (!HLFHE.eint<0>)
|
||||
// %o = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
|
||||
// (tensor<4x!FHE.eint<0>>,
|
||||
// tensor<4xi32>) -> (!FHE.eint<0>)
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// %0 = "HLFHE.zero"() : () -> !HLFHE.eint<0>
|
||||
// %1 = tensor.from_elements %0 : tensor<1x!HLFHE.eint<0>>
|
||||
// %0 = "FHE.zero"() : () -> !FHE.eint<0>
|
||||
// %1 = tensor.from_elements %0 : tensor<1x!FHE.eint<0>>
|
||||
// %2 = linalg.generic {
|
||||
// indexing_maps = [#map0, #map0, #map1],
|
||||
// iterator_types = ["reduction"]
|
||||
// }
|
||||
// ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<0>>, tensor<2xi32>)
|
||||
// outs(%1 : tensor<1x!HLFHE.eint<0>>) {
|
||||
// ^bb0(%arg2: !HLFHE.eint<0>, %arg3: i32, %arg4: !HLFHE.eint<0>):
|
||||
// %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) :
|
||||
// (!HLFHE.eint<0>, i32) -> !HLFHE.eint<0>
|
||||
// ins(%arg0, %arg1 : tensor<2x!FHE.eint<0>>, tensor<2xi32>)
|
||||
// outs(%1 : tensor<1x!FHE.eint<0>>) {
|
||||
// ^bb0(%arg2: !FHE.eint<0>, %arg3: i32, %arg4: !FHE.eint<0>):
|
||||
// %4 = "FHE.mul_eint_int"(%arg2, %arg3) :
|
||||
// (!FHE.eint<0>, i32) -> !FHE.eint<0>
|
||||
//
|
||||
// %5 = "HLFHE.add_eint"(%4, %arg4) :
|
||||
// (!HLFHE.eint<0>, !HLFHE.eint<0>) -> !HLFHE.eint<0>
|
||||
// %5 = "FHE.add_eint"(%4, %arg4) :
|
||||
// (!FHE.eint<0>, !FHE.eint<0>) -> !FHE.eint<0>
|
||||
//
|
||||
// linalg.yield %5 : !HLFHE.eint<0>
|
||||
// } -> tensor<1x!HLFHE.eint<0>>
|
||||
// linalg.yield %5 : !FHE.eint<0>
|
||||
// } -> tensor<1x!FHE.eint<0>>
|
||||
//
|
||||
// %c0 = constant 0 : index
|
||||
// %o = tensor.extract %2[%c0] : tensor<1x!HLFHE.eint<0>>
|
||||
// %o = tensor.extract %2[%c0] : tensor<1x!FHE.eint<0>>
|
||||
//
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(::mlir::concretelang::HLFHELinalg::Dot dotOp,
|
||||
matchAndRewrite(::mlir::concretelang::FHELinalg::Dot dotOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
// Zero value to initialize accumulator
|
||||
mlir::Value zeroCst = rewriter.create<mlir::concretelang::HLFHE::ZeroEintOp>(
|
||||
mlir::Value zeroCst = rewriter.create<mlir::concretelang::FHE::ZeroEintOp>(
|
||||
dotOp.getLoc(),
|
||||
dotOp.lhs().getType().cast<mlir::ShapedType>().getElementType());
|
||||
|
||||
@@ -95,11 +95,11 @@ struct DotToLinalgGeneric
|
||||
auto regBuilder = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
mlir::ValueRange blockArgs) {
|
||||
mlir::concretelang::HLFHE::MulEintIntOp mul =
|
||||
nestedBuilder.create<mlir::concretelang::HLFHE::MulEintIntOp>(
|
||||
mlir::concretelang::FHE::MulEintIntOp mul =
|
||||
nestedBuilder.create<mlir::concretelang::FHE::MulEintIntOp>(
|
||||
dotOp.getLoc(), blockArgs[0], blockArgs[1]);
|
||||
mlir::concretelang::HLFHE::AddEintOp add =
|
||||
nestedBuilder.create<mlir::concretelang::HLFHE::AddEintOp>(
|
||||
mlir::concretelang::FHE::AddEintOp add =
|
||||
nestedBuilder.create<mlir::concretelang::FHE::AddEintOp>(
|
||||
dotOp.getLoc(), mul, blockArgs[2]);
|
||||
|
||||
nestedBuilder.create<mlir::linalg::YieldOp>(dotOp.getLoc(),
|
||||
@@ -180,16 +180,16 @@ getBroadcastedAffineMapMultiLUT(const mlir::RankedTensorType &resultType,
|
||||
}
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// operators `HLFHELinalgOp` that implements the broadasting rules to an
|
||||
// instance of `linalg.generic` with an appropriate region using `HLFHEOp`
|
||||
// operators `FHELinalgOp` that implements the broadasting rules to an
|
||||
// instance of `linalg.generic` with an appropriate region using `FHEOp`
|
||||
// operation, an appropriate specification for the iteration dimensions and
|
||||
// appropriate operations managing the accumulator of `linalg.generic`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// %res = HLFHELinalg.op(%lhs, %rhs):
|
||||
// (tensor<D$Ax...xD1x!HLFHE.eint<p>>, tensor<D$B'x...xD1'xT>)
|
||||
// -> tensor<DR"x...xD1"x!HLFHE.eint<p>>
|
||||
// %res = FHELinalg.op(%lhs, %rhs):
|
||||
// (tensor<D$Ax...xD1x!FHE.eint<p>>, tensor<D$B'x...xD1'xT>)
|
||||
// -> tensor<DR"x...xD1"x!FHE.eint<p>>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
@@ -205,28 +205,28 @@ getBroadcastedAffineMapMultiLUT(const mlir::RankedTensorType &resultType,
|
||||
// iterator_types = ["parallel", ..., "parallel"], // $R" parallel
|
||||
// }
|
||||
// %init = linalg.init_tensor [DR",...,D1"]
|
||||
// : tensor<DR"x...xD1"x!HLFHE.eint<p>>
|
||||
// : tensor<DR"x...xD1"x!FHE.eint<p>>
|
||||
// %res = linalg.generic {
|
||||
// ins(%lhs, %rhs: tensor<DAx...xD1x!HLFHE.eint<p>>,tensor<DB'x...xD1'xT>)
|
||||
// outs(%init : tensor<DR"x...xD1"x!HLFHE.eint<p>>)
|
||||
// ins(%lhs, %rhs: tensor<DAx...xD1x!FHE.eint<p>>,tensor<DB'x...xD1'xT>)
|
||||
// outs(%init : tensor<DR"x...xD1"x!FHE.eint<p>>)
|
||||
// {
|
||||
// ^bb0(%arg0: !HLFHE.eint<p>, %arg1: T):
|
||||
// %0 = HLFHE.op(%arg0, %arg1): !HLFHE.eint<p>, T ->
|
||||
// !HLFHE.eint<p>
|
||||
// linalg.yield %0 : !HLFHE.eint<p>
|
||||
// ^bb0(%arg0: !FHE.eint<p>, %arg1: T):
|
||||
// %0 = FHE.op(%arg0, %arg1): !FHE.eint<p>, T ->
|
||||
// !FHE.eint<p>
|
||||
// linalg.yield %0 : !FHE.eint<p>
|
||||
// }
|
||||
// }
|
||||
//
|
||||
template <typename HLFHELinalgOp, typename HLFHEOp>
|
||||
struct HLFHELinalgOpToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<HLFHELinalgOp> {
|
||||
HLFHELinalgOpToLinalgGeneric(
|
||||
template <typename FHELinalgOp, typename FHEOp>
|
||||
struct FHELinalgOpToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<FHELinalgOp> {
|
||||
FHELinalgOpToLinalgGeneric(
|
||||
::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: ::mlir::OpRewritePattern<HLFHELinalgOp>(context, benefit) {}
|
||||
: ::mlir::OpRewritePattern<FHELinalgOp>(context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(HLFHELinalgOp linalgOp,
|
||||
matchAndRewrite(FHELinalgOp linalgOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::RankedTensorType resultTy =
|
||||
((mlir::Type)linalgOp->getResult(0).getType())
|
||||
@@ -254,11 +254,11 @@ struct HLFHELinalgOpToLinalgGeneric
|
||||
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
mlir::ValueRange blockArgs) {
|
||||
HLFHEOp hlfheOp = nestedBuilder.create<HLFHEOp>(
|
||||
FHEOp fheOp = nestedBuilder.create<FHEOp>(
|
||||
linalgOp.getLoc(), blockArgs[0], blockArgs[1]);
|
||||
|
||||
nestedBuilder.create<mlir::linalg::YieldOp>(linalgOp.getLoc(),
|
||||
hlfheOp.getResult());
|
||||
fheOp.getResult());
|
||||
};
|
||||
|
||||
// Create the `linalg.generic` op
|
||||
@@ -288,9 +288,9 @@ llvm::SmallVector<llvm::StringRef> parallelIteratorType(int n) {
|
||||
}
|
||||
|
||||
// This class rewrite pattern transforms any instance of
|
||||
// operators `HLFHELinalg.ApplyMappedLookupTableEintOp` that implements the
|
||||
// operators `FHELinalg.ApplyMappedLookupTableEintOp` that implements the
|
||||
// broadasting rules to an instance of `linalg.generic` with an appropriate
|
||||
// region using `HLFHE.ApplyLookupTableEintOp` operation, an appropriate
|
||||
// region using `FHE.ApplyLookupTableEintOp` operation, an appropriate
|
||||
// specification for the iteration dimensions and appropriate operations
|
||||
// managing the accumulator of `linalg.generic`.
|
||||
//
|
||||
@@ -298,20 +298,20 @@ llvm::SmallVector<llvm::StringRef> parallelIteratorType(int n) {
|
||||
// because of a bug in lowering this operation.
|
||||
//
|
||||
// Example:
|
||||
// %res = "HLFHELinalg.apply_mapped_lookup_table"(%t, %luts, %map)
|
||||
// : (tensor<2x3x!HLFHE.eint<2>>, tensor<5x4xi64>, tensor<2x3xindex>)
|
||||
// -> tensor<2x3x!HLFHE.eint<2>>
|
||||
// %res = "FHELinalg.apply_mapped_lookup_table"(%t, %luts, %map)
|
||||
// : (tensor<2x3x!FHE.eint<2>>, tensor<5x4xi64>, tensor<2x3xindex>)
|
||||
// -> tensor<2x3x!FHE.eint<2>>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// #map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// %init = linalg.init_tensor [2, 3] : tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>
|
||||
// %init = linalg.init_tensor [2, 3] : tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>
|
||||
// %output = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types
|
||||
// = ["parallel", "parallel"]} ins(%arg0, %arg2 :
|
||||
// tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<2x3xindex>) outs(%0 :
|
||||
// tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>) {
|
||||
// ^bb0(%arg3: !MidLFHE.glwe<{_,_,_}{2}>, %lut_idx: index, %arg5:
|
||||
// !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors
|
||||
// tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>, tensor<2x3xindex>) outs(%0 :
|
||||
// tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
// ^bb0(%arg3: !TFHE.glwe<{_,_,_}{2}>, %lut_idx: index, %arg5:
|
||||
// !TFHE.glwe<{_,_,_}{2}>): // no predecessors
|
||||
// // SHOULD BE
|
||||
// %lut = tensor.extract_slice %arg1[%[[LUTIDX]], 0] [1,4] [1, 1]
|
||||
// : tensor<5x4xi64> to tensor<4xi64>
|
||||
@@ -323,31 +323,31 @@ llvm::SmallVector<llvm::StringRef> parallelIteratorType(int n) {
|
||||
// ...
|
||||
// %e3 = tensor.extract %arg5[%lut_idx, %i3] : tensor<5x4xi64>
|
||||
// %lut = tensor.from_elements %e0, ..., %e3 : tensor<4xi64>
|
||||
// %res = "MidLFHE.apply_lookup_table"(%arg3, %[[LUT]])
|
||||
// %res = "TFHE.apply_lookup_table"(%arg3, %[[LUT]])
|
||||
// {baseLogBS = -1 : i32, baseLogKS = -1 : i32, k = -1 : i32,
|
||||
// levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS =
|
||||
// -1 : i32, polynomialSize = -1 : i32}
|
||||
// : (!MidLFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) ->
|
||||
// !MidLFHE.glwe<{_,_,_}{2}> linalg.yield %res :
|
||||
// !MidLFHE.glwe<{_,_,_}{2}>
|
||||
// } -> tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>
|
||||
// : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) ->
|
||||
// !TFHE.glwe<{_,_,_}{2}> linalg.yield %res :
|
||||
// !TFHE.glwe<{_,_,_}{2}>
|
||||
// } -> tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>
|
||||
|
||||
namespace HLFHELinalg = mlir::concretelang::HLFHELinalg;
|
||||
namespace FHELinalg = mlir::concretelang::FHELinalg;
|
||||
|
||||
struct HLFHELinalgApplyMappedLookupTableToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<HLFHELinalg::ApplyMappedLookupTableEintOp> {
|
||||
HLFHELinalgApplyMappedLookupTableToLinalgGeneric(
|
||||
struct FHELinalgApplyMappedLookupTableToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<FHELinalg::ApplyMappedLookupTableEintOp> {
|
||||
FHELinalgApplyMappedLookupTableToLinalgGeneric(
|
||||
::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<HLFHELinalg::ApplyMappedLookupTableEintOp>(
|
||||
: ::mlir::OpRewritePattern<FHELinalg::ApplyMappedLookupTableEintOp>(
|
||||
context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(HLFHELinalg::ApplyMappedLookupTableEintOp mappedLookup,
|
||||
matchAndRewrite(FHELinalg::ApplyMappedLookupTableEintOp mappedLookup,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
namespace arith = mlir::arith;
|
||||
namespace linalg = mlir::linalg;
|
||||
namespace tensor = mlir::tensor;
|
||||
namespace HLFHE = mlir::concretelang::HLFHE;
|
||||
namespace FHE = mlir::concretelang::FHE;
|
||||
using Values = llvm::SmallVector<mlir::Value>;
|
||||
using Types = llvm::SmallVector<mlir::Type>;
|
||||
using AffineMaps = llvm::SmallVector<mlir::AffineMap>;
|
||||
@@ -421,9 +421,9 @@ struct HLFHELinalgApplyMappedLookupTableToLinalgGeneric
|
||||
lut = nestedBuilder.create<tensor::FromElementsOp>(loc, extracts);
|
||||
} // WORKAROUND END
|
||||
// %res1 = apply_lookup_table %arg0 %lut
|
||||
auto lookup = nestedBuilder.create<HLFHE::ApplyLookupTableEintOp>(
|
||||
auto lookup = nestedBuilder.create<FHE::ApplyLookupTableEintOp>(
|
||||
loc, elementTy, tElmt, lut);
|
||||
// linalg.yield %res1 : !HLFHE.eint<2>
|
||||
// linalg.yield %res1 : !FHE.eint<2>
|
||||
nestedBuilder.create<linalg::YieldOp>(loc, lookup.getResult());
|
||||
};
|
||||
|
||||
@@ -446,16 +446,16 @@ struct HLFHELinalgApplyMappedLookupTableToLinalgGeneric
|
||||
};
|
||||
|
||||
// This class rewrite pattern transforms any instance of
|
||||
// operators `HLFHELinalg.ApplyMultiLookupTableEintOp` that implements the
|
||||
// operators `FHELinalg.ApplyMultiLookupTableEintOp` that implements the
|
||||
// broadasting rules to an instance of `linalg.generic` with an appropriate
|
||||
// region using `HLFHE.ApplyLookupTableEintOp` operation, an appropriate
|
||||
// region using `FHE.ApplyLookupTableEintOp` operation, an appropriate
|
||||
// specification for the iteration dimensions and appropriate operaztions
|
||||
// managing the accumulator of `linalg.generic`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// %res = "HLFHELinalg.apply_multi_lookup_table"(%t, %luts):
|
||||
// (tensor<4x3x!HLFHE.eint<2>>, tensor<3x4xi64>) -> tensor<4x3x!HLFHE.eint<2>>
|
||||
// %res = "FHELinalg.apply_multi_lookup_table"(%t, %luts):
|
||||
// (tensor<4x3x!FHE.eint<2>>, tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
@@ -471,50 +471,50 @@ struct HLFHELinalgApplyMappedLookupTableToLinalgGeneric
|
||||
// iterator_types = ["parallel", "parallel"],
|
||||
// }
|
||||
// %init = linalg.init_tensor [4, 3]
|
||||
// : tensor<4x3x!HLFHE.eint<2>>
|
||||
// : tensor<4x3x!FHE.eint<2>>
|
||||
// %res = linalg.generic {
|
||||
// ins(%t, %luts, %luts, %luts, %luts: tensor<4x3x!HLFHE.eint<p>>,
|
||||
// ins(%t, %luts, %luts, %luts, %luts: tensor<4x3x!FHE.eint<p>>,
|
||||
// tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>)
|
||||
// outs(%init : tensor<4x3x!HLFHE.eint<2>>)
|
||||
// outs(%init : tensor<4x3x!FHE.eint<2>>)
|
||||
// {
|
||||
// ^bb0(%arg0: !HLFHE.eint<2>, %arg1: i64, %arg2: i64, %arg3: i64,
|
||||
// %arg4: i64, %arg5: !HLFHE.eint<2>):
|
||||
// ^bb0(%arg0: !FHE.eint<2>, %arg1: i64, %arg2: i64, %arg3: i64,
|
||||
// %arg4: i64, %arg5: !FHE.eint<2>):
|
||||
// %lut = tensor.from_elements %arg1, %arg2, %arg3, %arg4 :
|
||||
// tensor<4xi64> %0 = "MidLFHE.apply_lookup_table"(%arg0, %lut)
|
||||
// tensor<4xi64> %0 = "TFHE.apply_lookup_table"(%arg0, %lut)
|
||||
// {baseLogBS = -1 : i32, baseLogKS = -1 : i32, glweDimension = -1 :
|
||||
// i32, levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS = -1 :
|
||||
// i32, polynomialSize = -1 : i32} : (!MidLFHE.glwe<{_,_,_}{2}>,
|
||||
// tensor<4xi64>) -> !MidLFHE.glwe<{_,_,_}{2}>
|
||||
// linalg.yield %0 : !HLFHE.eint<2>
|
||||
// i32, polynomialSize = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>,
|
||||
// tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
// linalg.yield %0 : !FHE.eint<2>
|
||||
// }
|
||||
// }
|
||||
//
|
||||
struct HLFHELinalgApplyMultiLookupTableToLinalgGeneric
|
||||
struct FHELinalgApplyMultiLookupTableToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::HLFHELinalg::ApplyMultiLookupTableEintOp> {
|
||||
HLFHELinalgApplyMultiLookupTableToLinalgGeneric(
|
||||
mlir::concretelang::FHELinalg::ApplyMultiLookupTableEintOp> {
|
||||
FHELinalgApplyMultiLookupTableToLinalgGeneric(
|
||||
::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: ::mlir::OpRewritePattern<
|
||||
mlir::concretelang::HLFHELinalg::ApplyMultiLookupTableEintOp>(context,
|
||||
mlir::concretelang::FHELinalg::ApplyMultiLookupTableEintOp>(context,
|
||||
benefit) {
|
||||
}
|
||||
|
||||
::mlir::LogicalResult matchAndRewrite(
|
||||
mlir::concretelang::HLFHELinalg::ApplyMultiLookupTableEintOp hlfheLinalgLutOp,
|
||||
mlir::concretelang::FHELinalg::ApplyMultiLookupTableEintOp fheLinalgLutOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::RankedTensorType resultTy =
|
||||
((mlir::Type)hlfheLinalgLutOp->getResult(0).getType())
|
||||
((mlir::Type)fheLinalgLutOp->getResult(0).getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
mlir::RankedTensorType tensorTy =
|
||||
((mlir::Type)hlfheLinalgLutOp.t().getType())
|
||||
((mlir::Type)fheLinalgLutOp.t().getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
mlir::RankedTensorType lutsTy =
|
||||
((mlir::Type)hlfheLinalgLutOp.luts().getType())
|
||||
((mlir::Type)fheLinalgLutOp.luts().getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
// linalg.init_tensor for initial value
|
||||
mlir::Value init = rewriter.create<mlir::linalg::InitTensorOp>(
|
||||
hlfheLinalgLutOp.getLoc(), resultTy.getShape(),
|
||||
fheLinalgLutOp.getLoc(), resultTy.getShape(),
|
||||
resultTy.getElementType());
|
||||
|
||||
auto lutsShape = lutsTy.getShape();
|
||||
@@ -541,51 +541,51 @@ struct HLFHELinalgApplyMultiLookupTableToLinalgGeneric
|
||||
mlir::ValueRange blockArgs) {
|
||||
mlir::tensor::FromElementsOp lut =
|
||||
nestedBuilder.create<mlir::tensor::FromElementsOp>(
|
||||
hlfheLinalgLutOp.getLoc(), blockArgs.slice(1, lut_size));
|
||||
mlir::concretelang::HLFHE::ApplyLookupTableEintOp lutOp =
|
||||
nestedBuilder.create<mlir::concretelang::HLFHE::ApplyLookupTableEintOp>(
|
||||
hlfheLinalgLutOp.getLoc(), resultTy.getElementType(),
|
||||
fheLinalgLutOp.getLoc(), blockArgs.slice(1, lut_size));
|
||||
mlir::concretelang::FHE::ApplyLookupTableEintOp lutOp =
|
||||
nestedBuilder.create<mlir::concretelang::FHE::ApplyLookupTableEintOp>(
|
||||
fheLinalgLutOp.getLoc(), resultTy.getElementType(),
|
||||
blockArgs[0], lut.result());
|
||||
|
||||
nestedBuilder.create<mlir::linalg::YieldOp>(hlfheLinalgLutOp.getLoc(),
|
||||
nestedBuilder.create<mlir::linalg::YieldOp>(fheLinalgLutOp.getLoc(),
|
||||
lutOp.getResult());
|
||||
};
|
||||
|
||||
// Create the `linalg.generic` op
|
||||
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
|
||||
llvm::SmallVector<mlir::Value> ins{hlfheLinalgLutOp.t()};
|
||||
llvm::SmallVector<mlir::Value> ins{fheLinalgLutOp.t()};
|
||||
ins.reserve(lut_size + 2);
|
||||
// We extract one value at a time from one LUT using different maps, so we
|
||||
// need to pass the LUT `lut_size` time
|
||||
for (auto i = 0; i < lut_size; i++)
|
||||
ins.push_back(hlfheLinalgLutOp.luts());
|
||||
ins.push_back(fheLinalgLutOp.luts());
|
||||
llvm::SmallVector<mlir::Value, 1> outs{init};
|
||||
llvm::StringRef doc{""};
|
||||
llvm::StringRef call{""};
|
||||
|
||||
mlir::linalg::GenericOp genericOp =
|
||||
rewriter.create<mlir::linalg::GenericOp>(
|
||||
hlfheLinalgLutOp.getLoc(), resTypes, ins, outs, maps, iteratorTypes,
|
||||
fheLinalgLutOp.getLoc(), resTypes, ins, outs, maps, iteratorTypes,
|
||||
doc, call, bodyBuilder);
|
||||
|
||||
rewriter.replaceOp(hlfheLinalgLutOp, {genericOp.getResult(0)});
|
||||
rewriter.replaceOp(fheLinalgLutOp, {genericOp.getResult(0)});
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// operators `HLFHELinalg.apply_lookup_table` that implements the broadasting
|
||||
// operators `FHELinalg.apply_lookup_table` that implements the broadasting
|
||||
// rules to an instance of `linalg.generic` with an appropriate region using
|
||||
// `HLFHE.apply_lookup_table` operation, an appropriate specification for the
|
||||
// `FHE.apply_lookup_table` operation, an appropriate specification for the
|
||||
// iteration dimensions and appropriate operations managing the accumulator of
|
||||
// `linalg.generic`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// HLFHELinalg.apply_lookup_table(%t, %lut):
|
||||
// tensor<DNx...xD1x!HLFHE.eint<p>>, tensor<DAxi64>
|
||||
// -> tensor<DNx...xD1x!HLFHE.eint<p'>>
|
||||
// FHELinalg.apply_lookup_table(%t, %lut):
|
||||
// tensor<DNx...xD1x!FHE.eint<p>>, tensor<DAxi64>
|
||||
// -> tensor<DNx...xD1x!FHE.eint<p'>>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
@@ -598,30 +598,30 @@ struct HLFHELinalgApplyMultiLookupTableToLinalgGeneric
|
||||
// iterator_types = ["parallel",..],//N parallel
|
||||
// }
|
||||
// %init = linalg.init_tensor [DN,...,D1]
|
||||
// : tensor<DNx...xD1x!HLFHE.eint<p'>>
|
||||
// : tensor<DNx...xD1x!FHE.eint<p'>>
|
||||
// %res = linalg.generic {
|
||||
// ins(%t: tensor<DNx...xD1x!HLFHE.eint<p>>)
|
||||
// outs(%init : tensor<DNx...xD1x!HLFHE.eint<p'>>)
|
||||
// ins(%t: tensor<DNx...xD1x!FHE.eint<p>>)
|
||||
// outs(%init : tensor<DNx...xD1x!FHE.eint<p'>>)
|
||||
// {
|
||||
// ^bb0(%arg0: !HLFHE.eint<p>):
|
||||
// %0 = HLFHE.apply_lookup_table(%arg0, %lut): !HLFHE.eint<p>,
|
||||
// tensor<4xi64> -> !HLFHE.eint<p'>
|
||||
// linalg.yield %0 : !HLFHE.eint<p'>
|
||||
// ^bb0(%arg0: !FHE.eint<p>):
|
||||
// %0 = FHE.apply_lookup_table(%arg0, %lut): !FHE.eint<p>,
|
||||
// tensor<4xi64> -> !FHE.eint<p'>
|
||||
// linalg.yield %0 : !FHE.eint<p'>
|
||||
// }
|
||||
// }
|
||||
//
|
||||
struct HLFHELinalgApplyLookupTableToLinalgGeneric
|
||||
struct FHELinalgApplyLookupTableToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::concretelang::HLFHELinalg::ApplyLookupTableEintOp> {
|
||||
HLFHELinalgApplyLookupTableToLinalgGeneric(
|
||||
mlir::concretelang::FHELinalg::ApplyLookupTableEintOp> {
|
||||
FHELinalgApplyLookupTableToLinalgGeneric(
|
||||
::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: ::mlir::OpRewritePattern<
|
||||
mlir::concretelang::HLFHELinalg::ApplyLookupTableEintOp>(context,
|
||||
mlir::concretelang::FHELinalg::ApplyLookupTableEintOp>(context,
|
||||
benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::HLFHELinalg::ApplyLookupTableEintOp lutOp,
|
||||
matchAndRewrite(mlir::concretelang::FHELinalg::ApplyLookupTableEintOp lutOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::RankedTensorType resultTy =
|
||||
((mlir::Type)lutOp->getResult(0).getType())
|
||||
@@ -649,13 +649,13 @@ struct HLFHELinalgApplyLookupTableToLinalgGeneric
|
||||
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
mlir::ValueRange blockArgs) {
|
||||
mlir::concretelang::HLFHE::ApplyLookupTableEintOp hlfheOp =
|
||||
nestedBuilder.create<mlir::concretelang::HLFHE::ApplyLookupTableEintOp>(
|
||||
mlir::concretelang::FHE::ApplyLookupTableEintOp fheOp =
|
||||
nestedBuilder.create<mlir::concretelang::FHE::ApplyLookupTableEintOp>(
|
||||
lutOp.getLoc(), resultTy.getElementType(), blockArgs[0],
|
||||
lutOp.lut());
|
||||
|
||||
nestedBuilder.create<mlir::linalg::YieldOp>(lutOp.getLoc(),
|
||||
hlfheOp.getResult());
|
||||
fheOp.getResult());
|
||||
};
|
||||
|
||||
// Create the `linalg.generic` op
|
||||
@@ -677,15 +677,15 @@ struct HLFHELinalgApplyLookupTableToLinalgGeneric
|
||||
};
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// operators `HLFHELinalg.neg_eint` to an instance of `linalg.generic` with an
|
||||
// appropriate region using `HLFHE.neg_eint` operation, an appropriate
|
||||
// operators `FHELinalg.neg_eint` to an instance of `linalg.generic` with an
|
||||
// appropriate region using `FHE.neg_eint` operation, an appropriate
|
||||
// specification for the iteration dimensions and appropriate operations
|
||||
// managing the accumulator of `linalg.generic`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// HLFHELinalg.neg_eint(%tensor):
|
||||
// tensor<DNx...xD1x!HLFHE.eint<p>> -> tensor<DNx...xD1x!HLFHE.eint<p'>>
|
||||
// FHELinalg.neg_eint(%tensor):
|
||||
// tensor<DNx...xD1x!FHE.eint<p>> -> tensor<DNx...xD1x!FHE.eint<p'>>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
@@ -698,27 +698,27 @@ struct HLFHELinalgApplyLookupTableToLinalgGeneric
|
||||
// iterator_types = ["parallel",..],//N parallel
|
||||
// }
|
||||
// %init = linalg.init_tensor [DN,...,D1]
|
||||
// : tensor<DNx...xD1x!HLFHE.eint<p'>>
|
||||
// : tensor<DNx...xD1x!FHE.eint<p'>>
|
||||
// %res = linalg.generic {
|
||||
// ins(%tensor: tensor<DNx...xD1x!HLFHE.eint<p>>)
|
||||
// outs(%init : tensor<DNx...xD1x!HLFHE.eint<p'>>)
|
||||
// ins(%tensor: tensor<DNx...xD1x!FHE.eint<p>>)
|
||||
// outs(%init : tensor<DNx...xD1x!FHE.eint<p'>>)
|
||||
// {
|
||||
// ^bb0(%arg0: !HLFHE.eint<p>):
|
||||
// %0 = HLFHE.neg_eint(%arg0): !HLFHE.eint<p> -> !HLFHE.eint<p'>
|
||||
// linalg.yield %0 : !HLFHE.eint<p'>
|
||||
// ^bb0(%arg0: !FHE.eint<p>):
|
||||
// %0 = FHE.neg_eint(%arg0): !FHE.eint<p> -> !FHE.eint<p'>
|
||||
// linalg.yield %0 : !FHE.eint<p'>
|
||||
// }
|
||||
// }
|
||||
//
|
||||
struct HLFHELinalgNegEintToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::HLFHELinalg::NegEintOp> {
|
||||
HLFHELinalgNegEintToLinalgGeneric(
|
||||
struct FHELinalgNegEintToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::FHELinalg::NegEintOp> {
|
||||
FHELinalgNegEintToLinalgGeneric(
|
||||
::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: ::mlir::OpRewritePattern<mlir::concretelang::HLFHELinalg::NegEintOp>(
|
||||
: ::mlir::OpRewritePattern<mlir::concretelang::FHELinalg::NegEintOp>(
|
||||
context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::HLFHELinalg::NegEintOp negEintOp,
|
||||
matchAndRewrite(mlir::concretelang::FHELinalg::NegEintOp negEintOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::RankedTensorType resultTy =
|
||||
((mlir::Type)negEintOp->getResult(0).getType())
|
||||
@@ -746,12 +746,12 @@ struct HLFHELinalgNegEintToLinalgGeneric
|
||||
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
mlir::ValueRange blockArgs) {
|
||||
mlir::concretelang::HLFHE::NegEintOp hlfheOp =
|
||||
nestedBuilder.create<mlir::concretelang::HLFHE::NegEintOp>(
|
||||
mlir::concretelang::FHE::NegEintOp fheOp =
|
||||
nestedBuilder.create<mlir::concretelang::FHE::NegEintOp>(
|
||||
negEintOp.getLoc(), resultTy.getElementType(), blockArgs[0]);
|
||||
|
||||
nestedBuilder.create<mlir::linalg::YieldOp>(negEintOp.getLoc(),
|
||||
hlfheOp.getResult());
|
||||
fheOp.getResult());
|
||||
};
|
||||
|
||||
// Create the `linalg.generic` op
|
||||
@@ -773,17 +773,17 @@ struct HLFHELinalgNegEintToLinalgGeneric
|
||||
};
|
||||
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// operators `HLFHELinalgMatmulOp` to an instance of `linalg.generic`
|
||||
// operators `FHELinalgMatmulOp` to an instance of `linalg.generic`
|
||||
// with an appropriate region using a builder that create the multiplication
|
||||
// operators and `HLFHE.add_eint` operation, an appropriate specification for
|
||||
// operators and `FHE.add_eint` operation, an appropriate specification for
|
||||
// the iteration dimensions and appropriate operations managing the accumulator
|
||||
// of `linalg.generic`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// "HLFHELinalg.matmul_eint_int(%a, %b) :
|
||||
// (tensor<MxPx!HLFHE.eint<p>>, tensor<PxNxip'>) ->
|
||||
// tensor<MxNx!HLFHE.eint<p>>"
|
||||
// "FHELinalg.matmul_eint_int(%a, %b) :
|
||||
// (tensor<MxPx!FHE.eint<p>>, tensor<PxNxip'>) ->
|
||||
// tensor<MxNx!FHE.eint<p>>"
|
||||
|
||||
//
|
||||
// becomes:
|
||||
@@ -799,36 +799,36 @@ struct HLFHELinalgNegEintToLinalgGeneric
|
||||
// }
|
||||
// %init = linalg.generate {
|
||||
// ^bb0(%i : index, %j : index, %k : index):
|
||||
// %z = "HLFHE.zero" : () -> !HLFHE.eint<2>
|
||||
// %z = "FHE.zero" : () -> !FHE.eint<2>
|
||||
// linalg.yield %z
|
||||
// }: tensor<MxNx!HLFHE.eint<p>>
|
||||
// }: tensor<MxNx!FHE.eint<p>>
|
||||
// linalg.generic #attributes_0
|
||||
// ins(%A, %B : tensor<MxPx!HLFHE.eint<p>>,
|
||||
// ins(%A, %B : tensor<MxPx!FHE.eint<p>>,
|
||||
// tensor<PxNxip'>)
|
||||
// outs(%C : tensor<MxNx!HLFHE.eint<p>>)
|
||||
// outs(%C : tensor<MxNx!FHE.eint<p>>)
|
||||
// {
|
||||
// ^bb0(%a: !HLFHE.eint<p>, %b: ip', %c: !HLFHE.eint<p>) :
|
||||
// %d = createMulOp(%a, %b): !HLFHE.eint<p>
|
||||
// %e = "HLFHE.add_eint"(%c, %d):
|
||||
// (!HLFHE.eint<p>, !HLFHE.eint<p>) -> !HLFHE.eint<p>
|
||||
// linalg.yield %e : !HLFHE.eint<p>
|
||||
// ^bb0(%a: !FHE.eint<p>, %b: ip', %c: !FHE.eint<p>) :
|
||||
// %d = createMulOp(%a, %b): !FHE.eint<p>
|
||||
// %e = "FHE.add_eint"(%c, %d):
|
||||
// (!FHE.eint<p>, !FHE.eint<p>) -> !FHE.eint<p>
|
||||
// linalg.yield %e : !FHE.eint<p>
|
||||
// }
|
||||
//
|
||||
template <typename HLFHELinalgMatmulOp>
|
||||
struct HLFHELinalgMatmulToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<HLFHELinalgMatmulOp> {
|
||||
HLFHELinalgMatmulToLinalgGeneric(
|
||||
template <typename FHELinalgMatmulOp>
|
||||
struct FHELinalgMatmulToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<FHELinalgMatmulOp> {
|
||||
FHELinalgMatmulToLinalgGeneric(
|
||||
mlir::MLIRContext *context,
|
||||
std::function<mlir::concretelang::HLFHE::MulEintIntOp(
|
||||
std::function<mlir::concretelang::FHE::MulEintIntOp(
|
||||
mlir::OpBuilder &, mlir::Location, mlir::Type, mlir::Value,
|
||||
mlir::Value)>
|
||||
createMulOp,
|
||||
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: ::mlir::OpRewritePattern<HLFHELinalgMatmulOp>(context, benefit),
|
||||
: ::mlir::OpRewritePattern<FHELinalgMatmulOp>(context, benefit),
|
||||
createMulOp(createMulOp) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(HLFHELinalgMatmulOp matmulOp,
|
||||
matchAndRewrite(FHELinalgMatmulOp matmulOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::Location matmulLoc = matmulOp.getLoc();
|
||||
mlir::RankedTensorType resultTy =
|
||||
@@ -839,11 +839,11 @@ struct HLFHELinalgMatmulToLinalgGeneric
|
||||
auto generateBody = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
mlir::ValueRange blockArgs) {
|
||||
// %z = "HLFHE.zero" : () -> !HLFHE.eint<2>
|
||||
mlir::concretelang::HLFHE::ZeroEintOp zeroOp =
|
||||
nestedBuilder.create<mlir::concretelang::HLFHE::ZeroEintOp>(
|
||||
// %z = "FHE.zero" : () -> !FHE.eint<2>
|
||||
mlir::concretelang::FHE::ZeroEintOp zeroOp =
|
||||
nestedBuilder.create<mlir::concretelang::FHE::ZeroEintOp>(
|
||||
matmulLoc, resultElementTy);
|
||||
// linalg.yield %z : !HLFHE.eint<p>
|
||||
// linalg.yield %z : !FHE.eint<p>
|
||||
nestedBuilder.create<mlir::tensor::YieldOp>(matmulLoc,
|
||||
zeroOp.getResult());
|
||||
};
|
||||
@@ -873,16 +873,16 @@ struct HLFHELinalgMatmulToLinalgGeneric
|
||||
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
mlir::ValueRange blockArgs) {
|
||||
// "HLFHE.mul_eint_int"(%a, %b) : (!HLFHE.eint<p>, ip') -> !HLFHE.eint<p>
|
||||
mlir::concretelang::HLFHE::MulEintIntOp mulEintIntOp =
|
||||
// "FHE.mul_eint_int"(%a, %b) : (!FHE.eint<p>, ip') -> !FHE.eint<p>
|
||||
mlir::concretelang::FHE::MulEintIntOp mulEintIntOp =
|
||||
createMulOp(nestedBuilder, matmulLoc, resultElementTy, blockArgs[0],
|
||||
blockArgs[1]);
|
||||
// "HLFHE.add_eint"(%c, %d): (!HLFHE.eint<p>, !HLFHE.eint<p>) ->
|
||||
// !HLFHE.eint<p>
|
||||
mlir::concretelang::HLFHE::AddEintOp addEintOp =
|
||||
nestedBuilder.create<mlir::concretelang::HLFHE::AddEintOp>(
|
||||
// "FHE.add_eint"(%c, %d): (!FHE.eint<p>, !FHE.eint<p>) ->
|
||||
// !FHE.eint<p>
|
||||
mlir::concretelang::FHE::AddEintOp addEintOp =
|
||||
nestedBuilder.create<mlir::concretelang::FHE::AddEintOp>(
|
||||
matmulLoc, resultElementTy, blockArgs[2], mulEintIntOp);
|
||||
// linalg.yield %e : !HLFHE.eint<p>
|
||||
// linalg.yield %e : !FHE.eint<p>
|
||||
nestedBuilder.create<mlir::linalg::YieldOp>(matmulLoc,
|
||||
addEintOp.getResult());
|
||||
};
|
||||
@@ -905,38 +905,38 @@ struct HLFHELinalgMatmulToLinalgGeneric
|
||||
};
|
||||
|
||||
private:
|
||||
std::function<mlir::concretelang::HLFHE::MulEintIntOp(
|
||||
std::function<mlir::concretelang::FHE::MulEintIntOp(
|
||||
mlir::OpBuilder &, mlir::Location, mlir::Type, mlir::Value, mlir::Value)>
|
||||
createMulOp;
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of operators
|
||||
// `HLFHELinalg.zero` to an instance of `linalg.generate` with an
|
||||
// `FHELinalg.zero` to an instance of `linalg.generate` with an
|
||||
// appropriate region yielding a zero value.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// %out = "HLFHELinalg.zero"() : () -> tensor<MxNx!HLFHE.eint<p>>
|
||||
// %out = "FHELinalg.zero"() : () -> tensor<MxNx!FHE.eint<p>>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// %0 = tensor.generate {
|
||||
// ^bb0(%arg2: index, %arg3: index):
|
||||
// %zero = "HLFHE.zero"() : () -> !HLFHE.eint<p>
|
||||
// tensor.yield %zero : !HLFHE.eint<p>
|
||||
// } : tensor<MxNx!HLFHE.eint<p>>
|
||||
// %zero = "FHE.zero"() : () -> !FHE.eint<p>
|
||||
// tensor.yield %zero : !FHE.eint<p>
|
||||
// } : tensor<MxNx!FHE.eint<p>>
|
||||
//
|
||||
struct HLFHELinalgZeroToLinalgGenerate
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::HLFHELinalg::ZeroOp> {
|
||||
HLFHELinalgZeroToLinalgGenerate(
|
||||
struct FHELinalgZeroToLinalgGenerate
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::FHELinalg::ZeroOp> {
|
||||
FHELinalgZeroToLinalgGenerate(
|
||||
::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
|
||||
: ::mlir::OpRewritePattern<mlir::concretelang::HLFHELinalg::ZeroOp>(context,
|
||||
: ::mlir::OpRewritePattern<mlir::concretelang::FHELinalg::ZeroOp>(context,
|
||||
benefit) {
|
||||
}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::HLFHELinalg::ZeroOp zeroOp,
|
||||
matchAndRewrite(mlir::concretelang::FHELinalg::ZeroOp zeroOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::RankedTensorType resultTy =
|
||||
zeroOp->getResult(0).getType().cast<mlir::RankedTensorType>();
|
||||
@@ -945,7 +945,7 @@ struct HLFHELinalgZeroToLinalgGenerate
|
||||
mlir::Location nestedLoc,
|
||||
mlir::ValueRange blockArgs) {
|
||||
mlir::Value zeroScalar =
|
||||
nestedBuilder.create<mlir::concretelang::HLFHE::ZeroEintOp>(
|
||||
nestedBuilder.create<mlir::concretelang::FHE::ZeroEintOp>(
|
||||
zeroOp.getLoc(), resultTy.getElementType());
|
||||
nestedBuilder.create<mlir::tensor::YieldOp>(zeroOp.getLoc(), zeroScalar);
|
||||
};
|
||||
@@ -960,13 +960,13 @@ struct HLFHELinalgZeroToLinalgGenerate
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct HLFHETensorOpsToLinalg
|
||||
: public HLFHETensorOpsToLinalgBase<HLFHETensorOpsToLinalg> {
|
||||
struct FHETensorOpsToLinalg
|
||||
: public FHETensorOpsToLinalgBase<FHETensorOpsToLinalg> {
|
||||
|
||||
void runOnFunction() final;
|
||||
};
|
||||
|
||||
void HLFHETensorOpsToLinalg::runOnFunction() {
|
||||
void FHETensorOpsToLinalg::runOnFunction() {
|
||||
mlir::FuncOp function = this->getFunction();
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
@@ -974,51 +974,51 @@ void HLFHETensorOpsToLinalg::runOnFunction() {
|
||||
target.addLegalDialect<mlir::linalg::LinalgDialect>();
|
||||
target.addLegalDialect<mlir::StandardOpsDialect>();
|
||||
target.addLegalDialect<mlir::memref::MemRefDialect>();
|
||||
target.addLegalDialect<mlir::concretelang::HLFHE::HLFHEDialect>();
|
||||
target.addLegalDialect<mlir::concretelang::FHE::FHEDialect>();
|
||||
target.addLegalDialect<mlir::tensor::TensorDialect>();
|
||||
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
|
||||
target.addIllegalOp<mlir::concretelang::HLFHELinalg::Dot>();
|
||||
target.addIllegalDialect<mlir::concretelang::HLFHELinalg::HLFHELinalgDialect>();
|
||||
target.addIllegalOp<mlir::concretelang::FHELinalg::Dot>();
|
||||
target.addIllegalDialect<mlir::concretelang::FHELinalg::FHELinalgDialect>();
|
||||
|
||||
mlir::OwningRewritePatternList patterns(&getContext());
|
||||
patterns.insert<DotToLinalgGeneric>(&getContext());
|
||||
patterns.insert<
|
||||
HLFHELinalgOpToLinalgGeneric<mlir::concretelang::HLFHELinalg::AddEintOp,
|
||||
mlir::concretelang::HLFHE::AddEintOp>>(
|
||||
FHELinalgOpToLinalgGeneric<mlir::concretelang::FHELinalg::AddEintOp,
|
||||
mlir::concretelang::FHE::AddEintOp>>(
|
||||
&getContext());
|
||||
patterns.insert<
|
||||
HLFHELinalgOpToLinalgGeneric<mlir::concretelang::HLFHELinalg::AddEintIntOp,
|
||||
mlir::concretelang::HLFHE::AddEintIntOp>>(
|
||||
FHELinalgOpToLinalgGeneric<mlir::concretelang::FHELinalg::AddEintIntOp,
|
||||
mlir::concretelang::FHE::AddEintIntOp>>(
|
||||
&getContext());
|
||||
patterns.insert<
|
||||
HLFHELinalgOpToLinalgGeneric<mlir::concretelang::HLFHELinalg::SubIntEintOp,
|
||||
mlir::concretelang::HLFHE::SubIntEintOp>>(
|
||||
FHELinalgOpToLinalgGeneric<mlir::concretelang::FHELinalg::SubIntEintOp,
|
||||
mlir::concretelang::FHE::SubIntEintOp>>(
|
||||
&getContext());
|
||||
patterns.insert<
|
||||
HLFHELinalgOpToLinalgGeneric<mlir::concretelang::HLFHELinalg::MulEintIntOp,
|
||||
mlir::concretelang::HLFHE::MulEintIntOp>>(
|
||||
FHELinalgOpToLinalgGeneric<mlir::concretelang::FHELinalg::MulEintIntOp,
|
||||
mlir::concretelang::FHE::MulEintIntOp>>(
|
||||
&getContext());
|
||||
patterns.insert<HLFHELinalgApplyLookupTableToLinalgGeneric>(&getContext());
|
||||
patterns.insert<HLFHELinalgNegEintToLinalgGeneric>(&getContext());
|
||||
patterns.insert<HLFHELinalgMatmulToLinalgGeneric<
|
||||
mlir::concretelang::HLFHELinalg::MatMulEintIntOp>>(
|
||||
patterns.insert<FHELinalgApplyLookupTableToLinalgGeneric>(&getContext());
|
||||
patterns.insert<FHELinalgNegEintToLinalgGeneric>(&getContext());
|
||||
patterns.insert<FHELinalgMatmulToLinalgGeneric<
|
||||
mlir::concretelang::FHELinalg::MatMulEintIntOp>>(
|
||||
&getContext(), [](mlir::OpBuilder &builder, mlir::Location loc,
|
||||
mlir::Type type, mlir::Value arg0, mlir::Value arg1) {
|
||||
return builder.create<mlir::concretelang::HLFHE::MulEintIntOp>(loc, type,
|
||||
return builder.create<mlir::concretelang::FHE::MulEintIntOp>(loc, type,
|
||||
arg0, arg1);
|
||||
});
|
||||
patterns.insert<HLFHELinalgMatmulToLinalgGeneric<
|
||||
mlir::concretelang::HLFHELinalg::MatMulIntEintOp>>(
|
||||
patterns.insert<FHELinalgMatmulToLinalgGeneric<
|
||||
mlir::concretelang::FHELinalg::MatMulIntEintOp>>(
|
||||
&getContext(), [](mlir::OpBuilder &builder, mlir::Location loc,
|
||||
mlir::Type type, mlir::Value arg0, mlir::Value arg1) {
|
||||
return builder.create<mlir::concretelang::HLFHE::MulEintIntOp>(loc, type,
|
||||
return builder.create<mlir::concretelang::FHE::MulEintIntOp>(loc, type,
|
||||
arg1, arg0);
|
||||
});
|
||||
patterns.insert<HLFHELinalgApplyMultiLookupTableToLinalgGeneric>(
|
||||
patterns.insert<FHELinalgApplyMultiLookupTableToLinalgGeneric>(
|
||||
&getContext());
|
||||
patterns.insert<HLFHELinalgApplyMappedLookupTableToLinalgGeneric>(
|
||||
patterns.insert<FHELinalgApplyMappedLookupTableToLinalgGeneric>(
|
||||
&getContext());
|
||||
patterns.insert<HLFHELinalgZeroToLinalgGenerate>(&getContext());
|
||||
patterns.insert<FHELinalgZeroToLinalgGenerate>(&getContext());
|
||||
|
||||
if (mlir::applyPartialConversion(function, target, std::move(patterns))
|
||||
.failed())
|
||||
@@ -1029,8 +1029,8 @@ void HLFHETensorOpsToLinalg::runOnFunction() {
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<mlir::FunctionPass> createConvertHLFHETensorOpsToLinalg() {
|
||||
return std::make_unique<HLFHETensorOpsToLinalg>();
|
||||
std::unique_ptr<mlir::FunctionPass> createConvertFHETensorOpsToLinalg() {
|
||||
return std::make_unique<FHETensorOpsToLinalg>();
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
16
compiler/lib/Conversion/FHEToTFHE/CMakeLists.txt
Normal file
16
compiler/lib/Conversion/FHEToTFHE/CMakeLists.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
add_mlir_dialect_library(FHEToTFHE
|
||||
FHEToTFHE.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
|
||||
|
||||
DEPENDS
|
||||
FHEDialect
|
||||
FHEToTFHEPatternsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRTransforms
|
||||
MLIRMath)
|
||||
|
||||
target_link_libraries(FHEToTFHE PUBLIC MLIRIR)
|
||||
@@ -6,31 +6,31 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "concretelang/Conversion/HLFHEToMidLFHE/Patterns.h"
|
||||
#include "concretelang/Conversion/FHEToTFHE/Patterns.h"
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
|
||||
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
|
||||
#include "concretelang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "concretelang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include "concretelang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
#include "concretelang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTOps.h"
|
||||
|
||||
namespace {
|
||||
struct HLFHEToMidLFHEPass : public HLFHEToMidLFHEBase<HLFHEToMidLFHEPass> {
|
||||
struct FHEToTFHEPass : public FHEToTFHEBase<FHEToTFHEPass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
using mlir::concretelang::HLFHE::EncryptedIntegerType;
|
||||
using mlir::concretelang::MidLFHE::GLWECipherTextType;
|
||||
using mlir::concretelang::FHE::EncryptedIntegerType;
|
||||
using mlir::concretelang::TFHE::GLWECipherTextType;
|
||||
|
||||
/// HLFHEToMidLFHETypeConverter is a TypeConverter that transform
|
||||
/// `HLFHE.eint<p>` to `MidLFHE.glwe<{_,_,_}{p}>`
|
||||
class HLFHEToMidLFHETypeConverter : public mlir::TypeConverter {
|
||||
/// FHEToTFHETypeConverter is a TypeConverter that transform
|
||||
/// `FHE.eint<p>` to `TFHE.glwe<{_,_,_}{p}>`
|
||||
class FHEToTFHETypeConverter : public mlir::TypeConverter {
|
||||
|
||||
public:
|
||||
HLFHEToMidLFHETypeConverter() {
|
||||
FHEToTFHETypeConverter() {
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([](EncryptedIntegerType type) {
|
||||
return mlir::concretelang::convertTypeEncryptedIntegerToGLWE(
|
||||
@@ -50,17 +50,17 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
void HLFHEToMidLFHEPass::runOnOperation() {
|
||||
void FHEToTFHEPass::runOnOperation() {
|
||||
auto op = this->getOperation();
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
HLFHEToMidLFHETypeConverter converter;
|
||||
FHEToTFHETypeConverter converter;
|
||||
|
||||
// Mark ops from the target dialect as legal operations
|
||||
target.addLegalDialect<mlir::concretelang::MidLFHE::MidLFHEDialect>();
|
||||
target.addLegalDialect<mlir::concretelang::TFHE::TFHEDialect>();
|
||||
|
||||
// Make sure that no ops from `HLFHE` remain after the lowering
|
||||
target.addIllegalDialect<mlir::concretelang::HLFHE::HLFHEDialect>();
|
||||
// Make sure that no ops from `FHE` remain after the lowering
|
||||
target.addIllegalDialect<mlir::concretelang::FHE::FHEDialect>();
|
||||
|
||||
// Make sure that no ops `linalg.generic` that have illegal types
|
||||
target
|
||||
@@ -77,19 +77,19 @@ void HLFHEToMidLFHEPass::runOnOperation() {
|
||||
return converter.isSignatureLegal(funcOp.getType()) &&
|
||||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
// Add all patterns required to lower all ops from `HLFHE` to
|
||||
// `MidLFHE`
|
||||
// Add all patterns required to lower all ops from `FHE` to
|
||||
// `TFHE`
|
||||
mlir::OwningRewritePatternList patterns(&getContext());
|
||||
|
||||
populateWithGeneratedHLFHEToMidLFHE(patterns);
|
||||
populateWithGeneratedFHEToTFHE(patterns);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
|
||||
HLFHEToMidLFHETypeConverter>>(
|
||||
FHEToTFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
|
||||
HLFHEToMidLFHETypeConverter>>(
|
||||
FHEToTFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::scf::ForOp,
|
||||
HLFHEToMidLFHETypeConverter>>(
|
||||
FHEToTFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
mlir::concretelang::populateWithTensorTypeConverterPatterns(patterns, target,
|
||||
@@ -110,8 +110,8 @@ void HLFHEToMidLFHEPass::runOnOperation() {
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertHLFHEToMidLFHEPass() {
|
||||
return std::make_unique<HLFHEToMidLFHEPass>();
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertFHEToTFHEPass() {
|
||||
return std::make_unique<FHEToTFHEPass>();
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -1,17 +0,0 @@
|
||||
add_mlir_dialect_library(HLFHETensorOpsToLinalg
|
||||
TensorOpsToLinalg.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/HLFHE
|
||||
|
||||
DEPENDS
|
||||
HLFHEDialect
|
||||
HLFHELinalgDialect
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
HLFHEDialect
|
||||
HLFHELinalgDialect)
|
||||
|
||||
target_link_libraries(HLFHEDialect PUBLIC MLIRIR)
|
||||
@@ -1,16 +0,0 @@
|
||||
add_mlir_dialect_library(HLFHEToMidLFHE
|
||||
HLFHEToMidLFHE.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/HLFHE
|
||||
|
||||
DEPENDS
|
||||
HLFHEDialect
|
||||
HLFHEToMidLFHEPatternsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRTransforms
|
||||
MLIRMath)
|
||||
|
||||
target_link_libraries(HLFHEToMidLFHE PUBLIC MLIRIR)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user