chore: rename dialects

HLFHE to FHE
MidLFHE to TFHE
LowLFHE to Concrete
This commit is contained in:
youben11
2021-12-29 11:27:02 +01:00
committed by Ayoub Benaissa
parent 47ef595a2a
commit 940cb96be4
288 changed files with 4710 additions and 4710 deletions

View File

@@ -68,8 +68,8 @@ build-end-to-end-jit-clear-tensor: build-initialized
build-end-to-end-jit-encrypted-tensor: build-initialized
cmake --build $(BUILD_DIR) --target end_to_end_jit_encrypted_tensor
build-end-to-end-jit-hlfhelinalg: build-initialized
cmake --build $(BUILD_DIR) --target end_to_end_jit_hlfhelinalg
build-end-to-end-jit-fhelinalg: build-initialized
cmake --build $(BUILD_DIR) --target end_to_end_jit_fhelinalg
build-end-to-end-jit-lambda: build-initialized
cmake --build $(BUILD_DIR) --target end_to_end_jit_lambda
@@ -80,7 +80,7 @@ build-end-to-end-jit-dfr: build-initialized
build-end-to-end-jit-auto-parallelization: build-initialized
cmake --build $(BUILD_DIR) --target end_to_end_jit_auto_parallelization
build-end-to-end-jit: build-end-to-end-jit-test build-end-to-end-jit-clear-tensor build-end-to-end-jit-encrypted-tensor build-end-to-end-jit-hlfhelinalg
build-end-to-end-jit: build-end-to-end-jit-test build-end-to-end-jit-clear-tensor build-end-to-end-jit-encrypted-tensor build-end-to-end-jit-fhelinalg
test-end-to-end-jit-test: build-end-to-end-jit-test
@@ -92,8 +92,8 @@ test-end-to-end-jit-clear-tensor: build-end-to-end-jit-clear-tensor
test-end-to-end-jit-encrypted-tensor: build-end-to-end-jit-encrypted-tensor
$(BUILD_DIR)/bin/end_to_end_jit_encrypted_tensor
test-end-to-end-jit-hlfhelinalg: build-end-to-end-jit-hlfhelinalg
$(BUILD_DIR)/bin/end_to_end_jit_hlfhelinalg
test-end-to-end-jit-fhelinalg: build-end-to-end-jit-fhelinalg
$(BUILD_DIR)/bin/end_to_end_jit_fhelinalg
test-end-to-end-jit-lambda: build-initialized build-end-to-end-jit-lambda
$(BUILD_DIR)/bin/end_to_end_jit_lambda
@@ -104,7 +104,7 @@ test-end-to-end-jit-dfr: build-end-to-end-jit-dfr
test-end-to-end-jit-auto-parallelization: build-end-to-end-jit-auto-parallelization
$(BUILD_DIR)/bin/end_to_end_jit_auto_parallelization
test-end-to-end-jit: test-end-to-end-jit-test test-end-to-end-jit-clear-tensor test-end-to-end-jit-encrypted-tensor test-end-to-end-jit-hlfhelinalg
test-end-to-end-jit: test-end-to-end-jit-test test-end-to-end-jit-clear-tensor test-end-to-end-jit-encrypted-tensor test-end-to-end-jit-fhelinalg
show-stress-tests-summary:
@echo '------ Stress tests summary ------'
@@ -163,7 +163,7 @@ update_python_version:
test-end-to-end-jit-test \
test-end-to-end-jit-clear-tensor \
test-end-to-end-jit-encrypted-tensor \
test-end-to-end-jit-hlfhelinalg \
test-end-to-end-jit-fhelinalg \
test-python \
test \
add-deps \

View File

@@ -1,8 +1,8 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_C_DIALECT_HLFHE_H
#define CONCRETELANG_C_DIALECT_HLFHE_H
#ifndef CONCRETELANG_C_DIALECT_FHE_H
#define CONCRETELANG_C_DIALECT_FHE_H
#include "mlir-c/IR.h"
#include "mlir-c/Registration.h"
@@ -13,18 +13,18 @@
extern "C" {
#endif
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(HLFHE, hlfhe);
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(FHE, fhe);
/// Creates an encrypted integer type of `width` bits
MLIR_CAPI_EXPORTED MlirType hlfheEncryptedIntegerTypeGetChecked(
MLIR_CAPI_EXPORTED MlirType fheEncryptedIntegerTypeGetChecked(
MlirContext context, unsigned width,
mlir::function_ref<mlir::InFlightDiagnostic()> emitError);
/// If the type is an EncryptedInteger
MLIR_CAPI_EXPORTED bool hlfheTypeIsAnEncryptedIntegerType(MlirType);
MLIR_CAPI_EXPORTED bool fheTypeIsAnEncryptedIntegerType(MlirType);
#ifdef __cplusplus
}
#endif
#endif // CONCRETELANG_C_DIALECT_HLFHE_H
#endif // CONCRETELANG_C_DIALECT_FHE_H

View File

@@ -1,8 +1,8 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_C_DIALECT_HLFHELINALG_H
#define CONCRETELANG_C_DIALECT_HLFHELINALG_H
#ifndef CONCRETELANG_C_DIALECT_FHELINALG_H
#define CONCRETELANG_C_DIALECT_FHELINALG_H
#include "mlir-c/IR.h"
#include "mlir-c/Registration.h"
@@ -13,10 +13,10 @@
extern "C" {
#endif
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(HLFHELinalg, hlfhelinalg);
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(FHELinalg, fhelinalg);
#ifdef __cplusplus
}
#endif
#endif // CONCRETELANG_C_DIALECT_HLFHELINALG_H
#endif // CONCRETELANG_C_DIALECT_FHELINALG_H

View File

@@ -3,5 +3,5 @@ mlir_tablegen(Passes.h.inc -gen-pass-decls -name Conversion)
add_public_tablegen_target(ConcretelangConversionPassIncGen)
add_subdirectory(HLFHEToMidLFHE)
add_subdirectory(MidLFHEToLowLFHE)
add_subdirectory(FHEToTFHE)
add_subdirectory(TFHEToConcrete)

View File

@@ -2,8 +2,8 @@
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_CONVERSION_LOWLFHETOCONCRETECAPI_PASS_H_
#define CONCRETELANG_CONVERSION_LOWLFHETOCONCRETECAPI_PASS_H_
#ifndef CONCRETELANG_CONVERSION_CONCRETETOCONCRETECAPI_PASS_H_
#define CONCRETELANG_CONVERSION_CONCRETETOCONCRETECAPI_PASS_H_
#include "mlir/Pass/Pass.h"
@@ -11,10 +11,10 @@
namespace mlir {
namespace concretelang {
/// Create a pass to convert `LowLFHE` operators to function call to the
/// Create a pass to convert `Concrete` operators to function call to the
/// `ConcreteCAPI`
std::unique_ptr<OperationPass<ModuleOp>>
createConvertLowLFHEToConcreteCAPIPass();
createConvertConcreteToConcreteCAPIPass();
} // namespace concretelang
} // namespace mlir

View File

@@ -2,15 +2,15 @@
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_CONVERSION_LOWLFHEUNPARAMETRIZE_PASS_H_
#define CONCRETELANG_CONVERSION_LOWLFHEUNPARAMETRIZE_PASS_H_
#ifndef CONCRETELANG_CONVERSION_CONCRETEUNPARAMETRIZE_PASS_H_
#define CONCRETELANG_CONVERSION_CONCRETEUNPARAMETRIZE_PASS_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertLowLFHEUnparametrizePass();
createConvertConcreteUnparametrizePass();
} // namespace concretelang
} // namespace mlir

View File

@@ -2,16 +2,16 @@
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_CONVERSION_HLFHETENSOROPSTOLINALG_PASS_H_
#define CONCRETELANG_CONVERSION_HLFHETENSOROPSTOLINALG_PASS_H_
#ifndef CONCRETELANG_CONVERSION_FHETENSOROPSTOLINALG_PASS_H_
#define CONCRETELANG_CONVERSION_FHETENSOROPSTOLINALG_PASS_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace concretelang {
/// Create a pass to convert `HLFHE` tensor operators to linal.generic
/// Create a pass to convert `FHE` tensor operators to linal.generic
/// operators.
std::unique_ptr<mlir::FunctionPass> createConvertHLFHETensorOpsToLinalg();
std::unique_ptr<mlir::FunctionPass> createConvertFHETensorOpsToLinalg();
} // namespace concretelang
} // namespace mlir

View File

@@ -0,0 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Patterns.td)
mlir_tablegen(Patterns.h.inc -gen-rewriters -name FHE)
add_public_tablegen_target(FHEToTFHEPatternsIncGen)
add_concretelang_doc(Patterns FHEToTFHEPatterns concretelang/ -gen-pass-doc)

View File

@@ -2,15 +2,15 @@
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_CONVERSION_MIDLFHETOLOWLFHE_PASS_H_
#define CONCRETELANG_CONVERSION_MIDLFHETOLOWLFHE_PASS_H_
#ifndef CONCRETELANG_CONVERSION_FHETOTFHE_PASS_H_
#define CONCRETELANG_CONVERSION_FHETOTFHE_PASS_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace concretelang {
/// Create a pass to convert `MidLFHE` dialect to `LowLFHE` dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertMidLFHEToLowLFHEPass();
/// Create a pass to convert `FHE` dialect to `TFHE` dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertFHEToTFHEPass();
} // namespace concretelang
} // namespace mlir

View File

@@ -1,50 +1,50 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_CONVERSION_HLFHETOMIDLFHE_PATTERNS_H_
#define CONCRETELANG_CONVERSION_HLFHETOMIDLFHE_PATTERNS_H_
#ifndef CONCRETELANG_CONVERSION_FHETOTFHE_PATTERNS_H_
#define CONCRETELANG_CONVERSION_FHETOTFHE_PATTERNS_H_
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "concretelang/Dialect/HLFHE/IR/HLFHEOps.h"
#include "concretelang/Dialect/MidLFHE/IR/MidLFHEOps.h"
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
namespace mlir {
namespace concretelang {
using HLFHE::EncryptedIntegerType;
using MidLFHE::GLWECipherTextType;
using FHE::EncryptedIntegerType;
using TFHE::GLWECipherTextType;
/// Converts HLFHE::EncryptedInteger into MidLFHE::GlweCiphetext
/// Converts FHE::EncryptedInteger into TFHE::GlweCiphetext
GLWECipherTextType
convertTypeEncryptedIntegerToGLWE(mlir::MLIRContext *context,
EncryptedIntegerType &eint) {
return GLWECipherTextType::get(context, -1, -1, -1, eint.getWidth());
}
mlir::Value createZeroGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter,
mlir::Value createZeroGLWEOpFromFHE(mlir::PatternRewriter &rewriter,
mlir::Location loc,
mlir::OpResult result) {
mlir::SmallVector<mlir::Value> args{};
mlir::SmallVector<mlir::NamedAttribute, 0> attrs;
auto eint =
result.getType().cast<mlir::concretelang::HLFHE::EncryptedIntegerType>();
result.getType().cast<mlir::concretelang::FHE::EncryptedIntegerType>();
mlir::SmallVector<mlir::Type, 1> resTypes{
convertTypeEncryptedIntegerToGLWE(rewriter.getContext(), eint)};
MidLFHE::ZeroGLWEOp op =
rewriter.create<MidLFHE::ZeroGLWEOp>(loc, resTypes, args, attrs);
TFHE::ZeroGLWEOp op =
rewriter.create<TFHE::ZeroGLWEOp>(loc, resTypes, args, attrs);
return op.getODSResults(0).front();
}
template <class Operator>
mlir::Value createGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter,
mlir::Value createGLWEOpFromFHE(mlir::PatternRewriter &rewriter,
mlir::Location loc, mlir::Value arg0,
mlir::Value arg1, mlir::OpResult result) {
mlir::SmallVector<mlir::Value, 2> args{arg0, arg1};
mlir::SmallVector<mlir::NamedAttribute, 0> attrs;
auto eint =
result.getType().cast<mlir::concretelang::HLFHE::EncryptedIntegerType>();
result.getType().cast<mlir::concretelang::FHE::EncryptedIntegerType>();
mlir::SmallVector<mlir::Type, 1> resTypes{
convertTypeEncryptedIntegerToGLWE(rewriter.getContext(), eint)};
Operator op = rewriter.create<Operator>(loc, resTypes, args, attrs);
@@ -52,13 +52,13 @@ mlir::Value createGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter,
}
template <class Operator>
mlir::Value createGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter,
mlir::Value createGLWEOpFromFHE(mlir::PatternRewriter &rewriter,
mlir::Location loc, mlir::Value arg0,
mlir::OpResult result) {
mlir::SmallVector<mlir::Value, 1> args{arg0};
mlir::SmallVector<mlir::NamedAttribute, 0> attrs;
auto eint =
result.getType().cast<mlir::concretelang::HLFHE::EncryptedIntegerType>();
result.getType().cast<mlir::concretelang::FHE::EncryptedIntegerType>();
mlir::SmallVector<mlir::Type, 1> resTypes{
convertTypeEncryptedIntegerToGLWE(rewriter.getContext(), eint)};
Operator op = rewriter.create<Operator>(loc, resTypes, args, attrs);
@@ -66,7 +66,7 @@ mlir::Value createGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter,
}
mlir::Value
createApplyLookupTableGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter,
createApplyLookupTableGLWEOpFromFHE(mlir::PatternRewriter &rewriter,
mlir::Location loc, mlir::Value arg0,
mlir::Value arg1, mlir::OpResult result) {
mlir::SmallVector<mlir::Value, 2> args{arg0, arg1};
@@ -86,10 +86,10 @@ createApplyLookupTableGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter,
unset),
};
auto eint =
result.getType().cast<mlir::concretelang::HLFHE::EncryptedIntegerType>();
result.getType().cast<mlir::concretelang::FHE::EncryptedIntegerType>();
mlir::SmallVector<mlir::Type, 1> resTypes{
convertTypeEncryptedIntegerToGLWE(rewriter.getContext(), eint)};
auto op = rewriter.create<concretelang::MidLFHE::ApplyLookupTable>(loc, resTypes,
auto op = rewriter.create<concretelang::TFHE::ApplyLookupTable>(loc, resTypes,
args, attrs);
return op.getODSResults(0).front();
}
@@ -98,10 +98,10 @@ createApplyLookupTableGLWEOpFromHLFHE(mlir::PatternRewriter &rewriter,
} // namespace mlir
namespace {
#include "concretelang/Conversion/HLFHEToMidLFHE/Patterns.h.inc"
#include "concretelang/Conversion/FHEToTFHE/Patterns.h.inc"
}
void populateWithGeneratedHLFHEToMidLFHE(mlir::RewritePatternSet &patterns) {
void populateWithGeneratedFHEToTFHE(mlir::RewritePatternSet &patterns) {
populateWithGenerated(patterns);
}

View File

@@ -1,47 +1,47 @@
#ifndef CONCRETELANG_CONVERSION_HLFHETOMIDLFHE_PATTERNS
#define CONCRETELANG_CONVERSION_HLFHETOMIDLFHE_PATTERNS
#ifndef CONCRETELANG_CONVERSION_FHETOTFHE_PATTERNS
#define CONCRETELANG_CONVERSION_FHETOTFHE_PATTERNS
include "mlir/Pass/PassBase.td"
include "concretelang/Dialect/HLFHE/IR/HLFHEOps.td"
include "concretelang/Dialect/MidLFHE/IR/MidLFHEOps.td"
include "concretelang/Dialect/FHE/IR/FHEOps.td"
include "concretelang/Dialect/TFHE/IR/TFHEOps.td"
def createZeroGLWEOp : NativeCodeCall<"mlir::concretelang::createZeroGLWEOpFromHLFHE($_builder, $_loc, $0)">;
def createZeroGLWEOp : NativeCodeCall<"mlir::concretelang::createZeroGLWEOpFromFHE($_builder, $_loc, $0)">;
def ZeroEintPattern : Pat<
(ZeroEintOp:$result),
(createZeroGLWEOp $result)>;
def createAddGLWEIntOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromHLFHE<mlir::concretelang::MidLFHE::AddGLWEIntOp>($_builder, $_loc, $0, $1, $2)">;
def createAddGLWEIntOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromFHE<mlir::concretelang::TFHE::AddGLWEIntOp>($_builder, $_loc, $0, $1, $2)">;
def AddEintIntPattern : Pat<
(AddEintIntOp:$result $arg0, $arg1),
(createAddGLWEIntOp $arg0, $arg1, $result)>;
def createAddGLWEOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromHLFHE<mlir::concretelang::MidLFHE::AddGLWEOp>($_builder, $_loc, $0, $1, $2)">;
def createAddGLWEOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromFHE<mlir::concretelang::TFHE::AddGLWEOp>($_builder, $_loc, $0, $1, $2)">;
def AddEintPattern : Pat<
(AddEintOp:$result $arg0, $arg1),
(createAddGLWEOp $arg0, $arg1, $result)>;
def createSubIntGLWEOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromHLFHE<mlir::concretelang::MidLFHE::SubIntGLWEOp>($_builder, $_loc, $0, $1, $2)">;
def createSubIntGLWEOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromFHE<mlir::concretelang::TFHE::SubIntGLWEOp>($_builder, $_loc, $0, $1, $2)">;
def SubIntEintPattern : Pat<
(SubIntEintOp:$result $arg0, $arg1),
(createSubIntGLWEOp $arg0, $arg1, $result)>;
def createNegGLWEOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromHLFHE<mlir::concretelang::MidLFHE::NegGLWEOp>($_builder, $_loc, $0, $1)">;
def createNegGLWEOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromFHE<mlir::concretelang::TFHE::NegGLWEOp>($_builder, $_loc, $0, $1)">;
def NegEintPattern : Pat<
(NegEintOp:$result $arg0),
(createNegGLWEOp $arg0, $result)>;
def createMulGLWEIntOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromHLFHE<mlir::concretelang::MidLFHE::MulGLWEIntOp>($_builder, $_loc, $0, $1, $2)">;
def createMulGLWEIntOp : NativeCodeCall<"mlir::concretelang::createGLWEOpFromFHE<mlir::concretelang::TFHE::MulGLWEIntOp>($_builder, $_loc, $0, $1, $2)">;
def MulEintIntPattern : Pat<
(MulEintIntOp:$result $arg0, $arg1),
(createMulGLWEIntOp $arg0, $arg1, $result)>;
def createApplyLookupTableGLWEOp : NativeCodeCall<"mlir::concretelang::createApplyLookupTableGLWEOpFromHLFHE($_builder, $_loc, $0, $1, $2)">;
def createApplyLookupTableGLWEOp : NativeCodeCall<"mlir::concretelang::createApplyLookupTableGLWEOpFromFHE($_builder, $_loc, $0, $1, $2)">;
def ApplyLookupTableEintPattern : Pat<
(ApplyLookupTableEintOp:$result $arg0, $arg1),

View File

@@ -1,5 +0,0 @@
set(LLVM_TARGET_DEFINITIONS Patterns.td)
mlir_tablegen(Patterns.h.inc -gen-rewriters -name HLFHE)
add_public_tablegen_target(HLFHEToMidLFHEPatternsIncGen)
add_concretelang_doc(Patterns HLFHEToMidLFHEPatterns concretelang/ -gen-pass-doc)

View File

@@ -1,5 +0,0 @@
set(LLVM_TARGET_DEFINITIONS Patterns.td)
mlir_tablegen(Patterns.h.inc -gen-rewriters -name MidLFHE)
add_public_tablegen_target(MidLFHEToLowLFHEPatternsIncGen)
add_concretelang_doc(Patterns MidLFHEToLowLFHEPatterns concretelang/ -gen-pass-doc)

View File

@@ -9,16 +9,16 @@
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "concretelang/Conversion/HLFHETensorOpsToLinalg/Pass.h"
#include "concretelang/Conversion/HLFHEToMidLFHE/Pass.h"
#include "concretelang/Conversion/LowLFHEToConcreteCAPI/Pass.h"
#include "concretelang/Conversion/LowLFHEUnparametrize/Pass.h"
#include "concretelang/Conversion/FHETensorOpsToLinalg/Pass.h"
#include "concretelang/Conversion/FHEToTFHE/Pass.h"
#include "concretelang/Conversion/ConcreteToConcreteCAPI/Pass.h"
#include "concretelang/Conversion/ConcreteUnparametrize/Pass.h"
#include "concretelang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
#include "concretelang/Conversion/MidLFHEGlobalParametrization/Pass.h"
#include "concretelang/Conversion/MidLFHEToLowLFHE/Pass.h"
#include "concretelang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "concretelang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
#include "concretelang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "concretelang/Conversion/TFHEGlobalParametrization/Pass.h"
#include "concretelang/Conversion/TFHEToConcrete/Pass.h"
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#define GEN_PASS_CLASSES
#include "concretelang/Conversion/Passes.h.inc"

View File

@@ -3,45 +3,45 @@
include "mlir/Pass/PassBase.td"
def HLFHETensorOpsToLinalg : FunctionPass<"hlfhe-tensor-ops-to-linalg"> {
let summary = "Lowers tensor operations of HLFHE dialect to linalg.generic";
let constructor = "mlir::concretelang::createConvertHLFHETensorOpsToLinalg()";
def FHETensorOpsToLinalg : FunctionPass<"fhe-tensor-ops-to-linalg"> {
let summary = "Lowers tensor operations of FHE dialect to linalg.generic";
let constructor = "mlir::concretelang::createConvertFHETensorOpsToLinalg()";
let dependentDialects = ["mlir::linalg::LinalgDialect"];
}
def HLFHEToMidLFHE : Pass<"hlfhe-to-midlfhe", "mlir::ModuleOp"> {
let summary = "Lowers operations from the HLFHE dialect to MidLFHE";
let description = [{ Lowers operations from the HLFHE dialect to Std + Math }];
let constructor = "mlir::concretelang::createConvertHLFHEToMidLFHEPass()";
def FHEToTFHE : Pass<"fhe-to-tfhe", "mlir::ModuleOp"> {
let summary = "Lowers operations from the FHE dialect to TFHE";
let description = [{ Lowers operations from the FHE dialect to Std + Math }];
let constructor = "mlir::concretelang::createConvertFHEToTFHEPass()";
let options = [];
let dependentDialects = ["mlir::linalg::LinalgDialect"];
}
def MidLFHEGlobalParametrization : Pass<"midlfhe-global-parametrization", "mlir::ModuleOp"> {
let summary = "Inject global fhe parameters to the MidLFHE dialect";
let constructor = "mlir::concretelang::createConvertMidLFHEToLowLFHEPass()";
def TFHEGlobalParametrization : Pass<"tfhe-global-parametrization", "mlir::ModuleOp"> {
let summary = "Inject global fhe parameters to the TFHE dialect";
let constructor = "mlir::concretelang::createConvertTFHEToConcretePass()";
let options = [];
let dependentDialects = ["mlir::concretelang::MidLFHE::MidLFHEDialect"];
let dependentDialects = ["mlir::concretelang::TFHE::TFHEDialect"];
}
def MidLFHEToLowLFHE : Pass<"midlfhe-to-lowlfhe", "mlir::ModuleOp"> {
let summary = "Lowers operations from the MidLFHE dialect to LowLFHE";
let description = [{ Lowers operations from the MidLFHE dialect to LowLFHE }];
let constructor = "mlir::concretelang::createConvertMidLFHEToLowLFHEPass()";
def TFHEToConcrete : Pass<"tfhe-to-concrete", "mlir::ModuleOp"> {
let summary = "Lowers operations from the TFHE dialect to Concrete";
let description = [{ Lowers operations from the TFHE dialect to Concrete }];
let constructor = "mlir::concretelang::createConvertTFHEToConcretePass()";
let options = [];
let dependentDialects = ["mlir::linalg::LinalgDialect"];
}
def LowLFHEToConcreteCAPI : Pass<"lowlfhe-to-concrete-c-api", "mlir::ModuleOp"> {
let summary = "Lower operations from the LowLFHE dialect to std with function call to the Concrete C API";
let constructor = "mlir::concretelang::createConvertLowLFHEToConcreteCAPIPass()";
let dependentDialects = ["mlir::concretelang::LowLFHE::LowLFHEDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
def ConcreteToConcreteCAPI : Pass<"concrete-to-concrete-c-api", "mlir::ModuleOp"> {
let summary = "Lower operations from the Concrete dialect to std with function call to the Concrete C API";
let constructor = "mlir::concretelang::createConvertConcreteToConcreteCAPIPass()";
let dependentDialects = ["mlir::concretelang::Concrete::ConcreteDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
}
def LowLFHEUnparametrize : Pass<"lowlfhe-unparametrize", "mlir::ModuleOp"> {
let summary = "Unparametrize LowLFHE types and remove unrealized_conversion_cast";
let constructor = "mlir::concretelang::createConvertLowLFHEToConcreteCAPIPass()";
let dependentDialects = ["mlir::concretelang::LowLFHE::LowLFHEDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
def ConcreteUnparametrize : Pass<"concrete-unparametrize", "mlir::ModuleOp"> {
let summary = "Unparametrize Concrete types and remove unrealized_conversion_cast";
let constructor = "mlir::concretelang::createConvertConcreteToConcreteCAPIPass()";
let dependentDialects = ["mlir::concretelang::Concrete::ConcreteDialect", "mlir::StandardOpsDialect", "mlir::memref::MemRefDialect"];
}
def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> {

View File

@@ -2,8 +2,8 @@
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_CONVERSION_MIDLFHEGLOBALPARAMETRIZATION_PASS_H_
#define CONCRETELANG_CONVERSION_MIDLFHEGLOBALPARAMETRIZATION_PASS_H_
#ifndef CONCRETELANG_CONVERSION_TFHEGLOBALPARAMETRIZATION_PASS_H_
#define CONCRETELANG_CONVERSION_TFHEGLOBALPARAMETRIZATION_PASS_H_
#include "mlir/Pass/Pass.h"
@@ -11,9 +11,9 @@
namespace mlir {
namespace concretelang {
/// Create a pass to inject fhe parameters to the MidLFHE types and operators.
/// Create a pass to inject fhe parameters to the TFHE types and operators.
std::unique_ptr<OperationPass<ModuleOp>>
createConvertMidLFHEGlobalParametrizationPass(
createConvertTFHEGlobalParametrizationPass(
mlir::concretelang::V0FHEContext &fheContext);
} // namespace concretelang
} // namespace mlir

View File

@@ -0,0 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Patterns.td)
mlir_tablegen(Patterns.h.inc -gen-rewriters -name TFHE)
add_public_tablegen_target(TFHEToConcretePatternsIncGen)
add_concretelang_doc(Patterns TFHEToConcretePatterns concretelang/ -gen-pass-doc)

View File

@@ -2,15 +2,15 @@
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_CONVERSION_HLFHETOMIDLFHE_PASS_H_
#define CONCRETELANG_CONVERSION_HLFHETOMIDLFHE_PASS_H_
#ifndef CONCRETELANG_CONVERSION_TFHETOCONCRETE_PASS_H_
#define CONCRETELANG_CONVERSION_TFHETOCONCRETE_PASS_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace concretelang {
/// Create a pass to convert `HLFHE` dialect to `MidLFHE` dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertHLFHEToMidLFHEPass();
/// Create a pass to convert `TFHE` dialect to `Concrete` dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertTFHEToConcretePass();
} // namespace concretelang
} // namespace mlir

View File

@@ -1,22 +1,22 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_CONVERSION_MIDLFHETOLOWLFHE_PATTERNS_H_
#define CONCRETELANG_CONVERSION_MIDLFHETOLOWLFHE_PATTERNS_H_
#ifndef CONCRETELANG_CONVERSION_TFHETOCONCRETE_PATTERNS_H_
#define CONCRETELANG_CONVERSION_TFHETOCONCRETE_PATTERNS_H_
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "concretelang/Dialect/LowLFHE/IR/LowLFHEOps.h"
#include "concretelang/Dialect/MidLFHE/IR/MidLFHEOps.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
namespace mlir {
namespace concretelang {
using LowLFHE::CleartextType;
using LowLFHE::LweCiphertextType;
using LowLFHE::PlaintextType;
using MidLFHE::GLWECipherTextType;
using Concrete::CleartextType;
using Concrete::LweCiphertextType;
using Concrete::PlaintextType;
using TFHE::GLWECipherTextType;
LweCiphertextType convertTypeToLWE(mlir::MLIRContext *context,
mlir::Type type) {
@@ -81,7 +81,7 @@ CleartextType convertCleartextTypeFromType(mlir::MLIRContext *context,
return nullptr;
}
mlir::Value createZeroLWEOpFromMidLFHE(mlir::PatternRewriter &rewriter,
mlir::Value createZeroLWEOpFromTFHE(mlir::PatternRewriter &rewriter,
mlir::Location loc,
mlir::OpResult result) {
mlir::SmallVector<mlir::Value> args{};
@@ -89,13 +89,13 @@ mlir::Value createZeroLWEOpFromMidLFHE(mlir::PatternRewriter &rewriter,
auto glwe = result.getType().cast<GLWECipherTextType>();
mlir::SmallVector<mlir::Type, 1> resTypes{
convertTypeToLWE(rewriter.getContext(), glwe)};
LowLFHE::ZeroLWEOp op =
rewriter.create<LowLFHE::ZeroLWEOp>(loc, resTypes, args, attrs);
Concrete::ZeroLWEOp op =
rewriter.create<Concrete::ZeroLWEOp>(loc, resTypes, args, attrs);
return op.getODSResults(0).front();
}
template <class Operator>
mlir::Value createLowLFHEOpFromMidLFHE(mlir::PatternRewriter &rewriter,
mlir::Value createConcreteOpFromTFHE(mlir::PatternRewriter &rewriter,
mlir::Location loc, mlir::Value arg0,
mlir::Value arg1,
mlir::OpResult result) {
@@ -116,14 +116,14 @@ mlir::Value createAddPlainLweCiphertextWithGlwe(
// encode int into plaintext
mlir::Value encoded =
rewriter
.create<mlir::concretelang::LowLFHE::EncodeIntOp>(loc, encoded_type, arg1)
.create<mlir::concretelang::Concrete::EncodeIntOp>(loc, encoded_type, arg1)
.plaintext();
// convert result type
LweCiphertextType lwe_type =
convertTypeToLWE(rewriter.getContext(), result.getType());
// replace op using the encoded plaintext instead of int
auto op =
rewriter.create<mlir::concretelang::LowLFHE::AddPlaintextLweCiphertextOp>(
rewriter.create<mlir::concretelang::Concrete::AddPlaintextLweCiphertextOp>(
loc, lwe_type, arg0, encoded);
return op.getODSResults(0).front();
}
@@ -142,7 +142,7 @@ mlir::Value createSubIntLweCiphertext(mlir::PatternRewriter &rewriter,
auto arg1_type = arg1.getType();
auto negated_arg1 =
rewriter
.create<mlir::concretelang::LowLFHE::NegateLweCiphertextOp>(
.create<mlir::concretelang::Concrete::NegateLweCiphertextOp>(
loc, convertTypeToLWE(rewriter.getContext(), arg1_type), arg1)
.result();
return createAddPlainLweCiphertextWithGlwe(rewriter, loc, negated_arg1, arg0,
@@ -154,7 +154,7 @@ mlir::Value createNegLweCiphertext(mlir::PatternRewriter &rewriter,
mlir::OpResult result) {
auto arg0_type = arg0.getType();
auto negated =
rewriter.create<mlir::concretelang::LowLFHE::NegateLweCiphertextOp>(
rewriter.create<mlir::concretelang::Concrete::NegateLweCiphertextOp>(
loc, convertTypeToLWE(rewriter.getContext(), arg0_type), arg0);
return negated.getODSResults(0).front();
}
@@ -168,7 +168,7 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter &rewriter,
convertCleartextTypeFromType(rewriter.getContext(), inType);
// encode int into plaintext
mlir::Value encoded = rewriter
.create<mlir::concretelang::LowLFHE::IntToCleartextOp>(
.create<mlir::concretelang::Concrete::IntToCleartextOp>(
loc, encoded_type, arg1)
.cleartext();
// convert result type
@@ -176,12 +176,12 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter &rewriter,
LweCiphertextType lwe_type = convertTypeToLWE(rewriter.getContext(), resType);
// replace op using the encoded plaintext instead of int
auto op =
rewriter.create<mlir::concretelang::LowLFHE::MulCleartextLweCiphertextOp>(
rewriter.create<mlir::concretelang::Concrete::MulCleartextLweCiphertextOp>(
loc, lwe_type, arg0, encoded);
return op.getODSResults(0).front();
}
// This is the rewritting of the HLFHE::ApplyLookupTable operation, it will be
// This is the rewritting of the FHE::ApplyLookupTable operation, it will be
// rewritten as 3 new operations:
// - Create the required GLWE ciphertext out of the plain lookup table
// - Keyswitch the input ciphertext to match the input key of the bootstrapping
@@ -189,7 +189,7 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter &rewriter,
// Example:
// from:
// ```
// "%result = MidLFHE.apply_lookup_table"(% arg0, % tlu){
// "%result = TFHE.apply_lookup_table"(% arg0, % tlu){
// glweDimension = 1 : i32,
// polynomialSize = 2048 : i32,
// levelKS = 3 : i32,
@@ -197,29 +197,29 @@ mlir::Value createMulClearLweCiphertext(mlir::PatternRewriter &rewriter,
// levelBS = 5 : i32,
// baseLogBS = 4 : i32,
// outputSizeKS = 600 : i32
// } : (!MidLFHE.glwe<{2048, 1, 64} {4}>, tensor<16xi4>)
// ->(!MidLFHE.glwe<{2048, 1, 64} {4}>)
// } : (!TFHE.glwe<{2048, 1, 64} {4}>, tensor<16xi4>)
// ->(!TFHE.glwe<{2048, 1, 64} {4}>)
// ```
// to:
// ```
// % accumulator =
// "LowLFHE.glwe_from_table"(
// "Concrete.glwe_from_table"(
// % [[TABLE]]){glweDimension = 1 : i32, p = 4 : i32, polynomialSize =
// 2048 : i32}
// : (tensor<16xi4>)
// ->!LowLFHE.glwe_ciphertext
// % keyswitched = "LowLFHE.keyswitch_lwe"(% arg0){
// ->!Concrete.glwe_ciphertext
// % keyswitched = "Concrete.keyswitch_lwe"(% arg0){
// baseLog = 2 : i32,
// level = 3 : i32
// } : (!LowLFHE.lwe_ciphertext<2048, 4>)
// ->!LowLFHE.lwe_ciphertext<600, 4>
// % result = "LowLFHE.bootstrap_lwe"(% keyswitched, % accumulator){
// } : (!Concrete.lwe_ciphertext<2048, 4>)
// ->!Concrete.lwe_ciphertext<600, 4>
// % result = "Concrete.bootstrap_lwe"(% keyswitched, % accumulator){
// baseLog = 4 : i32,
// glweDimension = 1 : i32,
// level = 5 : i32,
// polynomialSize = 2048 : i32
// } : (!LowLFHE.lwe_ciphertext<600, 4>, !LowLFHE.glwe_ciphertext)
// ->!LowLFHE.lwe_ciphertext<2048, 4>
// } : (!Concrete.lwe_ciphertext<600, 4>, !Concrete.glwe_ciphertext)
// ->!Concrete.lwe_ciphertext<2048, 4>
// ```
mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc,
mlir::Value ct, mlir::Value table,
@@ -236,8 +236,8 @@ mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc,
mlir::IntegerAttr precision = rewriter.getI32IntegerAttr(lwe_type.getP());
mlir::Value accumulator =
rewriter
.create<mlir::concretelang::LowLFHE::GlweFromTable>(
loc, LowLFHE::GlweCiphertextType::get(rewriter.getContext()),
.create<mlir::concretelang::Concrete::GlweFromTable>(
loc, Concrete::GlweCiphertextType::get(rewriter.getContext()),
table, polynomialSize, glweDimension, precision)
.result();
@@ -255,7 +255,7 @@ mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc,
convertTypeToLWE(rewriter.getContext(), result.getType());
mlir::Value keyswitched =
rewriter
.create<mlir::concretelang::LowLFHE::KeySwitchLweOp>(loc, ksOutType,
.create<mlir::concretelang::Concrete::KeySwitchLweOp>(loc, ksOutType,
ksArgs, ksAttrs)
.result();
@@ -275,7 +275,7 @@ mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc,
};
mlir::Value bootstrapped =
rewriter
.create<mlir::concretelang::LowLFHE::BootstrapLweOp>(loc, lwe_type,
.create<mlir::concretelang::Concrete::BootstrapLweOp>(loc, lwe_type,
bsArgs, bsAttrs)
.result();
@@ -286,10 +286,10 @@ mlir::Value createPBS(mlir::PatternRewriter &rewriter, mlir::Location loc,
} // namespace mlir
namespace {
#include "concretelang/Conversion/MidLFHEToLowLFHE/Patterns.h.inc"
#include "concretelang/Conversion/TFHEToConcrete/Patterns.h.inc"
}
void populateWithGeneratedMidLFHEToLowLFHE(mlir::RewritePatternSet &patterns) {
void populateWithGeneratedTFHEToConcrete(mlir::RewritePatternSet &patterns) {
populateWithGenerated(patterns);
}

View File

@@ -1,18 +1,18 @@
#ifndef CONCRETELANG_CONVERSION_MIDLFHETOLOWLFHE_PATTERNS
#define CONCRETELANG_CONVERSION_MIDLFHETOLOWLFHE_PATTERNS
#ifndef CONCRETELANG_CONVERSION_TFHETOCONCRETE_PATTERNS
#define CONCRETELANG_CONVERSION_TFHETOCONCRETE_PATTERNS
include "mlir/Pass/PassBase.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "concretelang/Dialect/LowLFHE/IR/LowLFHEOps.td"
include "concretelang/Dialect/MidLFHE/IR/MidLFHEOps.td"
include "concretelang/Dialect/Concrete/IR/ConcreteOps.td"
include "concretelang/Dialect/TFHE/IR/TFHEOps.td"
def createZeroLWEOp : NativeCodeCall<"mlir::concretelang::createZeroLWEOpFromMidLFHE($_builder, $_loc, $0)">;
def createZeroLWEOp : NativeCodeCall<"mlir::concretelang::createZeroLWEOpFromTFHE($_builder, $_loc, $0)">;
def ZeroGLWEPattern : Pat<
(ZeroGLWEOp:$result),
(createZeroLWEOp $result)>;
def createAddLWEOp : NativeCodeCall<"mlir::concretelang::createLowLFHEOpFromMidLFHE<mlir::concretelang::LowLFHE::AddLweCiphertextsOp>($_builder, $_loc, $0, $1, $2)">;
def createAddLWEOp : NativeCodeCall<"mlir::concretelang::createConcreteOpFromTFHE<mlir::concretelang::Concrete::AddLweCiphertextsOp>($_builder, $_loc, $0, $1, $2)">;
def AddGLWEPattern : Pat<
(AddGLWEOp:$result $arg0, $arg1),

View File

@@ -1,5 +1,5 @@
add_subdirectory(HLFHE)
add_subdirectory(HLFHELinalg)
add_subdirectory(MidLFHE)
add_subdirectory(LowLFHE)
add_subdirectory(FHE)
add_subdirectory(FHELinalg)
add_subdirectory(TFHE)
add_subdirectory(Concrete)
add_subdirectory(RT)

View File

@@ -0,0 +1,13 @@
set(LLVM_TARGET_DEFINITIONS ConcreteOps.td)
mlir_tablegen(ConcreteOps.h.inc -gen-op-decls)
mlir_tablegen(ConcreteOps.cpp.inc -gen-op-defs)
mlir_tablegen(ConcreteOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=Concrete)
mlir_tablegen(ConcreteOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=Concrete)
mlir_tablegen(ConcreteOpsDialect.h.inc -gen-dialect-decls -dialect=Concrete)
mlir_tablegen(ConcreteOpsDialect.cpp.inc -gen-dialect-defs -dialect=Concrete)
add_public_tablegen_target(MLIRConcreteOpsIncGen)
add_dependencies(mlir-headers MLIRConcreteOpsIncGen)
add_concretelang_doc(ConcreteDialect ConcreteDialect concretelang/ -gen-dialect-doc)
add_concretelang_doc(ConcreteOps ConcreteOps concretelang/ -gen-op-doc)
add_concretelang_doc(ConcreteTypes ConcreteTypes concretelang/ -gen-typedef-doc)

View File

@@ -1,8 +1,8 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_DIALECT_MIDLFHE_IR_MIDLFHEDIALECT_H
#define CONCRETELANG_DIALECT_MIDLFHE_IR_MIDLFHEDIALECT_H
#ifndef CONCRETELANG_DIALECT_Concrete_IR_ConcreteDIALECT_H
#define CONCRETELANG_DIALECT_Concrete_IR_ConcreteDIALECT_H
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
@@ -11,6 +11,6 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "concretelang/Dialect/MidLFHE/IR/MidLFHEOpsDialect.h.inc"
#include "concretelang/Dialect/Concrete/IR/ConcreteOpsDialect.h.inc"
#endif

View File

@@ -0,0 +1,15 @@
#ifndef CONCRETELANG_DIALECT_Concrete_IR_Concrete_DIALECT
#define CONCRETELANG_DIALECT_Concrete_IR_Concrete_DIALECT
include "mlir/IR/OpBase.td"
def Concrete_Dialect : Dialect {
let name = "Concrete";
let summary = "Low Level Fully Homorphic Encryption dialect";
let description = [{
A dialect for representation of low level operation on fully homomorphic ciphertext.
}];
let cppNamespace = "::mlir::concretelang::Concrete";
}
#endif

View File

@@ -1,8 +1,8 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_DIALECT_MIDLFHE_IR_MIDLFHEOPS_H
#define CONCRETELANG_DIALECT_MIDLFHE_IR_MIDLFHEOPS_H
#ifndef CONCRETELANG_DIALECT_Concrete_Concrete_OPS_H
#define CONCRETELANG_DIALECT_Concrete_Concrete_OPS_H
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
@@ -10,9 +10,9 @@
#include <mlir/Interfaces/ControlFlowInterfaces.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include "concretelang/Dialect/MidLFHE/IR/MidLFHETypes.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
#define GET_OP_CLASSES
#include "concretelang/Dialect/MidLFHE/IR/MidLFHEOps.h.inc"
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h.inc"
#endif

View File

@@ -1,58 +1,58 @@
#ifndef CONCRETELANG_DIALECT_LowLFHE_IR_LowLFHE_OPS
#define CONCRETELANG_DIALECT_LowLFHE_IR_LowLFHE_OPS
#ifndef CONCRETELANG_DIALECT_Concrete_IR_Concrete_OPS
#define CONCRETELANG_DIALECT_Concrete_IR_Concrete_OPS
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "concretelang/Dialect/LowLFHE/IR/LowLFHEDialect.td"
include "concretelang/Dialect/LowLFHE/IR/LowLFHETypes.td"
include "concretelang/Dialect/Concrete/IR/ConcreteDialect.td"
include "concretelang/Dialect/Concrete/IR/ConcreteTypes.td"
class LowLFHE_Op<string mnemonic, list<OpTrait> traits = []> :
Op<LowLFHE_Dialect, mnemonic, traits>;
class Concrete_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Concrete_Dialect, mnemonic, traits>;
def ZeroLWEOp : LowLFHE_Op<"zero"> {
def ZeroLWEOp : Concrete_Op<"zero"> {
let summary = "Returns a trivial encyption of 0";
let arguments = (ins);
let results = (outs LweCiphertextType:$out);
}
def AddLweCiphertextsOp : LowLFHE_Op<"add_lwe_ciphertexts"> {
def AddLweCiphertextsOp : Concrete_Op<"add_lwe_ciphertexts"> {
let summary = "Returns the sum of 2 lwe ciphertexts";
let arguments = (ins LweCiphertextType:$lhs, LweCiphertextType:$rhs);
let results = (outs LweCiphertextType:$result);
}
def AddPlaintextLweCiphertextOp : LowLFHE_Op<"add_plaintext_lwe_ciphertext"> {
def AddPlaintextLweCiphertextOp : Concrete_Op<"add_plaintext_lwe_ciphertext"> {
let summary = "Returns the sum of a clear integer and a lwe ciphertext";
let arguments = (ins LweCiphertextType:$lhs, PlaintextType:$rhs);
let results = (outs LweCiphertextType:$result);
}
def MulCleartextLweCiphertextOp : LowLFHE_Op<"mul_cleartext_lwe_ciphertext"> {
def MulCleartextLweCiphertextOp : Concrete_Op<"mul_cleartext_lwe_ciphertext"> {
let summary = "Returns the product of a clear integer and a lwe ciphertext";
let arguments = (ins LweCiphertextType:$lhs, CleartextType:$rhs);
let results = (outs LweCiphertextType:$result);
}
def NegateLweCiphertextOp : LowLFHE_Op<"negate_lwe_ciphertext"> {
def NegateLweCiphertextOp : Concrete_Op<"negate_lwe_ciphertext"> {
let summary = "Negates a lwe ciphertext";
let arguments = (ins LweCiphertextType:$ciphertext);
let results = (outs LweCiphertextType:$result);
}
def GlweFromTable : LowLFHE_Op<"glwe_from_table"> {
def GlweFromTable : Concrete_Op<"glwe_from_table"> {
let summary = "Creates a GLWE ciphertext which is the trivial encrytion of a the input table interpreted as a polynomial (to use later in a bootstrap)";
let arguments = (ins TensorOf<[AnyInteger]>:$table, I32Attr:$polynomialSize, I32Attr:$glweDimension, I32Attr:$p);
let results = (outs GlweCiphertextType:$result);
}
def BootstrapLweOp : LowLFHE_Op<"bootstrap_lwe"> {
def BootstrapLweOp : Concrete_Op<"bootstrap_lwe"> {
let summary = "Bootstraps a LWE ciphertext with a GLWE trivial encryption of the lookup table";
@@ -68,7 +68,7 @@ def BootstrapLweOp : LowLFHE_Op<"bootstrap_lwe"> {
let results = (outs LweCiphertextType:$result);
}
def KeySwitchLweOp : LowLFHE_Op<"keyswitch_lwe"> {
def KeySwitchLweOp : Concrete_Op<"keyswitch_lwe"> {
let summary = "Keyswitches a LWE ciphertext";
let arguments = (ins
@@ -80,14 +80,14 @@ def KeySwitchLweOp : LowLFHE_Op<"keyswitch_lwe"> {
let results = (outs LweCiphertextType:$result);
}
def EncodeIntOp : LowLFHE_Op<"encode_int"> {
def EncodeIntOp : Concrete_Op<"encode_int"> {
let summary = "Encodes an integer (for it to later be added to a LWE ciphertext)";
let arguments = (ins AnyInteger:$i);
let results = (outs PlaintextType:$plaintext);
}
def IntToCleartextOp : LowLFHE_Op<"int_to_cleartext"> {
def IntToCleartextOp : Concrete_Op<"int_to_cleartext"> {
let summary = "Keyswitches a LWE ciphertext";
let arguments = (ins AnyInteger:$i);

View File

@@ -1,8 +1,8 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_DIALECT_LowLFHE_IR_LowLFHETYPES_H
#define CONCRETELANG_DIALECT_LowLFHE_IR_LowLFHETYPES_H
#ifndef CONCRETELANG_DIALECT_Concrete_IR_ConcreteTYPES_H
#define CONCRETELANG_DIALECT_Concrete_IR_ConcreteTYPES_H
#include "llvm/ADT/TypeSwitch.h"
#include <mlir/IR/BuiltinOps.h>
@@ -10,6 +10,6 @@
#include <mlir/IR/DialectImplementation.h>
#define GET_TYPEDEF_CLASSES
#include "concretelang/Dialect/LowLFHE/IR/LowLFHEOpsTypes.h.inc"
#include "concretelang/Dialect/Concrete/IR/ConcreteOpsTypes.h.inc"
#endif

View File

@@ -1,13 +1,13 @@
#ifndef CONCRETELANG_DIALECT_LowLFHE_IR_LowLFHE_TYPES
#define CONCRETELANG_DIALECT_LowLFHE_IR_LowLFHE_TYPES
#ifndef CONCRETELANG_DIALECT_Concrete_IR_Concrete_TYPES
#define CONCRETELANG_DIALECT_Concrete_IR_Concrete_TYPES
include "mlir/IR/BuiltinTypes.td"
include "concretelang/Dialect/LowLFHE/IR/LowLFHEDialect.td"
include "concretelang/Dialect/Concrete/IR/ConcreteDialect.td"
class LowLFHE_Type<string name, list<Trait> traits = []> : TypeDef<LowLFHE_Dialect, name, traits> { }
class Concrete_Type<string name, list<Trait> traits = []> : TypeDef<Concrete_Dialect, name, traits> { }
def GlweCiphertextType : LowLFHE_Type<"GlweCiphertext"> {
def GlweCiphertextType : Concrete_Type<"GlweCiphertext"> {
let mnemonic = "glwe_ciphertext";
let summary = "A GLWE ciphertext (encryption of a polynomial of fixed-precision integers)";
@@ -25,7 +25,7 @@ def GlweCiphertextType : LowLFHE_Type<"GlweCiphertext"> {
}];
}
def LweCiphertextType : LowLFHE_Type<"LweCiphertext", [MemRefElementTypeInterface]> {
def LweCiphertextType : Concrete_Type<"LweCiphertext", [MemRefElementTypeInterface]> {
let mnemonic = "lwe_ciphertext";
let summary = "A LWE ciphertext (encryption of a fixed-precision integer)";
@@ -70,7 +70,7 @@ def LweCiphertextType : LowLFHE_Type<"LweCiphertext", [MemRefElementTypeInterfac
}];
}
def CleartextType : LowLFHE_Type<"Cleartext"> {
def CleartextType : Concrete_Type<"Cleartext"> {
let mnemonic = "cleartext";
let summary = "A cleartext (a fixed-precision integer) ready to be multiplied to a LWE ciphertext";
@@ -104,7 +104,7 @@ def CleartextType : LowLFHE_Type<"Cleartext"> {
}];
}
def PlaintextType : LowLFHE_Type<"Plaintext"> {
def PlaintextType : Concrete_Type<"Plaintext"> {
let mnemonic = "plaintext";
let summary = "A Plaintext (a fixed-precision integer) ready to be added to a LWE ciphertext";
@@ -138,7 +138,7 @@ def PlaintextType : LowLFHE_Type<"Plaintext"> {
}];
}
def PlaintextListType : LowLFHE_Type<"PlaintextList"> {
def PlaintextListType : Concrete_Type<"PlaintextList"> {
let mnemonic = "plaintext_list";
let summary = "List of plaintexts";
@@ -156,7 +156,7 @@ def PlaintextListType : LowLFHE_Type<"PlaintextList"> {
}];
}
def ForeignPlaintextListType : LowLFHE_Type<"ForeignPlaintextList"> {
def ForeignPlaintextListType : Concrete_Type<"ForeignPlaintextList"> {
let mnemonic = "foreign_plaintext_list";
let summary = "A foreign (reference to a independently allocated memory space) plaintext list";
@@ -174,7 +174,7 @@ def ForeignPlaintextListType : LowLFHE_Type<"ForeignPlaintextList"> {
}];
}
def LweKeySwitchKeyType : LowLFHE_Type<"LweKeySwitchKey"> {
def LweKeySwitchKeyType : Concrete_Type<"LweKeySwitchKey"> {
let mnemonic = "lwe_key_switch_key";
let summary = "A LWE keyswitching key";
@@ -192,7 +192,7 @@ def LweKeySwitchKeyType : LowLFHE_Type<"LweKeySwitchKey"> {
}];
}
def LweBootstrapKeyType : LowLFHE_Type<"LweBootstrapKey"> {
def LweBootstrapKeyType : Concrete_Type<"LweBootstrapKey"> {
let mnemonic = "lwe_bootstrap_key";
let summary = "A LWE bootstrapping key";
@@ -210,7 +210,7 @@ def LweBootstrapKeyType : LowLFHE_Type<"LweBootstrapKey"> {
}];
}
def Context : LowLFHE_Type<"Context"> {
def Context : Concrete_Type<"Context"> {
let mnemonic = "context";
let summary = "A runtime context";

View File

@@ -1,8 +1,8 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_DIALECT_HLFHE_ANALYSIS_MANP_H
#define CONCRETELANG_DIALECT_HLFHE_ANALYSIS_MANP_H
#ifndef CONCRETELANG_DIALECT_FHE_ANALYSIS_MANP_H
#define CONCRETELANG_DIALECT_FHE_ANALYSIS_MANP_H
#include <functional>
#include <mlir/Pass/Pass.h>

View File

@@ -1,10 +1,10 @@
#ifndef CONCRETELANG_DIALECT_HLFHE_ANALYSIS_MANP
#define CONCRETELANG_DIALECT_HLFHE_ANALYSIS_MANP
#ifndef CONCRETELANG_DIALECT_FHE_ANALYSIS_MANP
#define CONCRETELANG_DIALECT_FHE_ANALYSIS_MANP
include "mlir/Pass/PassBase.td"
def MANP : FunctionPass<"MANP"> {
let summary = "HLFHE Minimal Arithmetic Noise Padding Pass";
let summary = "FHE Minimal Arithmetic Noise Padding Pass";
let description = [{
This pass calculates the Minimal Arithmetic Noise Padding
(MANP) for each operation of a function and stores the result in an
@@ -15,14 +15,14 @@ def MANP : FunctionPass<"MANP"> {
The pass supports the following operations:
- HLFHELinalg.dot_eint_int
- HLFHE.zero
- HLFHE.add_eint_int
- HLFHE.add_eint
- HLFHE.sub_int_eint
- HLFHE.neg_eint
- HLFHE.mul_eint_int
- HLFHE.apply_lookup_table
- FHELinalg.dot_eint_int
- FHE.zero
- FHE.add_eint_int
- FHE.add_eint
- FHE.sub_int_eint
- FHE.neg_eint
- FHE.mul_eint_int
- FHE.apply_lookup_table
If any other operation is encountered, the pass conservatively
fails. The pass further makes the optimistic assumption that all
@@ -42,24 +42,24 @@ def MANP : FunctionPass<"MANP"> {
with the following replacement rules:
- Function argument a -> HLFHELinalg.dot_eint_int([a], [1])
- HLFHE.apply_lookup_table -> HLFHELinalg.dot_eint_int([LUT result], [1])
- HLFHE.zero() -> HLFHELinalg.dot_eint_int([encrypted 0], [1])
- HLFHE.add_eint_int(e, c) -> HLFHELinalg.dot_eint_int([e, 1], [1, c])
- HLFHE.add_eint(e0, e1) -> HLFHELinalg.dot_eint_int([e0, e1], [1, 1])
- HLFHE.sub_int_eint(c, e) -> HLFHELinalg.dot_eint_int([e, c], [1, -1])
- HLFHE.neg_eint(e) -> HLFHELinalg.dot_eint_int([e], [-1])
- HLFHE.mul_eint_int(e, c) -> HLFHELinalg.dot_eint_int([e], [c])
- Function argument a -> FHELinalg.dot_eint_int([a], [1])
- FHE.apply_lookup_table -> FHELinalg.dot_eint_int([LUT result], [1])
- FHE.zero() -> FHELinalg.dot_eint_int([encrypted 0], [1])
- FHE.add_eint_int(e, c) -> FHELinalg.dot_eint_int([e, 1], [1, c])
- FHE.add_eint(e0, e1) -> FHELinalg.dot_eint_int([e0, e1], [1, 1])
- FHE.sub_int_eint(c, e) -> FHELinalg.dot_eint_int([e, c], [1, -1])
- FHE.neg_eint(e) -> FHELinalg.dot_eint_int([e], [-1])
- FHE.mul_eint_int(e, c) -> FHELinalg.dot_eint_int([e], [c])
Dependent dot operations, e.g.,
a = HLFHELinalg.dot_eint_int([a0, a1, ...], [c0, c1, ...])
b = HLFHELinalg.dot_eint_int([b0, b1, ...], [d0, d1, ...])
x = HLFHELinalg.dot_eint_int([a, b, ...], [f0, f1, ...])
a = FHELinalg.dot_eint_int([a0, a1, ...], [c0, c1, ...])
b = FHELinalg.dot_eint_int([b0, b1, ...], [d0, d1, ...])
x = FHELinalg.dot_eint_int([a, b, ...], [f0, f1, ...])
are merged as follows:
x = HLFHELinalg.dot_eint_int([a0, a1, ..., b0, b1, ...],
x = FHELinalg.dot_eint_int([a0, a1, ..., b0, b1, ...],
[f0*c0, f0*c1, ..., f1*d0, f1*d1, ...])
However, the implementation does not explicitly create the
@@ -80,15 +80,15 @@ def MANP : FunctionPass<"MANP"> {
for the supported operations:
- Function argument -> 1
- HLFHE.apply_lookup_table -> 1
- HLFHE.zero() -> 1
- HLFHELinalg.dot_eint_int([e0, e1, ...], [c0, c1, ...]) ->
- FHE.apply_lookup_table -> 1
- FHE.zero() -> 1
- FHELinalg.dot_eint_int([e0, e1, ...], [c0, c1, ...]) ->
c0*c0*sqN(e0) + c1*c1*sqN(e1) + ...
- HLFHE.add_eint_int(e, c) -> 1*1*sqN(e) + c*c*1*1 = sqN(e) + c*c
- HLFHE.add_eint(e0, e1) -> 1*1*sqN(e0) + 1*1*sqN(e2) = sqN(e1) + sqN(e2)
- HLFHE.sub_int_eint(c, e) -> 1*1*sqN(e) + c*c*(-1)*(-1) = sqN(e) + c*c
- HLFHE.neg_eint(e) -> (-1)*(-1)*sqN(e) = sqN(e)
- HLFHE.mul_eint_int(e, c) -> c*c*sqN(e)
- FHE.add_eint_int(e, c) -> 1*1*sqN(e) + c*c*1*1 = sqN(e) + c*c
- FHE.add_eint(e0, e1) -> 1*1*sqN(e0) + 1*1*sqN(e2) = sqN(e1) + sqN(e2)
- FHE.sub_int_eint(c, e) -> 1*1*sqN(e) + c*c*(-1)*(-1) = sqN(e) + c*c
- FHE.neg_eint(e) -> (-1)*(-1)*sqN(e) = sqN(e)
- FHE.mul_eint_int(e, c) -> c*c*sqN(e)
The final, non-squared 2-norm of an operation is the square root of the
squared value rounded to the next highest integer.
@@ -96,7 +96,7 @@ def MANP : FunctionPass<"MANP"> {
}
def MaxMANP : FunctionPass<"MaxMANP"> {
let summary = "Extract maximum HLFHE Minimal Arithmetic Noise Padding and maximum encrypted integer width";
let summary = "Extract maximum FHE Minimal Arithmetic Noise Padding and maximum encrypted integer width";
let description = [{
This pass calculates the squared Minimal Arithmetic Noise Padding
(MANP) for each operation using the MANP pass and extracts the

View File

@@ -0,0 +1,13 @@
set(LLVM_TARGET_DEFINITIONS FHEOps.td)
mlir_tablegen(FHEOps.h.inc -gen-op-decls)
mlir_tablegen(FHEOps.cpp.inc -gen-op-defs)
mlir_tablegen(FHEOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=FHE)
mlir_tablegen(FHEOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=FHE)
mlir_tablegen(FHEOpsDialect.h.inc -gen-dialect-decls -dialect=FHE)
mlir_tablegen(FHEOpsDialect.cpp.inc -gen-dialect-defs -dialect=FHE)
add_public_tablegen_target(MLIRFHEOpsIncGen)
add_dependencies(mlir-headers MLIRFHEOpsIncGen)
add_concretelang_doc(FHEDialect FHEDialect concretelang/ -gen-dialect-doc)
add_concretelang_doc(FHEOps FHEOps concretelang/ -gen-op-doc)
add_concretelang_doc(FHETypes FHETypes concretelang/ -gen-typedef-doc)

View File

@@ -1,8 +1,8 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_DIALECT_HLFHE_IR_HLFHEDIALECT_H
#define CONCRETELANG_DIALECT_HLFHE_IR_HLFHEDIALECT_H
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHEDIALECT_H
#define CONCRETELANG_DIALECT_FHE_IR_FHEDIALECT_H
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
@@ -11,6 +11,6 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "concretelang/Dialect/HLFHE/IR/HLFHEOpsDialect.h.inc"
#include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc"
#endif

View File

@@ -1,4 +1,4 @@
//===- HLFHEDialect.td - HLFHE dialect ----------------*- tablegen -*-===//
//===- FHEDialect.td - FHE dialect ----------------*- tablegen -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,18 +6,18 @@
//
//===----------------------------------------------------------------------===//
#ifndef CONCRETELANG_DIALECT_HLFHE_IR_HLFHE_DIALECT
#define CONCRETELANG_DIALECT_HLFHE_IR_HLFHE_DIALECT
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHE_DIALECT
#define CONCRETELANG_DIALECT_FHE_IR_FHE_DIALECT
include "mlir/IR/OpBase.td"
def HLFHE_Dialect : Dialect {
let name = "HLFHE";
def FHE_Dialect : Dialect {
let name = "FHE";
let summary = "High Level Fully Homorphic Encryption dialect";
let description = [{
A dialect for representation of high level operation on fully homomorphic ciphertext.
}];
let cppNamespace = "::mlir::concretelang::HLFHE";
let cppNamespace = "::mlir::concretelang::FHE";
}
#endif

View File

@@ -1,19 +1,19 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_DIALECT_HLFHE_IR_HLFHEOPS_H
#define CONCRETELANG_DIALECT_HLFHE_IR_HLFHEOPS_H
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHEOPS_H
#define CONCRETELANG_DIALECT_FHE_IR_FHEOPS_H
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/Interfaces/ControlFlowInterfaces.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include "concretelang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
namespace mlir {
namespace concretelang {
namespace HLFHE {
namespace FHE {
bool verifyEncryptedIntegerInputAndResultConsistency(
OpState &op, EncryptedIntegerType &input, EncryptedIntegerType &result);
@@ -23,7 +23,7 @@ bool verifyEncryptedIntegerAndIntegerInputsConsistency(OpState &op,
IntegerType &b);
/** Shared error message for all ApplyLookupTable variant Op (several Dialect)
* E.g. HLFHE.apply_lookup_table(input, lut)
* E.g. FHE.apply_lookup_table(input, lut)
* Message when the lut tensor has an invalid size,
* i.e. it cannot accomodate the input elements bitwidth
*/
@@ -38,11 +38,11 @@ void emitErrorBadLutSize(Op &op, std::string lutName, std::string inputName,
<< " elements bitwidth (" << bitWidth << ")";
}
} // namespace HLFHE
} // namespace FHE
} // namespace concretelang
} // namespace mlir
#define GET_OP_CLASSES
#include "concretelang/Dialect/HLFHE/IR/HLFHEOps.h.inc"
#include "concretelang/Dialect/FHE/IR/FHEOps.h.inc"
#endif

View File

@@ -1,4 +1,4 @@
//===- HLFHEOps.td - High level FHE dialect ops ----------------*- tablegen -*-===//
//===- FHEOps.td - High level FHE dialect ops ----------------*- tablegen -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,27 +6,27 @@
//
//===----------------------------------------------------------------------===//
#ifndef CONCRETELANG_DIALECT_HLFHE_IR_HLFHE_OPS
#define CONCRETELANG_DIALECT_HLFHE_IR_HLFHE_OPS
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHE_OPS
#define CONCRETELANG_DIALECT_FHE_IR_FHE_OPS
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "concretelang/Dialect/HLFHE/IR/HLFHEDialect.td"
include "concretelang/Dialect/HLFHE/IR/HLFHETypes.td"
include "concretelang/Dialect/FHE/IR/FHEDialect.td"
include "concretelang/Dialect/FHE/IR/FHETypes.td"
class HLFHE_Op<string mnemonic, list<OpTrait> traits = []> :
Op<HLFHE_Dialect, mnemonic, traits>;
class FHE_Op<string mnemonic, list<OpTrait> traits = []> :
Op<FHE_Dialect, mnemonic, traits>;
// Generates an encrypted zero constant
def ZeroEintOp : HLFHE_Op<"zero", [NoSideEffect]> {
def ZeroEintOp : FHE_Op<"zero", [NoSideEffect]> {
let summary = "Return an encryption of 0";
let arguments = (ins);
let results = (outs EncryptedIntegerType:$out);
}
def AddEintIntOp : HLFHE_Op<"add_eint_int"> {
def AddEintIntOp : FHE_Op<"add_eint_int"> {
let summary = "Adds an encrypted integer and a clear integer";
@@ -40,11 +40,11 @@ def AddEintIntOp : HLFHE_Op<"add_eint_int"> {
];
let verifier = [{
return ::mlir::concretelang::HLFHE::verifyAddEintIntOp(*this);
return ::mlir::concretelang::FHE::verifyAddEintIntOp(*this);
}];
}
def AddEintOp : HLFHE_Op<"add_eint"> {
def AddEintOp : FHE_Op<"add_eint"> {
let summary = "Adds two encrypted integers";
@@ -58,11 +58,11 @@ def AddEintOp : HLFHE_Op<"add_eint"> {
];
let verifier = [{
return ::mlir::concretelang::HLFHE::verifyAddEintOp(*this);
return ::mlir::concretelang::FHE::verifyAddEintOp(*this);
}];
}
def SubIntEintOp : HLFHE_Op<"sub_int_eint"> {
def SubIntEintOp : FHE_Op<"sub_int_eint"> {
let summary = "Substract a clear integer and an encrypted integer";
@@ -76,11 +76,11 @@ def SubIntEintOp : HLFHE_Op<"sub_int_eint"> {
];
let verifier = [{
return ::mlir::concretelang::HLFHE::verifySubIntEintOp(*this);
return ::mlir::concretelang::FHE::verifySubIntEintOp(*this);
}];
}
def NegEintOp : HLFHE_Op<"neg_eint"> {
def NegEintOp : FHE_Op<"neg_eint"> {
let summary = "Negates an encrypted integer";
@@ -94,11 +94,11 @@ def NegEintOp : HLFHE_Op<"neg_eint"> {
];
let verifier = [{
return ::mlir::concretelang::HLFHE::verifyNegEintOp(*this);
return ::mlir::concretelang::FHE::verifyNegEintOp(*this);
}];
}
def MulEintIntOp : HLFHE_Op<"mul_eint_int"> {
def MulEintIntOp : FHE_Op<"mul_eint_int"> {
let summary = "Mulitplies an encrypted integer and a clear integer";
@@ -112,11 +112,11 @@ def MulEintIntOp : HLFHE_Op<"mul_eint_int"> {
];
let verifier = [{
return ::mlir::concretelang::HLFHE::verifyMulEintIntOp(*this);
return ::mlir::concretelang::FHE::verifyMulEintIntOp(*this);
}];
}
def ApplyLookupTableEintOp : HLFHE_Op<"apply_lookup_table"> {
def ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table"> {
let summary = "Applies a clear lookup table to an encrypted integer";
@@ -125,7 +125,7 @@ def ApplyLookupTableEintOp : HLFHE_Op<"apply_lookup_table"> {
let results = (outs EncryptedIntegerType);
let verifier = [{
return ::mlir::concretelang::HLFHE::verifyApplyLookupTable(*this);
return ::mlir::concretelang::FHE::verifyApplyLookupTable(*this);
}];
}

View File

@@ -1,8 +1,8 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_DIALECT_HLFHE_IR_HLFHETYPES_H
#define CONCRETELANG_DIALECT_HLFHE_IR_HLFHETYPES_H
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHETYPES_H
#define CONCRETELANG_DIALECT_FHE_IR_FHETYPES_H
#include "llvm/ADT/TypeSwitch.h"
#include <mlir/IR/BuiltinOps.h>
@@ -10,6 +10,6 @@
#include <mlir/IR/DialectImplementation.h>
#define GET_TYPEDEF_CLASSES
#include "concretelang/Dialect/HLFHE/IR/HLFHEOpsTypes.h.inc"
#include "concretelang/Dialect/FHE/IR/FHEOpsTypes.h.inc"
#endif

View File

@@ -1,13 +1,13 @@
#ifndef CONCRETELANG_DIALECT_HLFHE_IR_HLFHE_TYPES
#define CONCRETELANG_DIALECT_HLFHE_IR_HLFHE_TYPES
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHE_TYPES
#define CONCRETELANG_DIALECT_FHE_IR_FHE_TYPES
include "concretelang/Dialect/HLFHE/IR/HLFHEDialect.td"
include "concretelang/Dialect/FHE/IR/FHEDialect.td"
include "mlir/IR/BuiltinTypes.td"
class HLFHE_Type<string name, list<Trait> traits = []> :
TypeDef<HLFHE_Dialect, name, traits> { }
class FHE_Type<string name, list<Trait> traits = []> :
TypeDef<FHE_Dialect, name, traits> { }
def EncryptedIntegerType : HLFHE_Type<"EncryptedInteger",
def EncryptedIntegerType : FHE_Type<"EncryptedInteger",
[MemRefElementTypeInterface]> {
let mnemonic = "eint";

View File

@@ -0,0 +1,13 @@
set(LLVM_TARGET_DEFINITIONS FHELinalgOps.td)
mlir_tablegen(FHELinalgOps.h.inc -gen-op-decls)
mlir_tablegen(FHELinalgOps.cpp.inc -gen-op-defs)
mlir_tablegen(FHELinalgOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=FHELinalg)
mlir_tablegen(FHELinalgOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=FHELinalg)
mlir_tablegen(FHELinalgOpsDialect.h.inc -gen-dialect-decls -dialect=FHELinalg)
mlir_tablegen(FHELinalgOpsDialect.cpp.inc -gen-dialect-defs -dialect=FHELinalg)
add_public_tablegen_target(MLIRFHELinalgOpsIncGen)
add_dependencies(mlir-headers MLIRFHELinalgOpsIncGen)
add_concretelang_doc(FHELinalgDialect FHELinalgDialect concretelang/ -gen-dialect-doc)
add_concretelang_doc(FHELinalgOps FHELinalgOps concretelang/ -gen-op-doc)
add_concretelang_doc(FHELinalgTypes FHELinalgTypes concretelang/ -gen-typedef-doc)

View File

@@ -1,13 +1,13 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalgDIALECT_H
#define CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalgDIALECT_H
#ifndef CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalgDIALECT_H
#define CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalgDIALECT_H
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgOpsDialect.h.inc"
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOpsDialect.h.inc"
#endif

View File

@@ -0,0 +1,15 @@
#ifndef CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalg_DIALECT
#define CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalg_DIALECT
include "mlir/IR/OpBase.td"
def FHELinalg_Dialect : Dialect {
let name = "FHELinalg";
let summary = "High Level Fully Homorphic Encryption Linalg dialect";
let description = [{
A dialect for representation of high level linalg operations on fully homomorphic ciphertexts.
}];
let cppNamespace = "::mlir::concretelang::FHELinalg";
}
#endif

View File

@@ -1,13 +1,13 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalgOPS_H
#define CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalgOPS_H
#ifndef CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalgOPS_H
#define CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalgOPS_H
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "concretelang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.h"
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.h"
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
@@ -42,8 +42,8 @@ public:
/// TensorBinaryEintInt verifies that the operation matches the following
/// signature
/// `(tensor<...x!HLFHE.eint<$p>>, tensor<...xi$p'>) ->
/// tensor<...x!HLFHE.eint<$p>>` where `$p <= $p+1`.
/// `(tensor<...x!FHE.eint<$p>>, tensor<...xi$p'>) ->
/// tensor<...x!FHE.eint<$p>>` where `$p <= $p+1`.
template <typename ConcreteType>
class TensorBinaryEintInt
: public mlir::OpTrait::TraitBase<ConcreteType, TensorBinaryEintInt> {
@@ -55,8 +55,8 @@ public:
/// TensorBinaryEintInt verifies that the operation matches the following
/// signature
/// `(tensor<...xi$p'>, tensor<...x!HLFHE.eint<$p>>) ->
/// tensor<...x!HLFHE.eint<$p>>` where `$p <= $p+1`.
/// `(tensor<...xi$p'>, tensor<...x!FHE.eint<$p>>) ->
/// tensor<...x!FHE.eint<$p>>` where `$p <= $p+1`.
template <typename ConcreteType>
class TensorBinaryIntEint
: public mlir::OpTrait::TraitBase<ConcreteType, TensorBinaryEintInt> {
@@ -67,8 +67,8 @@ public:
};
/// TensorBinary verify the operation match the following signature
/// `(tensor<...x!HLFHE.eint<$p>>, tensor<...x!HLFHE.eint<$p>>) ->
/// tensor<...x!HLFHE.eint<$p>>`
/// `(tensor<...x!FHE.eint<$p>>, tensor<...x!FHE.eint<$p>>) ->
/// tensor<...x!FHE.eint<$p>>`
template <typename ConcreteType>
class TensorBinaryEint
: public mlir::OpTrait::TraitBase<ConcreteType, TensorBinaryEint> {
@@ -79,7 +79,7 @@ public:
};
/// TensorBinary verify the operation match the following signature
/// `(tensor<...x!HLFHE.eint<$p>>) -> tensor<...x!HLFHE.eint<$p>>`
/// `(tensor<...x!FHE.eint<$p>>) -> tensor<...x!FHE.eint<$p>>`
template <typename ConcreteType>
class TensorUnaryEint
: public mlir::OpTrait::TraitBase<ConcreteType, TensorUnaryEint> {
@@ -93,6 +93,6 @@ public:
} // namespace mlir
#define GET_OP_CLASSES
#include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h.inc"
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h.inc"
#endif

View File

@@ -1,14 +1,14 @@
#ifndef CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_OPS
#define CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_OPS
#ifndef CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalg_OPS
#define CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalg_OPS
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.td"
include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.td"
include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.td"
include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.td"
class HLFHELinalg_Op<string mnemonic, list<OpTrait> traits = []> :
Op<HLFHELinalg_Dialect, mnemonic, traits>;
class FHELinalg_Op<string mnemonic, list<OpTrait> traits = []> :
Op<FHELinalg_Dialect, mnemonic, traits>;
// TensorBroadcastingRules verify that the operands and result verify the broadcasting rules
def TensorBroadcastingRules : NativeOpTrait<"TensorBroadcastingRules">;
@@ -18,7 +18,7 @@ def TensorBinaryEint : NativeOpTrait<"TensorBinaryEint">;
def TensorUnaryEint : NativeOpTrait<"TensorUnaryEint">;
def AddEintIntOp : HLFHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> {
def AddEintIntOp : FHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> {
let summary = "Returns a tensor that contains the addition of a tensor of encrypted integers and a tensor of clear integers.";
let description = [{
@@ -28,10 +28,10 @@ def AddEintIntOp : HLFHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, Tens
Examples:
```mlir
// Returns the term to term addition of `%a0` with `%a1`
"HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4xi5>) -> tensor<4x!HLFHE.eint<4>>
"FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<4>>, tensor<4xi5>) -> tensor<4x!FHE.eint<4>>
// Returns the term to term addition of `%a0` with `%a1`, where dimensions equal to one are stretched.
"HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4xi5>) -> tensor<4x4x4x!HLFHE.eint<4>>
"FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x1x4x!FHE.eint<4>>, tensor<1x4x4xi5>) -> tensor<4x4x4x!FHE.eint<4>>
// Returns the addition of a 3x3 matrix of encrypted integers and a 3x1 matrix (a column) of integers.
//
@@ -40,7 +40,7 @@ def AddEintIntOp : HLFHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, Tens
// [7,8,9] [3] [10,11,12]
//
// The dimension #1 of operand #2 is stretched as it is equals to 1.
"HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!HLFHE.eint<4>>
"FHELinalg.add_eint_int"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!FHE.eint<4>>
// Returns the addition of a 3x3 matrix of encrypted integers and a 1x3 matrix (a line) of integers.
//
@@ -49,10 +49,10 @@ def AddEintIntOp : HLFHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, Tens
// [7,8,9] [8,10,12]
//
// The dimension #2 of operand #2 is stretched as it is equals to 1.
"HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!HLFHE.eint<4>>
"FHELinalg.add_eint_int"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!FHE.eint<4>>
// Same behavior than the previous one, but as the dimension #2 is missing of operand #2.
"HLFHELinalg.add_eint_int(%a0, %a1)" : (tensor<3x4x!HLFHE.eint<4>>, tensor<3xi5>) -> tensor<4x4x4x!HLFHE.eint<4>>
"FHELinalg.add_eint_int(%a0, %a1)" : (tensor<3x4x!FHE.eint<4>>, tensor<3xi5>) -> tensor<4x4x4x!FHE.eint<4>>
```
}];
@@ -71,7 +71,7 @@ def AddEintIntOp : HLFHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, Tens
];
}
def AddEintOp : HLFHELinalg_Op<"add_eint", [TensorBroadcastingRules, TensorBinaryEint]> {
def AddEintOp : FHELinalg_Op<"add_eint", [TensorBroadcastingRules, TensorBinaryEint]> {
let summary = "Returns a tensor that contains the addition of two tensor of encrypted integers.";
let description = [{
@@ -81,10 +81,10 @@ def AddEintOp : HLFHELinalg_Op<"add_eint", [TensorBroadcastingRules, TensorBinar
Examples:
```mlir
// Returns the term to term addition of `%a0` with `%a1`
"HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>>
"FHELinalg.add_eint"(%a0, %a1) : (tensor<4x!FHE.eint<4>>, tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>>
// Returns the term to term addition of `%a0` with `%a1`, where dimensions equal to one are stretched.
"HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>>
"FHELinalg.add_eint"(%a0, %a1) : (tensor<4x1x4x!FHE.eint<4>>, tensor<1x4x4x!FHE.eint<4>>) -> tensor<4x4x4x!FHE.eint<4>>
// Returns the addition of a 3x3 matrix of encrypted integers and a 3x1 matrix (a column) of encrypted integers.
//
@@ -93,7 +93,7 @@ def AddEintOp : HLFHELinalg_Op<"add_eint", [TensorBroadcastingRules, TensorBinar
// [7,8,9] [3] [10,11,12]
//
// The dimension #1 of operand #2 is stretched as it is equals to 1.
"HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x1x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
"FHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x1x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
// Returns the addition of a 3x3 matrix of encrypted integers and a 1x3 matrix (a line) of encrypted integers.
//
@@ -102,10 +102,10 @@ def AddEintOp : HLFHELinalg_Op<"add_eint", [TensorBroadcastingRules, TensorBinar
// [7,8,9] [8,10,12]
//
// The dimension #2 of operand #2 is stretched as it is equals to 1.
"HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<1x3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
"FHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<1x3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
// Same behavior than the previous one, but as the dimension #2 of operand #2 is missing.
"HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
"FHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
```
}];
@@ -123,7 +123,7 @@ def AddEintOp : HLFHELinalg_Op<"add_eint", [TensorBroadcastingRules, TensorBinar
];
}
def SubIntEintOp : HLFHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, TensorBinaryIntEint]> {
def SubIntEintOp : FHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, TensorBinaryIntEint]> {
let summary = "Returns a tensor that contains the substraction of a tensor of clear integers and a tensor of encrypted integers.";
let description = [{
@@ -133,10 +133,10 @@ def SubIntEintOp : HLFHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, Tens
Examples:
```mlir
// Returns the term to term substraction of `%a0` with `%a1`
"HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4xi5>, tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>>
"FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4xi5>, tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>>
// Returns the term to term substraction of `%a0` with `%a1`, where dimensions equal to one are stretched.
"HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4x1x4xi5>, tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>>
"FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4x1x4xi5>, tensor<1x4x4x!FHE.eint<4>>) -> tensor<4x4x4x!FHE.eint<4>>
// Returns the substraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers.
//
@@ -145,7 +145,7 @@ def SubIntEintOp : HLFHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, Tens
// [7,8,9] [3] [4,5,6]
//
// The dimension #1 of operand #2 is stretched as it is equals to 1.
"HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<3x3xi5>, tensor<3x1x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
"FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<3x3xi5>, tensor<3x1x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
// Returns the substraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers.
//
@@ -154,10 +154,10 @@ def SubIntEintOp : HLFHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, Tens
// [7,8,9] [6,6,6]
//
// The dimension #2 of operand #2 is stretched as it is equals to 1.
"HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<3x3xi5>, tensor<1x3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
"FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<3x3xi5>, tensor<1x3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
// Same behavior than the previous one, but as the dimension #2 is missing of operand #2.
"HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<3x3xi5>, tensor<3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
"FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<3x3xi5>, tensor<3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
```
}];
@@ -176,7 +176,7 @@ def SubIntEintOp : HLFHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, Tens
];
}
def NegEintOp : HLFHELinalg_Op<"neg_eint", [TensorUnaryEint]> {
def NegEintOp : FHELinalg_Op<"neg_eint", [TensorUnaryEint]> {
let summary = "Returns a tensor that contains the negation of a tensor of encrypted integers.";
let description = [{
@@ -185,7 +185,7 @@ def NegEintOp : HLFHELinalg_Op<"neg_eint", [TensorUnaryEint]> {
Examples:
```mlir
// Returns the term to term negation of `%a0`
"HLFHELinalg.neg_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
"FHELinalg.neg_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
//
// ( [1,2,3] ) [31,30,29]
// negate ( [4,5,6] ) = [28,27,26]
@@ -208,7 +208,7 @@ def NegEintOp : HLFHELinalg_Op<"neg_eint", [TensorUnaryEint]> {
];
}
def MulEintIntOp : HLFHELinalg_Op<"mul_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> {
def MulEintIntOp : FHELinalg_Op<"mul_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> {
let summary = "Returns a tensor that contains the multiplication of a tensor of encrypted integers and a tensor of clear integers.";
let description = [{
@@ -218,10 +218,10 @@ def MulEintIntOp : HLFHELinalg_Op<"mul_eint_int", [TensorBroadcastingRules, Tens
Examples:
```mlir
// Returns the term to term multiplication of `%a0` with `%a1`
"HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4xi5>) -> tensor<4x!HLFHE.eint<4>>
"FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<4>>, tensor<4xi5>) -> tensor<4x!FHE.eint<4>>
// Returns the term to term multiplication of `%a0` with `%a1`, where dimensions equal to one are stretched.
"HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4xi5>) -> tensor<4x4x4x!HLFHE.eint<4>>
"FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x1x4x!FHE.eint<4>>, tensor<1x4x4xi5>) -> tensor<4x4x4x!FHE.eint<4>>
// Returns the multiplication of a 3x3 matrix of encrypted integers and a 3x1 matrix (a column) of integers.
//
@@ -230,7 +230,7 @@ def MulEintIntOp : HLFHELinalg_Op<"mul_eint_int", [TensorBroadcastingRules, Tens
// [7,8,9] [3] [21,24,27]
//
// The dimension #1 of operand #2 is stretched as it is equals to 1.
"HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!HLFHE.eint<4>>
"FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!FHE.eint<4>>
// Returns the multiplication of a 3x3 matrix of encrypted integers and a 1x3 matrix (a line) of integers.
//
@@ -239,10 +239,10 @@ def MulEintIntOp : HLFHELinalg_Op<"mul_eint_int", [TensorBroadcastingRules, Tens
// [7,8,9] [8,10,12]
//
// The dimension #2 of operand #2 is stretched as it is equals to 1.
"HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!HLFHE.eint<4>>
"FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!FHE.eint<4>>
// Same behavior than the previous one, but as the dimension #2 is missing of operand #2.
"HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3xi5>) -> tensor<3x3x!HLFHE.eint<4>>
"FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3xi5>) -> tensor<3x3x!FHE.eint<4>>
```
}];
@@ -255,7 +255,7 @@ def MulEintIntOp : HLFHELinalg_Op<"mul_eint_int", [TensorBroadcastingRules, Tens
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
}
def ApplyLookupTableEintOp : HLFHELinalg_Op<"apply_lookup_table", []> {
def ApplyLookupTableEintOp : FHELinalg_Op<"apply_lookup_table", []> {
let summary = "Returns a tensor that contains the result of the lookup on a table.";
let description = [{
@@ -264,7 +264,7 @@ def ApplyLookupTableEintOp : HLFHELinalg_Op<"apply_lookup_table", []> {
```mlir
// The result of this operation, is a tensor that contains the result of the lookup on a table.
// i.e. %res[i, ..., k] = %lut[%t[i, ..., k]]
%res = HLFHELinalg.apply_lookup_table(%t, %lut): tensor<DNx...xD1x!HLFHE.eint<$p>>, tensor<D2^$pxi64> -> tensor<DNx...xD1x!HLFHE.eint<$p>>
%res = FHELinalg.apply_lookup_table(%t, %lut): tensor<DNx...xD1x!FHE.eint<$p>>, tensor<D2^$pxi64> -> tensor<DNx...xD1x!FHE.eint<$p>>
```
The `%lut` argument must be a tensor with one dimension, where its dimension is equals to `2^p` where `p` is the width of the encrypted integers.
@@ -277,7 +277,7 @@ def ApplyLookupTableEintOp : HLFHELinalg_Op<"apply_lookup_table", []> {
// [0,1,2] [1,3,5]
// [3,0,1] lut [1,3,5,7] = [7,1,3]
// [2,3,0] [5,7,1]
"HLFHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!HLFHE.eint<3>>
"FHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!FHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!FHE.eint<3>>
```
}];
@@ -289,11 +289,11 @@ def ApplyLookupTableEintOp : HLFHELinalg_Op<"apply_lookup_table", []> {
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let verifier = [{
return ::mlir::concretelang::HLFHELinalg::verifyApplyLookupTable(*this);
return ::mlir::concretelang::FHELinalg::verifyApplyLookupTable(*this);
}];
}
def ApplyMultiLookupTableEintOp : HLFHELinalg_Op<"apply_multi_lookup_table", []> {
def ApplyMultiLookupTableEintOp : FHELinalg_Op<"apply_multi_lookup_table", []> {
let summary = "Returns a tensor that contains the result of the lookup on a table, using a different lookup table for each element.";
let description = [{
@@ -303,7 +303,7 @@ def ApplyMultiLookupTableEintOp : HLFHELinalg_Op<"apply_multi_lookup_table", []>
```mlir
// The result of this operation, is a tensor that contains the result of the lookup on different tables.
// i.e. %res[i, ..., k] = [ %luts[i][%t[i]], ..., %luts[k][%t[k]] ]
%res = HLFHELinalg.apply_multi_lookup_table(%t, %lut): tensor<DNx...xD1x!HLFHE.eint<$p>>, tensor<DMx...xD1xD2^$pxi64> -> tensor<DNx...xD1x!HLFHE.eint<$p>>
%res = FHELinalg.apply_multi_lookup_table(%t, %lut): tensor<DNx...xD1x!FHE.eint<$p>>, tensor<DMx...xD1xD2^$pxi64> -> tensor<DNx...xD1x!FHE.eint<$p>>
```
The `%luts` argument should be a tensor with M dimension, where the first M-1 dimensions are broadcastable with the N dimensions of the encrypted tensor,
@@ -318,7 +318,7 @@ def ApplyMultiLookupTableEintOp : HLFHELinalg_Op<"apply_multi_lookup_table", []>
// [0,1] = [1,2]
// [3,0] lut [[1,3,5,7], [0,2,4,6]] = [7,0]
// [2,3] = [5,6]
"HLFHELinalg.apply_multi_lookup_table"(%t, %luts) : (tensor<3x2x!HLFHE.eint<2>>, tensor<2x4xi64>) -> tensor<3x2x!HLFHE.eint<3>>
"FHELinalg.apply_multi_lookup_table"(%t, %luts) : (tensor<3x2x!FHE.eint<2>>, tensor<2x4xi64>) -> tensor<3x2x!FHE.eint<3>>
```
```mlir
@@ -326,7 +326,7 @@ def ApplyMultiLookupTableEintOp : HLFHELinalg_Op<"apply_multi_lookup_table", []>
// Returns the lookup of a vector of 3 encrypted indices of width 2 on a vector of 3 tables of size 4=2² of clear integers.
//
// [3,0,1] lut [[1,3,5,7], [0,2,4,6], [1,2,3,4]] = [7,0,2]
"HLFHELinalg.apply_multi_lookup_table"(%t, %luts) : (tensor<3x!HLFHE.eint<2>>, tensor<3x4xi64>) -> tensor<3x!HLFHE.eint<3>>
"FHELinalg.apply_multi_lookup_table"(%t, %luts) : (tensor<3x!FHE.eint<2>>, tensor<3x4xi64>) -> tensor<3x!FHE.eint<3>>
```
}];
@@ -338,11 +338,11 @@ def ApplyMultiLookupTableEintOp : HLFHELinalg_Op<"apply_multi_lookup_table", []>
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let verifier = [{
return ::mlir::concretelang::HLFHELinalg::verifyApplyMultiLookupTable(*this);
return ::mlir::concretelang::FHELinalg::verifyApplyMultiLookupTable(*this);
}];
}
def ApplyMappedLookupTableEintOp : HLFHELinalg_Op<"apply_mapped_lookup_table", []> {
def ApplyMappedLookupTableEintOp : FHELinalg_Op<"apply_mapped_lookup_table", []> {
let summary = "Returns a tensor that contains the result of the lookup on a table, using a different lookup table for each element, specified by a map.";
let description = [{
@@ -352,7 +352,7 @@ def ApplyMappedLookupTableEintOp : HLFHELinalg_Op<"apply_mapped_lookup_table", [
```mlir
// The result of this operation, is a tensor that contains the result of the lookup on different tables.
// i.e. %res[i, ..., k] = %luts[ %map[i, ..., k] ][ %t[i, ..., k] ]
%res = HLFHELinalg.apply_mapped_lookup_table(%t, %luts, %map): tensor<DNx...xD1x!HLFHE.eint<$p>>, tensor<DM x ^$p>, tensor<DNx...xD1xindex> -> tensor<DNx...xD1x!HLFHE.eint<$p>>
%res = FHELinalg.apply_mapped_lookup_table(%t, %luts, %map): tensor<DNx...xD1x!FHE.eint<$p>>, tensor<DM x ^$p>, tensor<DNx...xD1xindex> -> tensor<DNx...xD1x!FHE.eint<$p>>
```
Examples:
@@ -363,7 +363,7 @@ def ApplyMappedLookupTableEintOp : HLFHELinalg_Op<"apply_mapped_lookup_table", [
// [0,1] [0, 1] = [1,2]
// [3,0] lut [[1,3,5,7], [0,2,4,6]] with [0, 1] = [7,0]
// [2,3] [0, 1] = [5,6]
"HLFHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) : (tensor<3x2x!HLFHE.eint<2>>, tensor<2x4xi64>, tensor<3x2xindex>) -> tensor<3x2x!HLFHE.eint<3>>
"FHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) : (tensor<3x2x!FHE.eint<2>>, tensor<2x4xi64>, tensor<3x2xindex>) -> tensor<3x2x!FHE.eint<3>>
```
Others examples:
@@ -394,12 +394,12 @@ def ApplyMappedLookupTableEintOp : HLFHELinalg_Op<"apply_mapped_lookup_table", [
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let verifier = [{
return ::mlir::concretelang::HLFHELinalg::verifyApplyMappedLookupTable(*this);
return ::mlir::concretelang::FHELinalg::verifyApplyMappedLookupTable(*this);
}];
}
// Dot product
def Dot : HLFHELinalg_Op<"dot_eint_int"> {
def Dot : FHELinalg_Op<"dot_eint_int"> {
let summary = "Returns the encrypted dot product between a vector of encrypted integers and a vector of clean integers.";
let description = [{
@@ -408,7 +408,7 @@ def Dot : HLFHELinalg_Op<"dot_eint_int"> {
Examples:
```mlir
// Returns the dot product of `%a0` with `%a1`
"HLFHELinalg.dot_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4xi5>) -> !HLFHE.eint<4>
"FHELinalg.dot_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<4>>, tensor<4xi5>) -> !FHE.eint<4>
```
}];
@@ -420,11 +420,11 @@ def Dot : HLFHELinalg_Op<"dot_eint_int"> {
let results = (outs EncryptedIntegerType:$out);
let verifier = [{
return ::mlir::concretelang::HLFHELinalg::verifyDotEintInt(*this);
return ::mlir::concretelang::FHELinalg::verifyDotEintInt(*this);
}];
}
def MatMulEintIntOp : HLFHELinalg_Op<"matmul_eint_int", [TensorBinaryEintInt]> {
def MatMulEintIntOp : FHELinalg_Op<"matmul_eint_int", [TensorBinaryEintInt]> {
let summary = "Returns a tensor that contains the result of the matrix multiplication of a matrix of encrypted integers and a matrix of clear integers.";
let description = [{
@@ -432,7 +432,7 @@ def MatMulEintIntOp : HLFHELinalg_Op<"matmul_eint_int", [TensorBinaryEintInt]> {
The width of the clear integers must be less than or equals to the witdh of encrypted integers.
```mlir
"HLFHELinalg.matmul_eint_int(%a, %b) : (tensor<MxNx!HLFHE.eint<p>>, tensor<NxPxip'>) -> tensor<MxPx!HLFHE.eint<p>>"
"FHELinalg.matmul_eint_int(%a, %b) : (tensor<MxNx!FHE.eint<p>>, tensor<NxPxip'>) -> tensor<MxPx!FHE.eint<p>>"
```
Examples:
@@ -445,7 +445,7 @@ def MatMulEintIntOp : HLFHELinalg_Op<"matmul_eint_int", [TensorBinaryEintInt]> {
// [3,4] = [11,18,25]
// [5,6] [17,28,39]
//
"HLFHELinalg.matmul_eint_int"(%a, %b) : (tensor<3x2x!HLFHE.eint<6>>, tensor<2x3xi7>) -> tensor<3x3x!HLFHE.eint<6>>
"FHELinalg.matmul_eint_int"(%a, %b) : (tensor<3x2x!FHE.eint<6>>, tensor<2x3xi7>) -> tensor<3x3x!FHE.eint<6>>
```
}];
@@ -458,11 +458,11 @@ def MatMulEintIntOp : HLFHELinalg_Op<"matmul_eint_int", [TensorBinaryEintInt]> {
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let verifier = [{
return ::mlir::concretelang::HLFHELinalg::verifyMatmul<mlir::concretelang::HLFHELinalg::MatMulEintIntOp>(*this);
return ::mlir::concretelang::FHELinalg::verifyMatmul<mlir::concretelang::FHELinalg::MatMulEintIntOp>(*this);
}];
}
def MatMulIntEintOp : HLFHELinalg_Op<"matmul_int_eint", [TensorBinaryIntEint]> {
def MatMulIntEintOp : FHELinalg_Op<"matmul_int_eint", [TensorBinaryIntEint]> {
let summary = "Returns a tensor that contains the result of the matrix multiplication of a matrix of clear integers and a matrix of encrypted integers.";
let description = [{
@@ -470,7 +470,7 @@ def MatMulIntEintOp : HLFHELinalg_Op<"matmul_int_eint", [TensorBinaryIntEint]> {
The width of the clear integers must be less than or equals to the witdh of encrypted integers.
```mlir
"HLFHELinalg.matmul_int_eint(%a, %b) : (tensor<MxNxip'>, tensor<NxPxHLFHE.eint<p>>) -> tensor<MxPx!HLFHE.eint<p>>"
"FHELinalg.matmul_int_eint(%a, %b) : (tensor<MxNxip'>, tensor<NxPxFHE.eint<p>>) -> tensor<MxPx!FHE.eint<p>>"
```
Examples:
@@ -483,7 +483,7 @@ def MatMulIntEintOp : HLFHELinalg_Op<"matmul_int_eint", [TensorBinaryIntEint]> {
// [3,4] = [11,18,25]
// [5,6] [17,28,39]
//
"HLFHELinalg.matmul_int_eint"(%a, %b) : (tensor<3x2xi7>, tensor<2x3x!HLFHE.eint<6>>) -> tensor<3x3x!HLFHE.eint<6>>
"FHELinalg.matmul_int_eint"(%a, %b) : (tensor<3x2xi7>, tensor<2x3x!FHE.eint<6>>) -> tensor<3x3x!FHE.eint<6>>
```
}];
@@ -496,11 +496,11 @@ def MatMulIntEintOp : HLFHELinalg_Op<"matmul_int_eint", [TensorBinaryIntEint]> {
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let verifier = [{
return ::mlir::concretelang::HLFHELinalg::verifyMatmul<mlir::concretelang::HLFHELinalg::MatMulIntEintOp>(*this);
return ::mlir::concretelang::FHELinalg::verifyMatmul<mlir::concretelang::FHELinalg::MatMulIntEintOp>(*this);
}];
}
def ZeroOp : HLFHELinalg_Op<"zero", []> {
def ZeroOp : FHELinalg_Op<"zero", []> {
let summary = "Creates a new tensor with all elements initialized to an encrypted zero.";
let description = [{
@@ -508,7 +508,7 @@ def ZeroOp : HLFHELinalg_Op<"zero", []> {
Example:
```mlir
%tensor = "HLFHELinalg.zero"() : () -> tensor<5x!HLFHE.eint<4>>
%tensor = "FHELinalg.zero"() : () -> tensor<5x!FHE.eint<4>>
```
}];

View File

@@ -1,14 +1,14 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalgTYPES_H
#define CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalgTYPES_H
#ifndef CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalgTYPES_H
#define CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalgTYPES_H
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/DialectImplementation.h>
#define GET_TYPEDEF_CLASSES
#include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgOpsTypes.h.inc"
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOpsTypes.h.inc"
#endif

View File

@@ -0,0 +1,11 @@
#ifndef CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalg_TYPES
#define CONCRETELANG_DIALECT_FHELinalg_IR_FHELinalg_TYPES
include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.td"
include "mlir/IR/BuiltinTypes.td"
include "concretelang/Dialect/FHE/IR/FHETypes.td"
class FHELinalg_Type<string name, list<Trait> traits = []> :
TypeDef<FHELinalg_Dialect, name, traits> { }
#endif

View File

@@ -1,3 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Tiling.td)
mlir_tablegen(Tiling.h.inc -gen-pass-decls -name Transforms)
add_public_tablegen_target(ConcretelangHLFHELinalgTilingPassIncGen)
add_public_tablegen_target(ConcretelangFHELinalgTilingPassIncGen)

View File

@@ -1,21 +1,21 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_HLFHELINALG_TILING_PASS_H
#define CONCRETELANG_HLFHELINALG_TILING_PASS_H
#ifndef CONCRETELANG_FHELINALG_TILING_PASS_H
#define CONCRETELANG_FHELINALG_TILING_PASS_H
#include <mlir/Pass/Pass.h>
#include <concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h>
#include <concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.h>
#define GEN_PASS_CLASSES
#include <concretelang/Dialect/HLFHELinalg/Transforms/Tiling.h.inc>
#include <concretelang/Dialect/FHELinalg/Transforms/Tiling.h.inc>
namespace mlir {
namespace concretelang {
std::unique_ptr<mlir::OperationPass<>>
createHLFHELinalgTilingMarkerPass(llvm::ArrayRef<int64_t> tileSizes);
createFHELinalgTilingMarkerPass(llvm::ArrayRef<int64_t> tileSizes);
std::unique_ptr<mlir::OperationPass<>> createHLFHELinalgTilingPass();
std::unique_ptr<mlir::OperationPass<>> createFHELinalgTilingPass();
} // namespace concretelang
} // namespace mlir

View File

@@ -0,0 +1,22 @@
#ifndef CONCRETELANG_FHELINALG_TILING_PASS
#define CONCRETELANG_FHELINALG_TILING_PASS
include "mlir/Pass/PassBase.td"
def FHELinalgTilingMarker : Pass<"fhe-linalg-tiling-marker"> {
let summary =
"Marks FHELinalg operations for tiling using a vector of tile sizes";
let constructor = "mlir::concretelang::createFHELinalgTilingMarkerPass()";
let options = [];
let dependentDialects = [ "mlir::concretelang::FHELinalg::FHELinalgDialect" ];
}
def FHELinalgTiling : Pass<"fhe-linalg-tiling"> {
let summary = "Performs tiling of FHELinalg operations based on the "
"tile-size attribute";
let constructor = "mlir::concretelang::createFHELinalgTilingPass()";
let options = [];
let dependentDialects = [ "mlir::concretelang::FHELinalg::FHELinalgDialect" ];
}
#endif

View File

@@ -1,13 +0,0 @@
set(LLVM_TARGET_DEFINITIONS HLFHEOps.td)
mlir_tablegen(HLFHEOps.h.inc -gen-op-decls)
mlir_tablegen(HLFHEOps.cpp.inc -gen-op-defs)
mlir_tablegen(HLFHEOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=HLFHE)
mlir_tablegen(HLFHEOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=HLFHE)
mlir_tablegen(HLFHEOpsDialect.h.inc -gen-dialect-decls -dialect=HLFHE)
mlir_tablegen(HLFHEOpsDialect.cpp.inc -gen-dialect-defs -dialect=HLFHE)
add_public_tablegen_target(MLIRHLFHEOpsIncGen)
add_dependencies(mlir-headers MLIRHLFHEOpsIncGen)
add_concretelang_doc(HLFHEDialect HLFHEDialect concretelang/ -gen-dialect-doc)
add_concretelang_doc(HLFHEOps HLFHEOps concretelang/ -gen-op-doc)
add_concretelang_doc(HLFHETypes HLFHETypes concretelang/ -gen-typedef-doc)

View File

@@ -1,13 +0,0 @@
set(LLVM_TARGET_DEFINITIONS HLFHELinalgOps.td)
mlir_tablegen(HLFHELinalgOps.h.inc -gen-op-decls)
mlir_tablegen(HLFHELinalgOps.cpp.inc -gen-op-defs)
mlir_tablegen(HLFHELinalgOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=HLFHELinalg)
mlir_tablegen(HLFHELinalgOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=HLFHELinalg)
mlir_tablegen(HLFHELinalgOpsDialect.h.inc -gen-dialect-decls -dialect=HLFHELinalg)
mlir_tablegen(HLFHELinalgOpsDialect.cpp.inc -gen-dialect-defs -dialect=HLFHELinalg)
add_public_tablegen_target(MLIRHLFHELinalgOpsIncGen)
add_dependencies(mlir-headers MLIRHLFHELinalgOpsIncGen)
add_concretelang_doc(HLFHELinalgDialect HLFHELinalgDialect concretelang/ -gen-dialect-doc)
add_concretelang_doc(HLFHELinalgOps HLFHELinalgOps concretelang/ -gen-op-doc)
add_concretelang_doc(HLFHELinalgTypes HLFHELinalgTypes concretelang/ -gen-typedef-doc)

View File

@@ -1,15 +0,0 @@
#ifndef CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_DIALECT
#define CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_DIALECT
include "mlir/IR/OpBase.td"
def HLFHELinalg_Dialect : Dialect {
let name = "HLFHELinalg";
let summary = "High Level Fully Homorphic Encryption Linalg dialect";
let description = [{
A dialect for representation of high level linalg operations on fully homomorphic ciphertexts.
}];
let cppNamespace = "::mlir::concretelang::HLFHELinalg";
}
#endif

View File

@@ -1,11 +0,0 @@
#ifndef CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_TYPES
#define CONCRETELANG_DIALECT_HLFHELinalg_IR_HLFHELinalg_TYPES
include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.td"
include "mlir/IR/BuiltinTypes.td"
include "concretelang/Dialect/HLFHE/IR/HLFHETypes.td"
class HLFHELinalg_Type<string name, list<Trait> traits = []> :
TypeDef<HLFHELinalg_Dialect, name, traits> { }
#endif

View File

@@ -1,22 +0,0 @@
#ifndef CONCRETELANG_HLFHELINALG_TILING_PASS
#define CONCRETELANG_HLFHELINALG_TILING_PASS
include "mlir/Pass/PassBase.td"
def HLFHELinalgTilingMarker : Pass<"hlfhe-linalg-tiling-marker"> {
let summary =
"Marks HLFHELinalg operations for tiling using a vector of tile sizes";
let constructor = "mlir::concretelang::createHLFHELinalgTilingMarkerPass()";
let options = [];
let dependentDialects = [ "mlir::concretelang::HLFHELinalg::HLFHELinalgDialect" ];
}
def HLFHELinalgTiling : Pass<"hlfhe-linalg-tiling"> {
let summary = "Performs tiling of HLFHELinalg operations based on the "
"tile-size attribute";
let constructor = "mlir::concretelang::createHLFHELinalgTilingPass()";
let options = [];
let dependentDialects = [ "mlir::concretelang::HLFHELinalg::HLFHELinalgDialect" ];
}
#endif

View File

@@ -1,13 +0,0 @@
set(LLVM_TARGET_DEFINITIONS LowLFHEOps.td)
mlir_tablegen(LowLFHEOps.h.inc -gen-op-decls)
mlir_tablegen(LowLFHEOps.cpp.inc -gen-op-defs)
mlir_tablegen(LowLFHEOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=LowLFHE)
mlir_tablegen(LowLFHEOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=LowLFHE)
mlir_tablegen(LowLFHEOpsDialect.h.inc -gen-dialect-decls -dialect=LowLFHE)
mlir_tablegen(LowLFHEOpsDialect.cpp.inc -gen-dialect-defs -dialect=LowLFHE)
add_public_tablegen_target(MLIRLowLFHEOpsIncGen)
add_dependencies(mlir-headers MLIRLowLFHEOpsIncGen)
add_concretelang_doc(LowLFHEDialect LowLFHEDialect concretelang/ -gen-dialect-doc)
add_concretelang_doc(LowLFHEOps LowLFHEOps concretelang/ -gen-op-doc)
add_concretelang_doc(LowLFHETypes LowLFHETypes concretelang/ -gen-typedef-doc)

View File

@@ -1,15 +0,0 @@
#ifndef CONCRETELANG_DIALECT_LowLFHE_IR_LowLFHE_DIALECT
#define CONCRETELANG_DIALECT_LowLFHE_IR_LowLFHE_DIALECT
include "mlir/IR/OpBase.td"
def LowLFHE_Dialect : Dialect {
let name = "LowLFHE";
let summary = "Low Level Fully Homorphic Encryption dialect";
let description = [{
A dialect for representation of low level operation on fully homomorphic ciphertext.
}];
let cppNamespace = "::mlir::concretelang::LowLFHE";
}
#endif

View File

@@ -1,13 +0,0 @@
set(LLVM_TARGET_DEFINITIONS MidLFHEOps.td)
mlir_tablegen(MidLFHEOps.h.inc -gen-op-decls)
mlir_tablegen(MidLFHEOps.cpp.inc -gen-op-defs)
mlir_tablegen(MidLFHEOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=MidLFHE)
mlir_tablegen(MidLFHEOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=MidLFHE)
mlir_tablegen(MidLFHEOpsDialect.h.inc -gen-dialect-decls -dialect=MidLFHE)
mlir_tablegen(MidLFHEOpsDialect.cpp.inc -gen-dialect-defs -dialect=MidLFHE)
add_public_tablegen_target(MLIRMidLFHEOpsIncGen)
add_dependencies(mlir-headers MLIRMidLFHEOpsIncGen)
add_concretelang_doc(MidLFHEDialect MidLFHEDialect concretelang/ -gen-dialect-doc)
add_concretelang_doc(MidLFHEOps MidLFHEOps concretelang/ -gen-op-doc)
add_concretelang_doc(MidLFHETypes MidLFHETypes concretelang/ -gen-typedef-doc)

View File

@@ -8,10 +8,10 @@ def BuildDataflowTaskGraph : Pass<"BuildDataflowTaskGraph", "mlir::ModuleOp"> {
"Identify profitable dataflow tasks and build DataflowTaskGraph.";
let description = [{
This pass builds a dataflow graph out of a HLFHE program.
This pass builds a dataflow graph out of a FHE program.
In its current incarnation, it considers some heavier weight
operations (e.g., HLFHELinalg Dot and Matmult or bootstraps) as
operations (e.g., FHELinalg Dot and Matmult or bootstraps) as
candidates for being executed in a discrete task, and then
sinks within the task the lighter weight operation that do not
increase the graph cut (amount of dependences in or out).
@@ -23,21 +23,21 @@ def BuildDataflowTaskGraph : Pass<"BuildDataflowTaskGraph", "mlir::ModuleOp"> {
Example:
```mlir
func @main(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> {
%0 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1) : (tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>>
return %0 : tensor<3x2x!HLFHE.eint<2>>
func @main(%arg0: tensor<3x4x!FHE.eint<2>>, %arg1: tensor<4x2xi3>) -> tensor<3x2x!FHE.eint<2>> {
%0 = "FHELinalg.matmul_eint_int"(%arg0, %arg1) : (tensor<3x4x!FHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!FHE.eint<2>>
return %0 : tensor<3x2x!FHE.eint<2>>
}
```
Will result in generating a dataflow task for the Matmul operation:
```mlir
func @main(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> {
func @main(%arg0: tensor<3x4x!FHE.eint<2>>, %arg1: tensor<4x2xi3>) -> tensor<3x2x!FHE.eint<2>> {
%0 = "RT.dataflow_task"(%arg0, %arg1) ( {
%1 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1) : (tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>>
"RT.dataflow_yield"(%1) : (tensor<3x2x!HLFHE.eint<2>>) -> ()
}) : (tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>>
return %0 : tensor<3x2x!HLFHE.eint<2>>
%1 = "FHELinalg.matmul_eint_int"(%arg0, %arg1) : (tensor<3x4x!FHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!FHE.eint<2>>
"RT.dataflow_yield"(%1) : (tensor<3x2x!FHE.eint<2>>) -> ()
}) : (tensor<3x4x!FHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!FHE.eint<2>>
return %0 : tensor<3x2x!FHE.eint<2>>
}
```
}];

View File

@@ -1,5 +1,5 @@
#ifndef CONCRETELANG_DIALECT_HLFHE_IR_HLFHE_TYPES
#define CONCRETELANG_DIALECT_HLFHE_IR_HLFHE_TYPES
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHE_TYPES
#define CONCRETELANG_DIALECT_FHE_IR_FHE_TYPES
include "concretelang/Dialect/RT/IR/RTDialect.td"
include "mlir/IR/BuiltinTypes.td"

View File

@@ -0,0 +1,13 @@
set(LLVM_TARGET_DEFINITIONS TFHEOps.td)
mlir_tablegen(TFHEOps.h.inc -gen-op-decls)
mlir_tablegen(TFHEOps.cpp.inc -gen-op-defs)
mlir_tablegen(TFHEOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=TFHE)
mlir_tablegen(TFHEOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=TFHE)
mlir_tablegen(TFHEOpsDialect.h.inc -gen-dialect-decls -dialect=TFHE)
mlir_tablegen(TFHEOpsDialect.cpp.inc -gen-dialect-defs -dialect=TFHE)
add_public_tablegen_target(MLIRTFHEOpsIncGen)
add_dependencies(mlir-headers MLIRTFHEOpsIncGen)
add_concretelang_doc(TFHEDialect TFHEDialect concretelang/ -gen-dialect-doc)
add_concretelang_doc(TFHEOps TFHEOps concretelang/ -gen-op-doc)
add_concretelang_doc(TFHETypes TFHETypes concretelang/ -gen-typedef-doc)

View File

@@ -1,8 +1,8 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_DIALECT_LowLFHE_IR_LowLFHEDIALECT_H
#define CONCRETELANG_DIALECT_LowLFHE_IR_LowLFHEDIALECT_H
#ifndef CONCRETELANG_DIALECT_TFHE_IR_TFHEDIALECT_H
#define CONCRETELANG_DIALECT_TFHE_IR_TFHEDIALECT_H
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
@@ -11,6 +11,6 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "concretelang/Dialect/LowLFHE/IR/LowLFHEOpsDialect.h.inc"
#include "concretelang/Dialect/TFHE/IR/TFHEOpsDialect.h.inc"
#endif

View File

@@ -1,4 +1,4 @@
//===- MidLFHEDialect.td - MidLFHE dialect ----------------*- tablegen -*-===//
//===- TFHEDialect.td - TFHE dialect ----------------*- tablegen -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,18 +6,18 @@
//
//===----------------------------------------------------------------------===//
#ifndef CONCRETELANG_DIALECT_MidLFHE_IR_MidLFHE_DIALECT
#define CONCRETELANG_DIALECT_MidLFHE_IR_MidLFHE_DIALECT
#ifndef CONCRETELANG_DIALECT_TFHE_IR_TFHE_DIALECT
#define CONCRETELANG_DIALECT_TFHE_IR_TFHE_DIALECT
include "mlir/IR/OpBase.td"
def MidLFHE_Dialect : Dialect {
let name = "MidLFHE";
def TFHE_Dialect : Dialect {
let name = "TFHE";
let summary = "High Level Fully Homorphic Encryption dialect";
let description = [{
A dialect for representation of high level operation on fully homomorphic ciphertext.
}];
let cppNamespace = "::mlir::concretelang::MidLFHE";
let cppNamespace = "::mlir::concretelang::TFHE";
}
#endif

View File

@@ -1,8 +1,8 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_DIALECT_LowLFHE_LowLFHE_OPS_H
#define CONCRETELANG_DIALECT_LowLFHE_LowLFHE_OPS_H
#ifndef CONCRETELANG_DIALECT_TFHE_IR_TFHEOPS_H
#define CONCRETELANG_DIALECT_TFHE_IR_TFHEOPS_H
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
@@ -10,9 +10,9 @@
#include <mlir/Interfaces/ControlFlowInterfaces.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include "concretelang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
#define GET_OP_CLASSES
#include "concretelang/Dialect/LowLFHE/IR/LowLFHEOps.h.inc"
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h.inc"
#endif

View File

@@ -1,4 +1,4 @@
//===- MidLFHEOps.td - High level FHE dialect ops ----------------*- tablegen -*-===//
//===- TFHEOps.td - High level FHE dialect ops ----------------*- tablegen -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,84 +6,84 @@
//
//===----------------------------------------------------------------------===//
#ifndef CONCRETELANG_DIALECT_MidLFHE_IR_MidLFHE_OPS
#define CONCRETELANG_DIALECT_MidLFHE_IR_MidLFHE_OPS
#ifndef CONCRETELANG_DIALECT_TFHE_IR_TFHE_OPS
#define CONCRETELANG_DIALECT_TFHE_IR_TFHE_OPS
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "concretelang/Dialect/MidLFHE/IR/MidLFHEDialect.td"
include "concretelang/Dialect/MidLFHE/IR/MidLFHETypes.td"
include "concretelang/Dialect/TFHE/IR/TFHEDialect.td"
include "concretelang/Dialect/TFHE/IR/TFHETypes.td"
class MidLFHE_Op<string mnemonic, list<OpTrait> traits = []> :
Op<MidLFHE_Dialect, mnemonic, traits>;
class TFHE_Op<string mnemonic, list<OpTrait> traits = []> :
Op<TFHE_Dialect, mnemonic, traits>;
def ZeroGLWEOp : MidLFHE_Op<"zero"> {
def ZeroGLWEOp : TFHE_Op<"zero"> {
let summary = "Returns a trivial encyption of 0";
let arguments = (ins);
let results = (outs GLWECipherTextType:$out);
}
def AddGLWEIntOp : MidLFHE_Op<"add_glwe_int"> {
def AddGLWEIntOp : TFHE_Op<"add_glwe_int"> {
let summary = "Returns the sum of a clear integer and a lwe ciphertext";
let arguments = (ins GLWECipherTextType:$a, AnyInteger:$b);
let results = (outs GLWECipherTextType);
let verifier = [{
return mlir::concretelang::MidLFHE::verifyGLWEIntegerOperator<AddGLWEIntOp>(*this);
return mlir::concretelang::TFHE::verifyGLWEIntegerOperator<AddGLWEIntOp>(*this);
}];
}
def AddGLWEOp : MidLFHE_Op<"add_glwe"> {
def AddGLWEOp : TFHE_Op<"add_glwe"> {
let summary = "Returns the sum of 2 lwe ciphertexts";
let arguments = (ins GLWECipherTextType:$a, GLWECipherTextType:$b);
let results = (outs GLWECipherTextType);
let verifier = [{
return ::mlir::concretelang::MidLFHE::verifyBinaryGLWEOperator<AddGLWEOp>(*this);
return ::mlir::concretelang::TFHE::verifyBinaryGLWEOperator<AddGLWEOp>(*this);
}];
}
def SubIntGLWEOp : MidLFHE_Op<"sub_int_glwe"> {
def SubIntGLWEOp : TFHE_Op<"sub_int_glwe"> {
let summary = "Substracts an integer and a GLWE ciphertext";
let arguments = (ins AnyInteger:$a, GLWECipherTextType:$b);
let results = (outs GLWECipherTextType);
let verifier = [{
return ::mlir::concretelang::MidLFHE::verifyIntegerGLWEOperator(*this);
return ::mlir::concretelang::TFHE::verifyIntegerGLWEOperator(*this);
}];
}
def NegGLWEOp : MidLFHE_Op<"neg_glwe"> {
def NegGLWEOp : TFHE_Op<"neg_glwe"> {
let summary = "Negates a glwe ciphertext";
let arguments = (ins GLWECipherTextType:$a);
let results = (outs GLWECipherTextType);
let verifier = [{
return ::mlir::concretelang::MidLFHE::verifyUnaryGLWEOperator<NegGLWEOp>(*this);
return ::mlir::concretelang::TFHE::verifyUnaryGLWEOperator<NegGLWEOp>(*this);
}];
}
def MulGLWEIntOp : MidLFHE_Op<"mul_glwe_int"> {
def MulGLWEIntOp : TFHE_Op<"mul_glwe_int"> {
let summary = "Returns the product of a clear integer and a lwe ciphertext";
let arguments = (ins GLWECipherTextType:$a, AnyInteger:$b);
let results = (outs GLWECipherTextType);
let verifier = [{
return mlir::concretelang::MidLFHE::verifyGLWEIntegerOperator<MulGLWEIntOp>(*this);
return mlir::concretelang::TFHE::verifyGLWEIntegerOperator<MulGLWEIntOp>(*this);
}];
}
def ApplyLookupTable : MidLFHE_Op<"apply_lookup_table"> {
def ApplyLookupTable : TFHE_Op<"apply_lookup_table"> {
let summary = "Applies a lookup table to a GLWE ciphertext";
@@ -96,7 +96,7 @@ def ApplyLookupTable : MidLFHE_Op<"apply_lookup_table"> {
let results = (outs GLWECipherTextType);
let verifier = [{
return ::mlir::concretelang::MidLFHE::verifyApplyLookupTable(*this);
return ::mlir::concretelang::TFHE::verifyApplyLookupTable(*this);
}];
}

View File

@@ -1,8 +1,8 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#ifndef CONCRETELANG_DIALECT_MIDLFHE_IR_MIDLFHETYPES_H
#define CONCRETELANG_DIALECT_MIDLFHE_IR_MIDLFHETYPES_H
#ifndef CONCRETELANG_DIALECT_TFHE_IR_TFHETYPES_H
#define CONCRETELANG_DIALECT_TFHE_IR_TFHETYPES_H
#include "llvm/ADT/TypeSwitch.h"
#include <mlir/Dialect/StandardOps/IR/Ops.h>
@@ -11,6 +11,6 @@
#include <mlir/IR/DialectImplementation.h>
#define GET_TYPEDEF_CLASSES
#include "concretelang/Dialect/MidLFHE/IR/MidLFHEOpsTypes.h.inc"
#include "concretelang/Dialect/TFHE/IR/TFHEOpsTypes.h.inc"
#endif

View File

@@ -1,14 +1,14 @@
#ifndef CONCRETELANG_DIALECT_MidLFHE_IR_MidLFHE_TYPES
#define CONCRETELANG_DIALECT_MidLFHE_IR_MidLFHE_TYPES
#ifndef CONCRETELANG_DIALECT_TFHE_IR_TFHE_TYPES
#define CONCRETELANG_DIALECT_TFHE_IR_TFHE_TYPES
// TODO: MLWE / GSW
include "concretelang/Dialect/MidLFHE/IR/MidLFHEDialect.td"
include "concretelang/Dialect/TFHE/IR/TFHEDialect.td"
include "mlir/IR/BuiltinTypes.td"
class MidLFHE_Type<string name, list<Trait> traits = []> : TypeDef<MidLFHE_Dialect, name, traits> { }
class TFHE_Type<string name, list<Trait> traits = []> : TypeDef<TFHE_Dialect, name, traits> { }
def GLWECipherTextType : MidLFHE_Type<"GLWECipherText", [MemRefElementTypeInterface]> {
def GLWECipherTextType : TFHE_Type<"GLWECipherText", [MemRefElementTypeInterface]> {
let mnemonic = "glwe";
let summary = "A GLWE ciphertext";

View File

@@ -89,22 +89,22 @@ public:
ROUND_TRIP,
// Read sources and exit before any lowering
HLFHE,
FHE,
// Read sources and lower all HLFHE operations to MidLFHE
// Read sources and lower all FHE operations to TFHE
// operations
MIDLFHE,
TFHE,
// Read sources and lower all HLFHE and MidLFHE operations to LowLFHE
// Read sources and lower all FHE and TFHE operations to Concrete
// operations
LOWLFHE,
CONCRETE,
// Read sources and lower all HLFHE, MidLFHE and LowLFHE
// Read sources and lower all FHE, TFHE and Concrete
// operations to canonical MLIR dialects. Cryptographic operations
// are lowered to invocations of the concrete library.
STD,
// Read sources and lower all HLFHE, MidLFHE and LowLFHE
// Read sources and lower all FHE, TFHE and Concrete
// operations to operations from the LLVM dialect. Cryptographic
// operations are lowered to invocations of the concrete library.
LLVM,
@@ -152,14 +152,14 @@ public:
void setAutoParallelize(bool v);
void setGenerateClientParameters(bool v);
void setClientParametersFuncName(const llvm::StringRef &name);
void setHLFHELinalgTileSizes(llvm::ArrayRef<int64_t> sizes);
void setFHELinalgTileSizes(llvm::ArrayRef<int64_t> sizes);
void setEnablePass(std::function<bool(mlir::Pass *)> enablePass);
protected:
llvm::Optional<size_t> overrideMaxEintPrecision;
llvm::Optional<size_t> overrideMaxMANP;
llvm::Optional<std::string> clientParametersFuncName;
llvm::Optional<std::vector<int64_t>> hlfhelinalgTileSizes;
llvm::Optional<std::vector<int64_t>> fhelinalgTileSizes;
bool verifyDiagnostics;
bool autoParallelize;

View File

@@ -19,29 +19,29 @@ mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
llvm::Expected<llvm::Optional<mlir::concretelang::V0FHEConstraint>>
getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
getFHEConstraintsFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
tileMarkedHLFHELinalg(mlir::MLIRContext &context, mlir::ModuleOp &module,
tileMarkedFHELinalg(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
markHLFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module,
markFHELinalgForTiling(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::ArrayRef<int64_t> tileSizes,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
lowerHLFHEToMidLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
lowerMidLFHEToLowLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
lowerLowLFHEToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
lowerConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult

View File

@@ -11,11 +11,11 @@ declare_mlir_python_extension(ConcretelangBindingsPythonExtension.Core
ADD_TO_PARENT ConcretelangBindingsPythonExtension
SOURCES
ConcretelangModule.cpp
HLFHEModule.cpp
FHEModule.cpp
CompilerAPIModule.cpp
EMBED_CAPI_LINK_LIBS
CONCRETELANGCAPIHLFHE
CONCRETELANGCAPIHLFHELINALG
CONCRETELANGCAPIFHE
CONCRETELANGCAPIFHELINALG
CONCRETELANGCAPISupport
)
@@ -42,20 +42,20 @@ declare_mlir_python_sources(ConcretelangBindingsPythonSources.Dialects
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT ConcretelangBindingsPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
CONCRETELANGBindingsPythonHLFHEOps
TD_FILE concrete/lang/dialects/HLFHEOps.td
CONCRETELANGBindingsPythonFHEOps
TD_FILE concrete/lang/dialects/FHEOps.td
SOURCES
concrete/lang/dialects/hlfhe.py
DIALECT_NAME HLFHE)
concrete/lang/dialects/fhe.py
DIALECT_NAME FHE)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT ConcretelangBindingsPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
CONCRETELANGBindingsPythonHLFHELinalgOps
TD_FILE concrete/lang/dialects/HLFHELinalgOps.td
CONCRETELANGBindingsPythonFHELinalgOps
TD_FILE concrete/lang/dialects/FHELinalgOps.td
SOURCES
concrete/lang/dialects/hlfhelinalg.py
DIALECT_NAME HLFHELinalg)
concrete/lang/dialects/fhelinalg.py
DIALECT_NAME FHELinalg)
################################################################################

View File

@@ -3,7 +3,7 @@
#include "CompilerAPIModule.h"
#include "concretelang-c/Support/CompilerEngine.h"
#include "concretelang/Dialect/HLFHE/IR/HLFHEOpsDialect.h.inc"
#include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc"
#include "concretelang/Support/Jit.h"
#include "concretelang/Support/JitCompilerEngine.h"
#include <mlir/Dialect/MemRef/IR/MemRef.h>

View File

@@ -7,8 +7,8 @@
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Registration.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "concretelang-c/Dialect/HLFHE.h"
#include "concretelang-c/Dialect/HLFHELinalg.h"
#include "concretelang-c/Dialect/FHE.h"
#include "concretelang-c/Dialect/FHELinalg.h"
#include "llvm-c/ErrorHandling.h"
#include "llvm/Support/Signals.h"
@@ -29,17 +29,17 @@ PYBIND11_MODULE(_concretelang, m) {
MlirContext context = mlirPythonCapsuleToContext(wrappedCapsule.ptr());
// Collect Concretelang dialects to register.
MlirDialectHandle hlfhe = mlirGetDialectHandle__hlfhe__();
mlirDialectHandleRegisterDialect(hlfhe, context);
mlirDialectHandleLoadDialect(hlfhe, context);
MlirDialectHandle hlfhelinalg = mlirGetDialectHandle__hlfhelinalg__();
mlirDialectHandleRegisterDialect(hlfhelinalg, context);
mlirDialectHandleLoadDialect(hlfhelinalg, context);
MlirDialectHandle fhe = mlirGetDialectHandle__fhe__();
mlirDialectHandleRegisterDialect(fhe, context);
mlirDialectHandleLoadDialect(fhe, context);
MlirDialectHandle fhelinalg = mlirGetDialectHandle__fhelinalg__();
mlirDialectHandleRegisterDialect(fhelinalg, context);
mlirDialectHandleLoadDialect(fhelinalg, context);
},
"Register Concretelang dialects on a PyMlirContext.");
py::module hlfhe = m.def_submodule("_hlfhe", "HLFHE API");
mlir::concretelang::python::populateDialectHLFHESubmodule(hlfhe);
py::module fhe = m.def_submodule("_fhe", "FHE API");
mlir::concretelang::python::populateDialectFHESubmodule(fhe);
py::module api = m.def_submodule("_compiler", "Compiler API");
mlir::concretelang::python::populateCompilerAPISubmodule(api);

View File

@@ -10,7 +10,7 @@ namespace mlir {
namespace concretelang {
namespace python {
void populateDialectHLFHESubmodule(pybind11::module &m);
void populateDialectFHESubmodule(pybind11::module &m);
} // namespace python
} // namespace concretelang

View File

@@ -3,7 +3,7 @@
#include "DialectModules.h"
#include "concretelang-c/Dialect/HLFHE.h"
#include "concretelang-c/Dialect/FHE.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
@@ -18,13 +18,13 @@
using namespace mlir::concretelang;
using namespace mlir::python::adaptors;
/// Populate the hlfhe python module.
void mlir::concretelang::python::populateDialectHLFHESubmodule(
/// Populate the fhe python module.
void mlir::concretelang::python::populateDialectFHESubmodule(
pybind11::module &m) {
m.doc() = "HLFHE dialect Python native extension";
m.doc() = "FHE dialect Python native extension";
mlir_type_subclass(m, "EncryptedIntegerType",
hlfheTypeIsAnEncryptedIntegerType)
fheTypeIsAnEncryptedIntegerType)
.def_classmethod("get", [](pybind11::object cls, MlirContext ctx,
unsigned width) {
// We want the user to receive a python exception for not being able to
@@ -33,6 +33,6 @@ void mlir::concretelang::python::populateDialectHLFHESubmodule(
throw std::invalid_argument("can't create eint with the given width");
};
return cls(
hlfheEncryptedIntegerTypeGetChecked(ctx, width, emitException));
fheEncryptedIntegerTypeGetChecked(ctx, width, emitException));
});
}

View File

@@ -0,0 +1,7 @@
#ifndef PYTHON_BINDINGS_FHELINALG_OPS
#define PYTHON_BINDINGS_FHELINALG_OPS
include "mlir/Bindings/Python/Attributes.td"
include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td"
#endif

View File

@@ -0,0 +1,7 @@
#ifndef PYTHON_BINDINGS_FHE_OPS
#define PYTHON_BINDINGS_FHE_OPS
include "mlir/Bindings/Python/Attributes.td"
include "concretelang/Dialect/FHE/IR/FHEOps.td"
#endif

View File

@@ -1,7 +0,0 @@
#ifndef PYTHON_BINDINGS_HLFHELINALG_OPS
#define PYTHON_BINDINGS_HLFHELINALG_OPS
include "mlir/Bindings/Python/Attributes.td"
include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td"
#endif

View File

@@ -1,7 +0,0 @@
#ifndef PYTHON_BINDINGS_HLFHE_OPS
#define PYTHON_BINDINGS_HLFHE_OPS
include "mlir/Bindings/Python/Attributes.td"
include "concretelang/Dialect/HLFHE/IR/HLFHEOps.td"
#endif

View File

@@ -1,6 +1,6 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
"""HLFHE dialect module"""
from ._HLFHE_ops_gen import *
from mlir._mlir_libs._concretelang._hlfhe import *
"""FHE dialect module"""
from ._FHE_ops_gen import *
from mlir._mlir_libs._concretelang._fhe import *

View File

@@ -1,5 +1,5 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
"""HLFHELinalg dialect module"""
from ._HLFHELinalg_ops_gen import *
"""FHELinalg dialect module"""
from ._FHELinalg_ops_gen import *

View File

@@ -1,2 +1,2 @@
add_subdirectory(HLFHE)
add_subdirectory(HLFHELinalg)
add_subdirectory(FHE)
add_subdirectory(FHELinalg)

View File

@@ -0,0 +1,10 @@
set(LLVM_OPTIONAL_SOURCES FHE.cpp)
add_mlir_public_c_api_library(CONCRETELANGCAPIFHE
FHE.cpp
LINK_LIBS PUBLIC
MLIRCAPIIR
FHEDialect
)

View File

@@ -1,31 +1,31 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#include "concretelang-c/Dialect/HLFHE.h"
#include "concretelang-c/Dialect/FHE.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/CAPI/Support.h"
#include "concretelang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "concretelang/Dialect/HLFHE/IR/HLFHEOps.h"
#include "concretelang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
using namespace mlir::concretelang::HLFHE;
using namespace mlir::concretelang::FHE;
//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(HLFHE, hlfhe, HLFHEDialect)
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(FHE, fhe, FHEDialect)
//===----------------------------------------------------------------------===//
// Type API.
//===----------------------------------------------------------------------===//
bool hlfheTypeIsAnEncryptedIntegerType(MlirType type) {
bool fheTypeIsAnEncryptedIntegerType(MlirType type) {
return unwrap(type).isa<EncryptedIntegerType>();
}
MlirType hlfheEncryptedIntegerTypeGetChecked(
MlirType fheEncryptedIntegerTypeGetChecked(
MlirContext ctx, unsigned width,
mlir::function_ref<mlir::InFlightDiagnostic()> emitError) {
return wrap(EncryptedIntegerType::getChecked(emitError, unwrap(ctx), width));

View File

@@ -0,0 +1,10 @@
set(LLVM_OPTIONAL_SOURCES FHELinalg.cpp)
add_mlir_public_c_api_library(CONCRETELANGCAPIFHELINALG
FHELinalg.cpp
LINK_LIBS PUBLIC
MLIRCAPIIR
FHELinalgDialect
)

View File

@@ -1,19 +1,19 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
// See https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license information.
#include "concretelang-c/Dialect/HLFHELinalg.h"
#include "concretelang-c/Dialect/FHELinalg.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/CAPI/Support.h"
#include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h"
#include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h"
#include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgTypes.h"
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.h"
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h"
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.h"
using namespace mlir::concretelang::HLFHELinalg;
using namespace mlir::concretelang::FHELinalg;
//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(HLFHELinalg, hlfhelinalg,
HLFHELinalgDialect)
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(FHELinalg, fhelinalg,
FHELinalgDialect)

View File

@@ -1,10 +0,0 @@
set(LLVM_OPTIONAL_SOURCES HLFHE.cpp)
add_mlir_public_c_api_library(CONCRETELANGCAPIHLFHE
HLFHE.cpp
LINK_LIBS PUBLIC
MLIRCAPIIR
HLFHEDialect
)

View File

@@ -1,10 +0,0 @@
set(LLVM_OPTIONAL_SOURCES HLFHELinalg.cpp)
add_mlir_public_c_api_library(CONCRETELANGCAPIHLFHELINALG
HLFHELinalg.cpp
LINK_LIBS PUBLIC
MLIRCAPIIR
HLFHELinalgDialect
)

View File

@@ -1,7 +1,7 @@
add_subdirectory(HLFHEToMidLFHE)
add_subdirectory(MidLFHEGlobalParametrization)
add_subdirectory(MidLFHEToLowLFHE)
add_subdirectory(HLFHETensorOpsToLinalg)
add_subdirectory(LowLFHEToConcreteCAPI)
add_subdirectory(FHEToTFHE)
add_subdirectory(TFHEGlobalParametrization)
add_subdirectory(TFHEToConcrete)
add_subdirectory(FHETensorOpsToLinalg)
add_subdirectory(ConcreteToConcreteCAPI)
add_subdirectory(MLIRLowerableDialectsToLLVM)
add_subdirectory(LowLFHEUnparametrize)
add_subdirectory(ConcreteUnparametrize)

View File

@@ -0,0 +1,16 @@
add_mlir_dialect_library(ConcreteToConcreteCAPI
ConcreteToConcreteCAPI.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
DEPENDS
ConcreteDialect
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRTransforms
)
target_link_libraries(ConcreteToConcreteCAPI PUBLIC MLIRIR)

View File

@@ -10,20 +10,20 @@
#include "mlir/Transforms/DialectConversion.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
#include "concretelang/Dialect/LowLFHE/IR/LowLFHEOps.h"
#include "concretelang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
#include "concretelang/Support/Constants.h"
class LowLFHEToConcreteCAPITypeConverter : public mlir::TypeConverter {
class ConcreteToConcreteCAPITypeConverter : public mlir::TypeConverter {
public:
LowLFHEToConcreteCAPITypeConverter() {
ConcreteToConcreteCAPITypeConverter() {
addConversion([](mlir::Type type) { return type; });
addConversion([&](mlir::concretelang::LowLFHE::PlaintextType type) {
addConversion([&](mlir::concretelang::Concrete::PlaintextType type) {
return mlir::IntegerType::get(type.getContext(), 64);
});
addConversion([&](mlir::concretelang::LowLFHE::CleartextType type) {
addConversion([&](mlir::concretelang::Concrete::CleartextType type) {
return mlir::IntegerType::get(type.getContext(), 64);
});
}
@@ -64,56 +64,56 @@ mlir::LogicalResult insertForwardDeclaration(mlir::Operation *op,
// allocate them. All the calls to the C API should be done using this generic
// types, and casting should then be performed back to the appropriate type.
inline mlir::concretelang::LowLFHE::LweCiphertextType
inline mlir::concretelang::Concrete::LweCiphertextType
getGenericLweCiphertextType(mlir::MLIRContext *context) {
return mlir::concretelang::LowLFHE::LweCiphertextType::get(context, -1, -1);
return mlir::concretelang::Concrete::LweCiphertextType::get(context, -1, -1);
}
inline mlir::concretelang::LowLFHE::GlweCiphertextType
inline mlir::concretelang::Concrete::GlweCiphertextType
getGenericGlweCiphertextType(mlir::MLIRContext *context) {
return mlir::concretelang::LowLFHE::GlweCiphertextType::get(context);
return mlir::concretelang::Concrete::GlweCiphertextType::get(context);
}
inline mlir::concretelang::LowLFHE::PlaintextType
inline mlir::concretelang::Concrete::PlaintextType
getGenericPlaintextType(mlir::MLIRContext *context) {
return mlir::concretelang::LowLFHE::PlaintextType::get(context, -1);
return mlir::concretelang::Concrete::PlaintextType::get(context, -1);
}
inline mlir::concretelang::LowLFHE::PlaintextListType
inline mlir::concretelang::Concrete::PlaintextListType
getGenericPlaintextListType(mlir::MLIRContext *context) {
return mlir::concretelang::LowLFHE::PlaintextListType::get(context);
return mlir::concretelang::Concrete::PlaintextListType::get(context);
}
inline mlir::concretelang::LowLFHE::ForeignPlaintextListType
inline mlir::concretelang::Concrete::ForeignPlaintextListType
getGenericForeignPlaintextListType(mlir::MLIRContext *context) {
return mlir::concretelang::LowLFHE::ForeignPlaintextListType::get(context);
return mlir::concretelang::Concrete::ForeignPlaintextListType::get(context);
}
inline mlir::concretelang::LowLFHE::CleartextType
inline mlir::concretelang::Concrete::CleartextType
getGenericCleartextType(mlir::MLIRContext *context) {
return mlir::concretelang::LowLFHE::CleartextType::get(context, -1);
return mlir::concretelang::Concrete::CleartextType::get(context, -1);
}
inline mlir::concretelang::LowLFHE::LweBootstrapKeyType
inline mlir::concretelang::Concrete::LweBootstrapKeyType
getGenericLweBootstrapKeyType(mlir::MLIRContext *context) {
return mlir::concretelang::LowLFHE::LweBootstrapKeyType::get(context);
return mlir::concretelang::Concrete::LweBootstrapKeyType::get(context);
}
inline mlir::concretelang::LowLFHE::LweKeySwitchKeyType
inline mlir::concretelang::Concrete::LweKeySwitchKeyType
getGenericLweKeySwitchKeyType(mlir::MLIRContext *context) {
return mlir::concretelang::LowLFHE::LweKeySwitchKeyType::get(context);
return mlir::concretelang::Concrete::LweKeySwitchKeyType::get(context);
}
// Get the generic version of the type.
// Useful when iterating over a set of types.
mlir::Type getGenericType(mlir::Type baseType) {
if (baseType.isa<mlir::concretelang::LowLFHE::LweCiphertextType>()) {
if (baseType.isa<mlir::concretelang::Concrete::LweCiphertextType>()) {
return getGenericLweCiphertextType(baseType.getContext());
}
if (baseType.isa<mlir::concretelang::LowLFHE::PlaintextType>()) {
if (baseType.isa<mlir::concretelang::Concrete::PlaintextType>()) {
return getGenericPlaintextType(baseType.getContext());
}
if (baseType.isa<mlir::concretelang::LowLFHE::CleartextType>()) {
if (baseType.isa<mlir::concretelang::Concrete::CleartextType>()) {
return getGenericCleartextType(baseType.getContext());
}
return baseType;
@@ -138,7 +138,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
auto genericBSKType = getGenericLweBootstrapKeyType(rewriter.getContext());
auto genericKSKType = getGenericLweKeySwitchKeyType(rewriter.getContext());
auto contextType =
mlir::concretelang::LowLFHE::ContextType::get(rewriter.getContext());
mlir::concretelang::Concrete::ContextType::get(rewriter.getContext());
auto errType = mlir::IndexType::get(rewriter.getContext());
@@ -321,7 +321,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
return mlir::success();
}
/// LowLFHEOpToConcreteCAPICallPattern<Op> match the `Op` Operation and
/// ConcreteOpToConcreteCAPICallPattern<Op> match the `Op` Operation and
/// replace with a call to `funcName`, the funcName should be an external
/// function that was linked later. It insert the forward declaration of the
/// private `funcName` if it not already in the symbol table.
@@ -336,8 +336,8 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
/// call_op(err, out, arg0, arg1);
/// ```
template <typename Op>
struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
LowLFHEOpToConcreteCAPICallPattern(
struct ConcreteOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
ConcreteOpToConcreteCAPICallPattern(
mlir::MLIRContext *context, mlir::StringRef funcName,
mlir::StringRef allocName,
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
@@ -346,11 +346,11 @@ struct LowLFHEOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
mlir::LogicalResult
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
LowLFHEToConcreteCAPITypeConverter typeConverter;
ConcreteToConcreteCAPITypeConverter typeConverter;
mlir::Type resultType = op->getResultTypes().front();
auto lweResultType =
resultType.cast<mlir::concretelang::LowLFHE::LweCiphertextType>();
resultType.cast<mlir::concretelang::Concrete::LweCiphertextType>();
// Replace the operation with a call to the `funcName`
{
// Create the err value
@@ -402,21 +402,21 @@ private:
std::string allocName;
};
struct LowLFHEZeroOpPattern
: public mlir::OpRewritePattern<mlir::concretelang::LowLFHE::ZeroLWEOp> {
LowLFHEZeroOpPattern(
struct ConcreteZeroOpPattern
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::ZeroLWEOp> {
ConcreteZeroOpPattern(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::concretelang::LowLFHE::ZeroLWEOp>(context,
: mlir::OpRewritePattern<mlir::concretelang::Concrete::ZeroLWEOp>(context,
benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::LowLFHE::ZeroLWEOp op,
matchAndRewrite(mlir::concretelang::Concrete::ZeroLWEOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Type resultType = op->getResultTypes().front();
auto lweResultType =
resultType.cast<mlir::concretelang::LowLFHE::LweCiphertextType>();
resultType.cast<mlir::concretelang::Concrete::LweCiphertextType>();
// Create the err value
auto errOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(0));
@@ -438,16 +438,16 @@ struct LowLFHEZeroOpPattern
};
};
struct LowLFHEEncodeIntOpPattern
: public mlir::OpRewritePattern<mlir::concretelang::LowLFHE::EncodeIntOp> {
LowLFHEEncodeIntOpPattern(
struct ConcreteEncodeIntOpPattern
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::EncodeIntOp> {
ConcreteEncodeIntOpPattern(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::concretelang::LowLFHE::EncodeIntOp>(context,
: mlir::OpRewritePattern<mlir::concretelang::Concrete::EncodeIntOp>(context,
benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::LowLFHE::EncodeIntOp op,
matchAndRewrite(mlir::concretelang::Concrete::EncodeIntOp op,
mlir::PatternRewriter &rewriter) const override {
{
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
@@ -463,16 +463,16 @@ struct LowLFHEEncodeIntOpPattern
};
};
struct LowLFHEIntToCleartextOpPattern
: public mlir::OpRewritePattern<mlir::concretelang::LowLFHE::IntToCleartextOp> {
LowLFHEIntToCleartextOpPattern(
struct ConcreteIntToCleartextOpPattern
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::IntToCleartextOp> {
ConcreteIntToCleartextOpPattern(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::concretelang::LowLFHE::IntToCleartextOp>(
: mlir::OpRewritePattern<mlir::concretelang::Concrete::IntToCleartextOp>(
context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::LowLFHE::IntToCleartextOp op,
matchAndRewrite(mlir::concretelang::Concrete::IntToCleartextOp op,
mlir::PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(
op, rewriter.getIntegerType(64), op->getOperands().front());
@@ -489,17 +489,17 @@ struct LowLFHEIntToCleartextOpPattern
// - construct the GLWE accumulator by adding the plaintext_list to a freshly
// allocated GLWE
struct GlweFromTableOpPattern
: public mlir::OpRewritePattern<mlir::concretelang::LowLFHE::GlweFromTable> {
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::GlweFromTable> {
GlweFromTableOpPattern(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::concretelang::LowLFHE::GlweFromTable>(
: mlir::OpRewritePattern<mlir::concretelang::Concrete::GlweFromTable>(
context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::LowLFHE::GlweFromTable op,
matchAndRewrite(mlir::concretelang::Concrete::GlweFromTable op,
mlir::PatternRewriter &rewriter) const override {
LowLFHEToConcreteCAPITypeConverter typeConverter;
ConcreteToConcreteCAPITypeConverter typeConverter;
auto errType = mlir::IndexType::get(rewriter.getContext());
// TODO: move this to insertForwardDeclarations
@@ -589,8 +589,8 @@ mlir::Value getContextArgument(mlir::Operation *op) {
mlir::Value context = block->getArguments().back();
assert(context.getType().isa<mlir::concretelang::LowLFHE::ContextType>() &&
"the LowLFHE.context should be the last argument of the enclosing "
assert(context.getType().isa<mlir::concretelang::Concrete::ContextType>() &&
"the Concrete.context should be the last argument of the enclosing "
"function of the op");
return context;
@@ -606,23 +606,23 @@ mlir::Value getContextArgument(mlir::Operation *op) {
// - get the global bootstrapping key
// - use the key and the input accumulator (GLWE) to bootstrap the input
// ciphertext
struct LowLFHEBootstrapLweOpPattern
: public mlir::OpRewritePattern<mlir::concretelang::LowLFHE::BootstrapLweOp> {
LowLFHEBootstrapLweOpPattern(
struct ConcreteBootstrapLweOpPattern
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::BootstrapLweOp> {
ConcreteBootstrapLweOpPattern(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::concretelang::LowLFHE::BootstrapLweOp>(
: mlir::OpRewritePattern<mlir::concretelang::Concrete::BootstrapLweOp>(
context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::LowLFHE::BootstrapLweOp op,
matchAndRewrite(mlir::concretelang::Concrete::BootstrapLweOp op,
mlir::PatternRewriter &rewriter) const override {
auto resultType = op->getResultTypes().front();
auto errOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(0));
// Get the size from the dimension
int64_t outputLweDimension =
resultType.cast<mlir::concretelang::LowLFHE::LweCiphertextType>()
resultType.cast<mlir::concretelang::Concrete::LweCiphertextType>()
.getDimension();
int64_t outputLweSize = outputLweDimension + 1;
mlir::Value lweSizeOp = rewriter.create<mlir::arith::ConstantOp>(
@@ -670,16 +670,16 @@ struct LowLFHEBootstrapLweOpPattern
// - allocate the result LWE ciphertext
// - get the global keyswitch key
// - use the key to keyswitch the input ciphertext
struct LowLFHEKeySwitchLweOpPattern
: public mlir::OpRewritePattern<mlir::concretelang::LowLFHE::KeySwitchLweOp> {
LowLFHEKeySwitchLweOpPattern(
struct ConcreteKeySwitchLweOpPattern
: public mlir::OpRewritePattern<mlir::concretelang::Concrete::KeySwitchLweOp> {
ConcreteKeySwitchLweOpPattern(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::concretelang::LowLFHE::KeySwitchLweOp>(
: mlir::OpRewritePattern<mlir::concretelang::Concrete::KeySwitchLweOp>(
context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::LowLFHE::KeySwitchLweOp op,
matchAndRewrite(mlir::concretelang::Concrete::KeySwitchLweOp op,
mlir::PatternRewriter &rewriter) const override {
auto errOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(0));
@@ -687,7 +687,7 @@ struct LowLFHEKeySwitchLweOpPattern
int64_t lweDimension =
op.getResult()
.getType()
.cast<mlir::concretelang::LowLFHE::LweCiphertextType>()
.cast<mlir::concretelang::Concrete::LweCiphertextType>()
.getDimension();
int64_t lweSize = lweDimension + 1;
mlir::Value lweSizeOp = rewriter.create<mlir::arith::ConstantOp>(
@@ -724,31 +724,31 @@ struct LowLFHEKeySwitchLweOpPattern
};
};
/// Populate the RewritePatternSet with all patterns that rewrite LowLFHE
/// Populate the RewritePatternSet with all patterns that rewrite Concrete
/// operators to the corresponding function call to the `Concrete C API`.
void populateLowLFHEToConcreteCAPICall(mlir::RewritePatternSet &patterns) {
patterns.add<LowLFHEOpToConcreteCAPICallPattern<
mlir::concretelang::LowLFHE::AddLweCiphertextsOp>>(
void populateConcreteToConcreteCAPICall(mlir::RewritePatternSet &patterns) {
patterns.add<ConcreteOpToConcreteCAPICallPattern<
mlir::concretelang::Concrete::AddLweCiphertextsOp>>(
patterns.getContext(), "add_lwe_ciphertexts_u64",
"allocate_lwe_ciphertext_u64");
patterns.add<LowLFHEOpToConcreteCAPICallPattern<
mlir::concretelang::LowLFHE::AddPlaintextLweCiphertextOp>>(
patterns.add<ConcreteOpToConcreteCAPICallPattern<
mlir::concretelang::Concrete::AddPlaintextLweCiphertextOp>>(
patterns.getContext(), "add_plaintext_lwe_ciphertext_u64",
"allocate_lwe_ciphertext_u64");
patterns.add<LowLFHEOpToConcreteCAPICallPattern<
mlir::concretelang::LowLFHE::MulCleartextLweCiphertextOp>>(
patterns.add<ConcreteOpToConcreteCAPICallPattern<
mlir::concretelang::Concrete::MulCleartextLweCiphertextOp>>(
patterns.getContext(), "mul_cleartext_lwe_ciphertext_u64",
"allocate_lwe_ciphertext_u64");
patterns.add<LowLFHEOpToConcreteCAPICallPattern<
mlir::concretelang::LowLFHE::NegateLweCiphertextOp>>(
patterns.add<ConcreteOpToConcreteCAPICallPattern<
mlir::concretelang::Concrete::NegateLweCiphertextOp>>(
patterns.getContext(), "negate_lwe_ciphertext_u64",
"allocate_lwe_ciphertext_u64");
patterns.add<LowLFHEEncodeIntOpPattern>(patterns.getContext());
patterns.add<LowLFHEIntToCleartextOpPattern>(patterns.getContext());
patterns.add<LowLFHEZeroOpPattern>(patterns.getContext());
patterns.add<ConcreteEncodeIntOpPattern>(patterns.getContext());
patterns.add<ConcreteIntToCleartextOpPattern>(patterns.getContext());
patterns.add<ConcreteZeroOpPattern>(patterns.getContext());
patterns.add<GlweFromTableOpPattern>(patterns.getContext());
patterns.add<LowLFHEKeySwitchLweOpPattern>(patterns.getContext());
patterns.add<LowLFHEBootstrapLweOpPattern>(patterns.getContext());
patterns.add<ConcreteKeySwitchLweOpPattern>(patterns.getContext());
patterns.add<ConcreteBootstrapLweOpPattern>(patterns.getContext());
}
struct AddRuntimeContextToFuncOpPattern
@@ -764,11 +764,11 @@ struct AddRuntimeContextToFuncOpPattern
mlir::OpBuilder::InsertionGuard guard(rewriter);
mlir::FunctionType oldFuncType = oldFuncOp.getType();
// Add a LowLFHE.context to the function signature
// Add a Concrete.context to the function signature
mlir::SmallVector<mlir::Type> newInputs(oldFuncType.getInputs().begin(),
oldFuncType.getInputs().end());
newInputs.push_back(
rewriter.getType<mlir::concretelang::LowLFHE::ContextType>());
rewriter.getType<mlir::concretelang::Concrete::ContextType>());
mlir::FunctionType newFuncTy = rewriter.getType<mlir::FunctionType>(
newInputs, oldFuncType.getResults());
// Create the new func
@@ -793,20 +793,20 @@ struct AddRuntimeContextToFuncOpPattern
return mlir::success();
}
// Legal function are one that are private or has a LowLFHE.context as last
// Legal function are one that are private or has a Concrete.context as last
// arguments.
static bool isLegal(mlir::FuncOp funcOp) {
if (!funcOp.isPublic()) {
return true;
}
// TODO : Don't need to add a runtime context for function that doesn't
// manipulates lowlfhe types.
// manipulates concrete types.
//
// if (!llvm::any_of(funcOp.getType().getInputs(), [](mlir::Type t) {
// if (auto tensorTy = t.dyn_cast_or_null<mlir::TensorType>()) {
// t = tensorTy.getElementType();
// }
// return llvm::isa<mlir::concretelang::LowLFHE::LowLFHEDialect>(
// return llvm::isa<mlir::concretelang::Concrete::ConcreteDialect>(
// t.getDialect());
// })) {
// return true;
@@ -815,21 +815,21 @@ struct AddRuntimeContextToFuncOpPattern
funcOp.getType()
.getInputs()
.back()
.isa<mlir::concretelang::LowLFHE::ContextType>();
.isa<mlir::concretelang::Concrete::ContextType>();
}
};
namespace {
struct LowLFHEToConcreteCAPIPass
: public LowLFHEToConcreteCAPIBase<LowLFHEToConcreteCAPIPass> {
struct ConcreteToConcreteCAPIPass
: public ConcreteToConcreteCAPIBase<ConcreteToConcreteCAPIPass> {
void runOnOperation() final;
};
} // namespace
void LowLFHEToConcreteCAPIPass::runOnOperation() {
void ConcreteToConcreteCAPIPass::runOnOperation() {
mlir::ModuleOp op = getOperation();
// First of all add the LowLFHE.context to the block arguments of function
// First of all add the Concrete.context to the block arguments of function
// that manipulates ciphertexts.
{
mlir::ConversionTarget target(getContext());
@@ -854,17 +854,17 @@ void LowLFHEToConcreteCAPIPass::runOnOperation() {
if (insertForwardDeclarations(op, rewriter).failed()) {
this->signalPassFailure();
}
// Rewrite LowLFHE ops to CallOp to the Concrete C API
// Rewrite Concrete ops to CallOp to the Concrete C API
{
mlir::ConversionTarget target(getContext());
mlir::RewritePatternSet patterns(&getContext());
target.addIllegalDialect<mlir::concretelang::LowLFHE::LowLFHEDialect>();
target.addIllegalDialect<mlir::concretelang::Concrete::ConcreteDialect>();
target.addLegalDialect<mlir::BuiltinDialect, mlir::StandardOpsDialect,
mlir::memref::MemRefDialect,
mlir::arith::ArithmeticDialect>();
populateLowLFHEToConcreteCAPICall(patterns);
populateConcreteToConcreteCAPICall(patterns);
if (mlir::applyPartialConversion(op, target, std::move(patterns))
.failed()) {
@@ -876,8 +876,8 @@ void LowLFHEToConcreteCAPIPass::runOnOperation() {
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertLowLFHEToConcreteCAPIPass() {
return std::make_unique<LowLFHEToConcreteCAPIPass>();
createConvertConcreteToConcreteCAPIPass() {
return std::make_unique<ConcreteToConcreteCAPIPass>();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -0,0 +1,16 @@
add_mlir_dialect_library(ConcreteUnparametrize
ConcreteUnparametrize.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
DEPENDS
ConcreteDialect
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRTransforms
)
target_link_libraries(ConcreteUnparametrize PUBLIC MLIRIR)

View File

@@ -7,32 +7,32 @@
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
#include "concretelang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
#include "concretelang/Dialect/LowLFHE/IR/LowLFHEOps.h"
#include "concretelang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
#include "concretelang/Dialect/RT/IR/RTOps.h"
#include "concretelang/Support/Constants.h"
/// LowLFHEUnparametrizeTypeConverter is a type converter that unparametrize
/// LowLFHE types
class LowLFHEUnparametrizeTypeConverter : public mlir::TypeConverter {
/// ConcreteUnparametrizeTypeConverter is a type converter that unparametrize
/// Concrete types
class ConcreteUnparametrizeTypeConverter : public mlir::TypeConverter {
public:
static mlir::Type unparematrizeLowLFHEType(mlir::Type type) {
if (type.isa<mlir::concretelang::LowLFHE::PlaintextType>()) {
static mlir::Type unparematrizeConcreteType(mlir::Type type) {
if (type.isa<mlir::concretelang::Concrete::PlaintextType>()) {
return mlir::IntegerType::get(type.getContext(), 64);
}
if (type.isa<mlir::concretelang::LowLFHE::CleartextType>()) {
if (type.isa<mlir::concretelang::Concrete::CleartextType>()) {
return mlir::IntegerType::get(type.getContext(), 64);
}
if (type.isa<mlir::concretelang::LowLFHE::LweCiphertextType>()) {
return mlir::concretelang::LowLFHE::LweCiphertextType::get(type.getContext(),
if (type.isa<mlir::concretelang::Concrete::LweCiphertextType>()) {
return mlir::concretelang::Concrete::LweCiphertextType::get(type.getContext(),
-1, -1);
}
auto tensorType = type.dyn_cast_or_null<mlir::RankedTensorType>();
if (tensorType != nullptr) {
auto eltTy0 = tensorType.getElementType();
auto eltTy1 = unparematrizeLowLFHEType(eltTy0);
auto eltTy1 = unparematrizeConcreteType(eltTy0);
if (eltTy0 == eltTy1) {
return type;
}
@@ -41,17 +41,17 @@ public:
return type;
}
LowLFHEUnparametrizeTypeConverter() {
ConcreteUnparametrizeTypeConverter() {
addConversion(
[](mlir::Type type) { return unparematrizeLowLFHEType(type); });
[](mlir::Type type) { return unparematrizeConcreteType(type); });
}
};
/// Replace `%1 = unrealized_conversion_cast %0 : t0 to t1` to `%0` where t0 or
/// t1 are a LowLFHE type.
struct LowLFHEUnrealizedCastReplacementPattern
/// t1 are a Concrete type.
struct ConcreteUnrealizedCastReplacementPattern
: public mlir::OpRewritePattern<mlir::UnrealizedConversionCastOp> {
LowLFHEUnrealizedCastReplacementPattern(
ConcreteUnrealizedCastReplacementPattern(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<mlir::UnrealizedConversionCastOp>(context,
@@ -60,9 +60,9 @@ struct LowLFHEUnrealizedCastReplacementPattern
mlir::LogicalResult
matchAndRewrite(mlir::UnrealizedConversionCastOp op,
mlir::PatternRewriter &rewriter) const override {
if (mlir::isa<mlir::concretelang::LowLFHE::LowLFHEDialect>(
if (mlir::isa<mlir::concretelang::Concrete::ConcreteDialect>(
op.getOperandTypes()[0].getDialect()) ||
mlir::isa<mlir::concretelang::LowLFHE::LowLFHEDialect>(
mlir::isa<mlir::concretelang::Concrete::ConcreteDialect>(
op.getType(0).getDialect())) {
rewriter.replaceOp(op, op.getOperands());
return mlir::success();
@@ -71,21 +71,21 @@ struct LowLFHEUnrealizedCastReplacementPattern
};
};
/// LowLFHEUnparametrizePass remove all parameters of LowLFHE types and remove
/// ConcreteUnparametrizePass remove all parameters of Concrete types and remove
/// the unrealized_conversion_cast operation that operates on parametrized
/// LowLFHE types.
struct LowLFHEUnparametrizePass
: public LowLFHEUnparametrizeBase<LowLFHEUnparametrizePass> {
/// Concrete types.
struct ConcreteUnparametrizePass
: public ConcreteUnparametrizeBase<ConcreteUnparametrizePass> {
void runOnOperation() final;
};
void LowLFHEUnparametrizePass::runOnOperation() {
void ConcreteUnparametrizePass::runOnOperation() {
auto op = this->getOperation();
mlir::ConversionTarget target(getContext());
mlir::OwningRewritePatternList patterns(&getContext());
LowLFHEUnparametrizeTypeConverter converter;
ConcreteUnparametrizeTypeConverter converter;
// Conversion of linalg.generic operation
target
@@ -97,13 +97,13 @@ void LowLFHEUnparametrizePass::runOnOperation() {
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
});
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
LowLFHEUnparametrizeTypeConverter>>(
ConcreteUnparametrizeTypeConverter>>(
&getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
LowLFHEUnparametrizeTypeConverter>>(
ConcreteUnparametrizeTypeConverter>>(
&getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<mlir::scf::ForOp,
LowLFHEUnparametrizeTypeConverter>>(
ConcreteUnparametrizeTypeConverter>>(
&getContext(), converter);
// Conversion of function signature and arguments
@@ -116,7 +116,7 @@ void LowLFHEUnparametrizePass::runOnOperation() {
// Replacement of unrealized_conversion_cast
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::UnrealizedConversionCastOp>(
target, converter);
patterns.add<LowLFHEUnrealizedCastReplacementPattern>(patterns.getContext());
patterns.add<ConcreteUnrealizedCastReplacementPattern>(patterns.getContext());
// Conversion of tensor operators
mlir::concretelang::populateWithTensorTypeConverterPatterns(patterns, target,
@@ -142,8 +142,8 @@ void LowLFHEUnparametrizePass::runOnOperation() {
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertLowLFHEUnparametrizePass() {
return std::make_unique<LowLFHEUnparametrizePass>();
createConvertConcreteUnparametrizePass() {
return std::make_unique<ConcreteUnparametrizePass>();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -0,0 +1,17 @@
add_mlir_dialect_library(FHETensorOpsToLinalg
TensorOpsToLinalg.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
DEPENDS
FHEDialect
FHELinalgDialect
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
MLIRIR
FHEDialect
FHELinalgDialect)
target_link_libraries(FHEDialect PUBLIC MLIRIR)

View File

@@ -14,59 +14,59 @@
#include <iostream>
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "concretelang/Dialect/HLFHE/IR/HLFHEOps.h"
#include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h"
#include "concretelang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h"
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.h"
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h"
#include "concretelang/Support/Constants.h"
struct DotToLinalgGeneric
: public ::mlir::OpRewritePattern<mlir::concretelang::HLFHELinalg::Dot> {
: public ::mlir::OpRewritePattern<mlir::concretelang::FHELinalg::Dot> {
DotToLinalgGeneric(::mlir::MLIRContext *context)
: ::mlir::OpRewritePattern<::mlir::concretelang::HLFHELinalg::Dot>(
: ::mlir::OpRewritePattern<::mlir::concretelang::FHELinalg::Dot>(
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
// This rewrite pattern transforms any instance of
// `HLFHELinalg.dot_eint_int` to an instance of `linalg.generic` with an
// appropriate region using `HLFHE.mul_eint_int` and
// `HLFHE.add_eint` operations, an appropriate specification for the
// `FHELinalg.dot_eint_int` to an instance of `linalg.generic` with an
// appropriate region using `FHE.mul_eint_int` and
// `FHE.add_eint` operations, an appropriate specification for the
// iteration dimensions and appropriate operations managing the
// accumulator of `linalg.generic`.
//
// Example:
//
// %o = "HLFHELinalg.dot_eint_int"(%arg0, %arg1) :
// (tensor<4x!HLFHE.eint<0>>,
// tensor<4xi32>) -> (!HLFHE.eint<0>)
// %o = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
// (tensor<4x!FHE.eint<0>>,
// tensor<4xi32>) -> (!FHE.eint<0>)
//
// becomes:
//
// %0 = "HLFHE.zero"() : () -> !HLFHE.eint<0>
// %1 = tensor.from_elements %0 : tensor<1x!HLFHE.eint<0>>
// %0 = "FHE.zero"() : () -> !FHE.eint<0>
// %1 = tensor.from_elements %0 : tensor<1x!FHE.eint<0>>
// %2 = linalg.generic {
// indexing_maps = [#map0, #map0, #map1],
// iterator_types = ["reduction"]
// }
// ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<0>>, tensor<2xi32>)
// outs(%1 : tensor<1x!HLFHE.eint<0>>) {
// ^bb0(%arg2: !HLFHE.eint<0>, %arg3: i32, %arg4: !HLFHE.eint<0>):
// %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) :
// (!HLFHE.eint<0>, i32) -> !HLFHE.eint<0>
// ins(%arg0, %arg1 : tensor<2x!FHE.eint<0>>, tensor<2xi32>)
// outs(%1 : tensor<1x!FHE.eint<0>>) {
// ^bb0(%arg2: !FHE.eint<0>, %arg3: i32, %arg4: !FHE.eint<0>):
// %4 = "FHE.mul_eint_int"(%arg2, %arg3) :
// (!FHE.eint<0>, i32) -> !FHE.eint<0>
//
// %5 = "HLFHE.add_eint"(%4, %arg4) :
// (!HLFHE.eint<0>, !HLFHE.eint<0>) -> !HLFHE.eint<0>
// %5 = "FHE.add_eint"(%4, %arg4) :
// (!FHE.eint<0>, !FHE.eint<0>) -> !FHE.eint<0>
//
// linalg.yield %5 : !HLFHE.eint<0>
// } -> tensor<1x!HLFHE.eint<0>>
// linalg.yield %5 : !FHE.eint<0>
// } -> tensor<1x!FHE.eint<0>>
//
// %c0 = constant 0 : index
// %o = tensor.extract %2[%c0] : tensor<1x!HLFHE.eint<0>>
// %o = tensor.extract %2[%c0] : tensor<1x!FHE.eint<0>>
//
::mlir::LogicalResult
matchAndRewrite(::mlir::concretelang::HLFHELinalg::Dot dotOp,
matchAndRewrite(::mlir::concretelang::FHELinalg::Dot dotOp,
::mlir::PatternRewriter &rewriter) const override {
// Zero value to initialize accumulator
mlir::Value zeroCst = rewriter.create<mlir::concretelang::HLFHE::ZeroEintOp>(
mlir::Value zeroCst = rewriter.create<mlir::concretelang::FHE::ZeroEintOp>(
dotOp.getLoc(),
dotOp.lhs().getType().cast<mlir::ShapedType>().getElementType());
@@ -95,11 +95,11 @@ struct DotToLinalgGeneric
auto regBuilder = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
mlir::concretelang::HLFHE::MulEintIntOp mul =
nestedBuilder.create<mlir::concretelang::HLFHE::MulEintIntOp>(
mlir::concretelang::FHE::MulEintIntOp mul =
nestedBuilder.create<mlir::concretelang::FHE::MulEintIntOp>(
dotOp.getLoc(), blockArgs[0], blockArgs[1]);
mlir::concretelang::HLFHE::AddEintOp add =
nestedBuilder.create<mlir::concretelang::HLFHE::AddEintOp>(
mlir::concretelang::FHE::AddEintOp add =
nestedBuilder.create<mlir::concretelang::FHE::AddEintOp>(
dotOp.getLoc(), mul, blockArgs[2]);
nestedBuilder.create<mlir::linalg::YieldOp>(dotOp.getLoc(),
@@ -180,16 +180,16 @@ getBroadcastedAffineMapMultiLUT(const mlir::RankedTensorType &resultType,
}
// This template rewrite pattern transforms any instance of
// operators `HLFHELinalgOp` that implements the broadasting rules to an
// instance of `linalg.generic` with an appropriate region using `HLFHEOp`
// operators `FHELinalgOp` that implements the broadasting rules to an
// instance of `linalg.generic` with an appropriate region using `FHEOp`
// operation, an appropriate specification for the iteration dimensions and
// appropriate operations managing the accumulator of `linalg.generic`.
//
// Example:
//
// %res = HLFHELinalg.op(%lhs, %rhs):
// (tensor<D$Ax...xD1x!HLFHE.eint<p>>, tensor<D$B'x...xD1'xT>)
// -> tensor<DR"x...xD1"x!HLFHE.eint<p>>
// %res = FHELinalg.op(%lhs, %rhs):
// (tensor<D$Ax...xD1x!FHE.eint<p>>, tensor<D$B'x...xD1'xT>)
// -> tensor<DR"x...xD1"x!FHE.eint<p>>
//
// becomes:
//
@@ -205,28 +205,28 @@ getBroadcastedAffineMapMultiLUT(const mlir::RankedTensorType &resultType,
// iterator_types = ["parallel", ..., "parallel"], // $R" parallel
// }
// %init = linalg.init_tensor [DR",...,D1"]
// : tensor<DR"x...xD1"x!HLFHE.eint<p>>
// : tensor<DR"x...xD1"x!FHE.eint<p>>
// %res = linalg.generic {
// ins(%lhs, %rhs: tensor<DAx...xD1x!HLFHE.eint<p>>,tensor<DB'x...xD1'xT>)
// outs(%init : tensor<DR"x...xD1"x!HLFHE.eint<p>>)
// ins(%lhs, %rhs: tensor<DAx...xD1x!FHE.eint<p>>,tensor<DB'x...xD1'xT>)
// outs(%init : tensor<DR"x...xD1"x!FHE.eint<p>>)
// {
// ^bb0(%arg0: !HLFHE.eint<p>, %arg1: T):
// %0 = HLFHE.op(%arg0, %arg1): !HLFHE.eint<p>, T ->
// !HLFHE.eint<p>
// linalg.yield %0 : !HLFHE.eint<p>
// ^bb0(%arg0: !FHE.eint<p>, %arg1: T):
// %0 = FHE.op(%arg0, %arg1): !FHE.eint<p>, T ->
// !FHE.eint<p>
// linalg.yield %0 : !FHE.eint<p>
// }
// }
//
template <typename HLFHELinalgOp, typename HLFHEOp>
struct HLFHELinalgOpToLinalgGeneric
: public mlir::OpRewritePattern<HLFHELinalgOp> {
HLFHELinalgOpToLinalgGeneric(
template <typename FHELinalgOp, typename FHEOp>
struct FHELinalgOpToLinalgGeneric
: public mlir::OpRewritePattern<FHELinalgOp> {
FHELinalgOpToLinalgGeneric(
::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: ::mlir::OpRewritePattern<HLFHELinalgOp>(context, benefit) {}
: ::mlir::OpRewritePattern<FHELinalgOp>(context, benefit) {}
::mlir::LogicalResult
matchAndRewrite(HLFHELinalgOp linalgOp,
matchAndRewrite(FHELinalgOp linalgOp,
::mlir::PatternRewriter &rewriter) const override {
mlir::RankedTensorType resultTy =
((mlir::Type)linalgOp->getResult(0).getType())
@@ -254,11 +254,11 @@ struct HLFHELinalgOpToLinalgGeneric
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
HLFHEOp hlfheOp = nestedBuilder.create<HLFHEOp>(
FHEOp fheOp = nestedBuilder.create<FHEOp>(
linalgOp.getLoc(), blockArgs[0], blockArgs[1]);
nestedBuilder.create<mlir::linalg::YieldOp>(linalgOp.getLoc(),
hlfheOp.getResult());
fheOp.getResult());
};
// Create the `linalg.generic` op
@@ -288,9 +288,9 @@ llvm::SmallVector<llvm::StringRef> parallelIteratorType(int n) {
}
// This class rewrite pattern transforms any instance of
// operators `HLFHELinalg.ApplyMappedLookupTableEintOp` that implements the
// operators `FHELinalg.ApplyMappedLookupTableEintOp` that implements the
// broadasting rules to an instance of `linalg.generic` with an appropriate
// region using `HLFHE.ApplyLookupTableEintOp` operation, an appropriate
// region using `FHE.ApplyLookupTableEintOp` operation, an appropriate
// specification for the iteration dimensions and appropriate operations
// managing the accumulator of `linalg.generic`.
//
@@ -298,20 +298,20 @@ llvm::SmallVector<llvm::StringRef> parallelIteratorType(int n) {
// because of a bug in lowering this operation.
//
// Example:
// %res = "HLFHELinalg.apply_mapped_lookup_table"(%t, %luts, %map)
// : (tensor<2x3x!HLFHE.eint<2>>, tensor<5x4xi64>, tensor<2x3xindex>)
// -> tensor<2x3x!HLFHE.eint<2>>
// %res = "FHELinalg.apply_mapped_lookup_table"(%t, %luts, %map)
// : (tensor<2x3x!FHE.eint<2>>, tensor<5x4xi64>, tensor<2x3xindex>)
// -> tensor<2x3x!FHE.eint<2>>
//
// becomes:
//
// #map = affine_map<(d0, d1) -> (d0, d1)>
// %init = linalg.init_tensor [2, 3] : tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>
// %init = linalg.init_tensor [2, 3] : tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>
// %output = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types
// = ["parallel", "parallel"]} ins(%arg0, %arg2 :
// tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<2x3xindex>) outs(%0 :
// tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>) {
// ^bb0(%arg3: !MidLFHE.glwe<{_,_,_}{2}>, %lut_idx: index, %arg5:
// !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors
// tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>, tensor<2x3xindex>) outs(%0 :
// tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>) {
// ^bb0(%arg3: !TFHE.glwe<{_,_,_}{2}>, %lut_idx: index, %arg5:
// !TFHE.glwe<{_,_,_}{2}>): // no predecessors
// // SHOULD BE
// %lut = tensor.extract_slice %arg1[%[[LUTIDX]], 0] [1,4] [1, 1]
// : tensor<5x4xi64> to tensor<4xi64>
@@ -323,31 +323,31 @@ llvm::SmallVector<llvm::StringRef> parallelIteratorType(int n) {
// ...
// %e3 = tensor.extract %arg5[%lut_idx, %i3] : tensor<5x4xi64>
// %lut = tensor.from_elements %e0, ..., %e3 : tensor<4xi64>
// %res = "MidLFHE.apply_lookup_table"(%arg3, %[[LUT]])
// %res = "TFHE.apply_lookup_table"(%arg3, %[[LUT]])
// {baseLogBS = -1 : i32, baseLogKS = -1 : i32, k = -1 : i32,
// levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS =
// -1 : i32, polynomialSize = -1 : i32}
// : (!MidLFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) ->
// !MidLFHE.glwe<{_,_,_}{2}> linalg.yield %res :
// !MidLFHE.glwe<{_,_,_}{2}>
// } -> tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>
// : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) ->
// !TFHE.glwe<{_,_,_}{2}> linalg.yield %res :
// !TFHE.glwe<{_,_,_}{2}>
// } -> tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>
namespace HLFHELinalg = mlir::concretelang::HLFHELinalg;
namespace FHELinalg = mlir::concretelang::FHELinalg;
struct HLFHELinalgApplyMappedLookupTableToLinalgGeneric
: public mlir::OpRewritePattern<HLFHELinalg::ApplyMappedLookupTableEintOp> {
HLFHELinalgApplyMappedLookupTableToLinalgGeneric(
struct FHELinalgApplyMappedLookupTableToLinalgGeneric
: public mlir::OpRewritePattern<FHELinalg::ApplyMappedLookupTableEintOp> {
FHELinalgApplyMappedLookupTableToLinalgGeneric(
::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<HLFHELinalg::ApplyMappedLookupTableEintOp>(
: ::mlir::OpRewritePattern<FHELinalg::ApplyMappedLookupTableEintOp>(
context, benefit) {}
::mlir::LogicalResult
matchAndRewrite(HLFHELinalg::ApplyMappedLookupTableEintOp mappedLookup,
matchAndRewrite(FHELinalg::ApplyMappedLookupTableEintOp mappedLookup,
::mlir::PatternRewriter &rewriter) const override {
namespace arith = mlir::arith;
namespace linalg = mlir::linalg;
namespace tensor = mlir::tensor;
namespace HLFHE = mlir::concretelang::HLFHE;
namespace FHE = mlir::concretelang::FHE;
using Values = llvm::SmallVector<mlir::Value>;
using Types = llvm::SmallVector<mlir::Type>;
using AffineMaps = llvm::SmallVector<mlir::AffineMap>;
@@ -421,9 +421,9 @@ struct HLFHELinalgApplyMappedLookupTableToLinalgGeneric
lut = nestedBuilder.create<tensor::FromElementsOp>(loc, extracts);
} // WORKAROUND END
// %res1 = apply_lookup_table %arg0 %lut
auto lookup = nestedBuilder.create<HLFHE::ApplyLookupTableEintOp>(
auto lookup = nestedBuilder.create<FHE::ApplyLookupTableEintOp>(
loc, elementTy, tElmt, lut);
// linalg.yield %res1 : !HLFHE.eint<2>
// linalg.yield %res1 : !FHE.eint<2>
nestedBuilder.create<linalg::YieldOp>(loc, lookup.getResult());
};
@@ -446,16 +446,16 @@ struct HLFHELinalgApplyMappedLookupTableToLinalgGeneric
};
// This class rewrite pattern transforms any instance of
// operators `HLFHELinalg.ApplyMultiLookupTableEintOp` that implements the
// operators `FHELinalg.ApplyMultiLookupTableEintOp` that implements the
// broadasting rules to an instance of `linalg.generic` with an appropriate
// region using `HLFHE.ApplyLookupTableEintOp` operation, an appropriate
// region using `FHE.ApplyLookupTableEintOp` operation, an appropriate
// specification for the iteration dimensions and appropriate operaztions
// managing the accumulator of `linalg.generic`.
//
// Example:
//
// %res = "HLFHELinalg.apply_multi_lookup_table"(%t, %luts):
// (tensor<4x3x!HLFHE.eint<2>>, tensor<3x4xi64>) -> tensor<4x3x!HLFHE.eint<2>>
// %res = "FHELinalg.apply_multi_lookup_table"(%t, %luts):
// (tensor<4x3x!FHE.eint<2>>, tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>>
//
// becomes:
//
@@ -471,50 +471,50 @@ struct HLFHELinalgApplyMappedLookupTableToLinalgGeneric
// iterator_types = ["parallel", "parallel"],
// }
// %init = linalg.init_tensor [4, 3]
// : tensor<4x3x!HLFHE.eint<2>>
// : tensor<4x3x!FHE.eint<2>>
// %res = linalg.generic {
// ins(%t, %luts, %luts, %luts, %luts: tensor<4x3x!HLFHE.eint<p>>,
// ins(%t, %luts, %luts, %luts, %luts: tensor<4x3x!FHE.eint<p>>,
// tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>)
// outs(%init : tensor<4x3x!HLFHE.eint<2>>)
// outs(%init : tensor<4x3x!FHE.eint<2>>)
// {
// ^bb0(%arg0: !HLFHE.eint<2>, %arg1: i64, %arg2: i64, %arg3: i64,
// %arg4: i64, %arg5: !HLFHE.eint<2>):
// ^bb0(%arg0: !FHE.eint<2>, %arg1: i64, %arg2: i64, %arg3: i64,
// %arg4: i64, %arg5: !FHE.eint<2>):
// %lut = tensor.from_elements %arg1, %arg2, %arg3, %arg4 :
// tensor<4xi64> %0 = "MidLFHE.apply_lookup_table"(%arg0, %lut)
// tensor<4xi64> %0 = "TFHE.apply_lookup_table"(%arg0, %lut)
// {baseLogBS = -1 : i32, baseLogKS = -1 : i32, glweDimension = -1 :
// i32, levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS = -1 :
// i32, polynomialSize = -1 : i32} : (!MidLFHE.glwe<{_,_,_}{2}>,
// tensor<4xi64>) -> !MidLFHE.glwe<{_,_,_}{2}>
// linalg.yield %0 : !HLFHE.eint<2>
// i32, polynomialSize = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>,
// tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{2}>
// linalg.yield %0 : !FHE.eint<2>
// }
// }
//
struct HLFHELinalgApplyMultiLookupTableToLinalgGeneric
struct FHELinalgApplyMultiLookupTableToLinalgGeneric
: public mlir::OpRewritePattern<
mlir::concretelang::HLFHELinalg::ApplyMultiLookupTableEintOp> {
HLFHELinalgApplyMultiLookupTableToLinalgGeneric(
mlir::concretelang::FHELinalg::ApplyMultiLookupTableEintOp> {
FHELinalgApplyMultiLookupTableToLinalgGeneric(
::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: ::mlir::OpRewritePattern<
mlir::concretelang::HLFHELinalg::ApplyMultiLookupTableEintOp>(context,
mlir::concretelang::FHELinalg::ApplyMultiLookupTableEintOp>(context,
benefit) {
}
::mlir::LogicalResult matchAndRewrite(
mlir::concretelang::HLFHELinalg::ApplyMultiLookupTableEintOp hlfheLinalgLutOp,
mlir::concretelang::FHELinalg::ApplyMultiLookupTableEintOp fheLinalgLutOp,
::mlir::PatternRewriter &rewriter) const override {
mlir::RankedTensorType resultTy =
((mlir::Type)hlfheLinalgLutOp->getResult(0).getType())
((mlir::Type)fheLinalgLutOp->getResult(0).getType())
.cast<mlir::RankedTensorType>();
mlir::RankedTensorType tensorTy =
((mlir::Type)hlfheLinalgLutOp.t().getType())
((mlir::Type)fheLinalgLutOp.t().getType())
.cast<mlir::RankedTensorType>();
mlir::RankedTensorType lutsTy =
((mlir::Type)hlfheLinalgLutOp.luts().getType())
((mlir::Type)fheLinalgLutOp.luts().getType())
.cast<mlir::RankedTensorType>();
// linalg.init_tensor for initial value
mlir::Value init = rewriter.create<mlir::linalg::InitTensorOp>(
hlfheLinalgLutOp.getLoc(), resultTy.getShape(),
fheLinalgLutOp.getLoc(), resultTy.getShape(),
resultTy.getElementType());
auto lutsShape = lutsTy.getShape();
@@ -541,51 +541,51 @@ struct HLFHELinalgApplyMultiLookupTableToLinalgGeneric
mlir::ValueRange blockArgs) {
mlir::tensor::FromElementsOp lut =
nestedBuilder.create<mlir::tensor::FromElementsOp>(
hlfheLinalgLutOp.getLoc(), blockArgs.slice(1, lut_size));
mlir::concretelang::HLFHE::ApplyLookupTableEintOp lutOp =
nestedBuilder.create<mlir::concretelang::HLFHE::ApplyLookupTableEintOp>(
hlfheLinalgLutOp.getLoc(), resultTy.getElementType(),
fheLinalgLutOp.getLoc(), blockArgs.slice(1, lut_size));
mlir::concretelang::FHE::ApplyLookupTableEintOp lutOp =
nestedBuilder.create<mlir::concretelang::FHE::ApplyLookupTableEintOp>(
fheLinalgLutOp.getLoc(), resultTy.getElementType(),
blockArgs[0], lut.result());
nestedBuilder.create<mlir::linalg::YieldOp>(hlfheLinalgLutOp.getLoc(),
nestedBuilder.create<mlir::linalg::YieldOp>(fheLinalgLutOp.getLoc(),
lutOp.getResult());
};
// Create the `linalg.generic` op
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
llvm::SmallVector<mlir::Value> ins{hlfheLinalgLutOp.t()};
llvm::SmallVector<mlir::Value> ins{fheLinalgLutOp.t()};
ins.reserve(lut_size + 2);
// We extract one value at a time from one LUT using different maps, so we
// need to pass the LUT `lut_size` time
for (auto i = 0; i < lut_size; i++)
ins.push_back(hlfheLinalgLutOp.luts());
ins.push_back(fheLinalgLutOp.luts());
llvm::SmallVector<mlir::Value, 1> outs{init};
llvm::StringRef doc{""};
llvm::StringRef call{""};
mlir::linalg::GenericOp genericOp =
rewriter.create<mlir::linalg::GenericOp>(
hlfheLinalgLutOp.getLoc(), resTypes, ins, outs, maps, iteratorTypes,
fheLinalgLutOp.getLoc(), resTypes, ins, outs, maps, iteratorTypes,
doc, call, bodyBuilder);
rewriter.replaceOp(hlfheLinalgLutOp, {genericOp.getResult(0)});
rewriter.replaceOp(fheLinalgLutOp, {genericOp.getResult(0)});
return ::mlir::success();
};
};
// This template rewrite pattern transforms any instance of
// operators `HLFHELinalg.apply_lookup_table` that implements the broadasting
// operators `FHELinalg.apply_lookup_table` that implements the broadasting
// rules to an instance of `linalg.generic` with an appropriate region using
// `HLFHE.apply_lookup_table` operation, an appropriate specification for the
// `FHE.apply_lookup_table` operation, an appropriate specification for the
// iteration dimensions and appropriate operations managing the accumulator of
// `linalg.generic`.
//
// Example:
//
// HLFHELinalg.apply_lookup_table(%t, %lut):
// tensor<DNx...xD1x!HLFHE.eint<p>>, tensor<DAxi64>
// -> tensor<DNx...xD1x!HLFHE.eint<p'>>
// FHELinalg.apply_lookup_table(%t, %lut):
// tensor<DNx...xD1x!FHE.eint<p>>, tensor<DAxi64>
// -> tensor<DNx...xD1x!FHE.eint<p'>>
//
// becomes:
//
@@ -598,30 +598,30 @@ struct HLFHELinalgApplyMultiLookupTableToLinalgGeneric
// iterator_types = ["parallel",..],//N parallel
// }
// %init = linalg.init_tensor [DN,...,D1]
// : tensor<DNx...xD1x!HLFHE.eint<p'>>
// : tensor<DNx...xD1x!FHE.eint<p'>>
// %res = linalg.generic {
// ins(%t: tensor<DNx...xD1x!HLFHE.eint<p>>)
// outs(%init : tensor<DNx...xD1x!HLFHE.eint<p'>>)
// ins(%t: tensor<DNx...xD1x!FHE.eint<p>>)
// outs(%init : tensor<DNx...xD1x!FHE.eint<p'>>)
// {
// ^bb0(%arg0: !HLFHE.eint<p>):
// %0 = HLFHE.apply_lookup_table(%arg0, %lut): !HLFHE.eint<p>,
// tensor<4xi64> -> !HLFHE.eint<p'>
// linalg.yield %0 : !HLFHE.eint<p'>
// ^bb0(%arg0: !FHE.eint<p>):
// %0 = FHE.apply_lookup_table(%arg0, %lut): !FHE.eint<p>,
// tensor<4xi64> -> !FHE.eint<p'>
// linalg.yield %0 : !FHE.eint<p'>
// }
// }
//
struct HLFHELinalgApplyLookupTableToLinalgGeneric
struct FHELinalgApplyLookupTableToLinalgGeneric
: public mlir::OpRewritePattern<
mlir::concretelang::HLFHELinalg::ApplyLookupTableEintOp> {
HLFHELinalgApplyLookupTableToLinalgGeneric(
mlir::concretelang::FHELinalg::ApplyLookupTableEintOp> {
FHELinalgApplyLookupTableToLinalgGeneric(
::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: ::mlir::OpRewritePattern<
mlir::concretelang::HLFHELinalg::ApplyLookupTableEintOp>(context,
mlir::concretelang::FHELinalg::ApplyLookupTableEintOp>(context,
benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::concretelang::HLFHELinalg::ApplyLookupTableEintOp lutOp,
matchAndRewrite(mlir::concretelang::FHELinalg::ApplyLookupTableEintOp lutOp,
::mlir::PatternRewriter &rewriter) const override {
mlir::RankedTensorType resultTy =
((mlir::Type)lutOp->getResult(0).getType())
@@ -649,13 +649,13 @@ struct HLFHELinalgApplyLookupTableToLinalgGeneric
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
mlir::concretelang::HLFHE::ApplyLookupTableEintOp hlfheOp =
nestedBuilder.create<mlir::concretelang::HLFHE::ApplyLookupTableEintOp>(
mlir::concretelang::FHE::ApplyLookupTableEintOp fheOp =
nestedBuilder.create<mlir::concretelang::FHE::ApplyLookupTableEintOp>(
lutOp.getLoc(), resultTy.getElementType(), blockArgs[0],
lutOp.lut());
nestedBuilder.create<mlir::linalg::YieldOp>(lutOp.getLoc(),
hlfheOp.getResult());
fheOp.getResult());
};
// Create the `linalg.generic` op
@@ -677,15 +677,15 @@ struct HLFHELinalgApplyLookupTableToLinalgGeneric
};
// This template rewrite pattern transforms any instance of
// operators `HLFHELinalg.neg_eint` to an instance of `linalg.generic` with an
// appropriate region using `HLFHE.neg_eint` operation, an appropriate
// operators `FHELinalg.neg_eint` to an instance of `linalg.generic` with an
// appropriate region using `FHE.neg_eint` operation, an appropriate
// specification for the iteration dimensions and appropriate operations
// managing the accumulator of `linalg.generic`.
//
// Example:
//
// HLFHELinalg.neg_eint(%tensor):
// tensor<DNx...xD1x!HLFHE.eint<p>> -> tensor<DNx...xD1x!HLFHE.eint<p'>>
// FHELinalg.neg_eint(%tensor):
// tensor<DNx...xD1x!FHE.eint<p>> -> tensor<DNx...xD1x!FHE.eint<p'>>
//
// becomes:
//
@@ -698,27 +698,27 @@ struct HLFHELinalgApplyLookupTableToLinalgGeneric
// iterator_types = ["parallel",..],//N parallel
// }
// %init = linalg.init_tensor [DN,...,D1]
// : tensor<DNx...xD1x!HLFHE.eint<p'>>
// : tensor<DNx...xD1x!FHE.eint<p'>>
// %res = linalg.generic {
// ins(%tensor: tensor<DNx...xD1x!HLFHE.eint<p>>)
// outs(%init : tensor<DNx...xD1x!HLFHE.eint<p'>>)
// ins(%tensor: tensor<DNx...xD1x!FHE.eint<p>>)
// outs(%init : tensor<DNx...xD1x!FHE.eint<p'>>)
// {
// ^bb0(%arg0: !HLFHE.eint<p>):
// %0 = HLFHE.neg_eint(%arg0): !HLFHE.eint<p> -> !HLFHE.eint<p'>
// linalg.yield %0 : !HLFHE.eint<p'>
// ^bb0(%arg0: !FHE.eint<p>):
// %0 = FHE.neg_eint(%arg0): !FHE.eint<p> -> !FHE.eint<p'>
// linalg.yield %0 : !FHE.eint<p'>
// }
// }
//
struct HLFHELinalgNegEintToLinalgGeneric
: public mlir::OpRewritePattern<mlir::concretelang::HLFHELinalg::NegEintOp> {
HLFHELinalgNegEintToLinalgGeneric(
struct FHELinalgNegEintToLinalgGeneric
: public mlir::OpRewritePattern<mlir::concretelang::FHELinalg::NegEintOp> {
FHELinalgNegEintToLinalgGeneric(
::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: ::mlir::OpRewritePattern<mlir::concretelang::HLFHELinalg::NegEintOp>(
: ::mlir::OpRewritePattern<mlir::concretelang::FHELinalg::NegEintOp>(
context, benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::concretelang::HLFHELinalg::NegEintOp negEintOp,
matchAndRewrite(mlir::concretelang::FHELinalg::NegEintOp negEintOp,
::mlir::PatternRewriter &rewriter) const override {
mlir::RankedTensorType resultTy =
((mlir::Type)negEintOp->getResult(0).getType())
@@ -746,12 +746,12 @@ struct HLFHELinalgNegEintToLinalgGeneric
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
mlir::concretelang::HLFHE::NegEintOp hlfheOp =
nestedBuilder.create<mlir::concretelang::HLFHE::NegEintOp>(
mlir::concretelang::FHE::NegEintOp fheOp =
nestedBuilder.create<mlir::concretelang::FHE::NegEintOp>(
negEintOp.getLoc(), resultTy.getElementType(), blockArgs[0]);
nestedBuilder.create<mlir::linalg::YieldOp>(negEintOp.getLoc(),
hlfheOp.getResult());
fheOp.getResult());
};
// Create the `linalg.generic` op
@@ -773,17 +773,17 @@ struct HLFHELinalgNegEintToLinalgGeneric
};
// This template rewrite pattern transforms any instance of
// operators `HLFHELinalgMatmulOp` to an instance of `linalg.generic`
// operators `FHELinalgMatmulOp` to an instance of `linalg.generic`
// with an appropriate region using a builder that create the multiplication
// operators and `HLFHE.add_eint` operation, an appropriate specification for
// operators and `FHE.add_eint` operation, an appropriate specification for
// the iteration dimensions and appropriate operations managing the accumulator
// of `linalg.generic`.
//
// Example:
//
// "HLFHELinalg.matmul_eint_int(%a, %b) :
// (tensor<MxPx!HLFHE.eint<p>>, tensor<PxNxip'>) ->
// tensor<MxNx!HLFHE.eint<p>>"
// "FHELinalg.matmul_eint_int(%a, %b) :
// (tensor<MxPx!FHE.eint<p>>, tensor<PxNxip'>) ->
// tensor<MxNx!FHE.eint<p>>"
//
// becomes:
@@ -799,36 +799,36 @@ struct HLFHELinalgNegEintToLinalgGeneric
// }
// %init = linalg.generate {
// ^bb0(%i : index, %j : index, %k : index):
// %z = "HLFHE.zero" : () -> !HLFHE.eint<2>
// %z = "FHE.zero" : () -> !FHE.eint<2>
// linalg.yield %z
// }: tensor<MxNx!HLFHE.eint<p>>
// }: tensor<MxNx!FHE.eint<p>>
// linalg.generic #attributes_0
// ins(%A, %B : tensor<MxPx!HLFHE.eint<p>>,
// ins(%A, %B : tensor<MxPx!FHE.eint<p>>,
// tensor<PxNxip'>)
// outs(%C : tensor<MxNx!HLFHE.eint<p>>)
// outs(%C : tensor<MxNx!FHE.eint<p>>)
// {
// ^bb0(%a: !HLFHE.eint<p>, %b: ip', %c: !HLFHE.eint<p>) :
// %d = createMulOp(%a, %b): !HLFHE.eint<p>
// %e = "HLFHE.add_eint"(%c, %d):
// (!HLFHE.eint<p>, !HLFHE.eint<p>) -> !HLFHE.eint<p>
// linalg.yield %e : !HLFHE.eint<p>
// ^bb0(%a: !FHE.eint<p>, %b: ip', %c: !FHE.eint<p>) :
// %d = createMulOp(%a, %b): !FHE.eint<p>
// %e = "FHE.add_eint"(%c, %d):
// (!FHE.eint<p>, !FHE.eint<p>) -> !FHE.eint<p>
// linalg.yield %e : !FHE.eint<p>
// }
//
template <typename HLFHELinalgMatmulOp>
struct HLFHELinalgMatmulToLinalgGeneric
: public mlir::OpRewritePattern<HLFHELinalgMatmulOp> {
HLFHELinalgMatmulToLinalgGeneric(
template <typename FHELinalgMatmulOp>
struct FHELinalgMatmulToLinalgGeneric
: public mlir::OpRewritePattern<FHELinalgMatmulOp> {
FHELinalgMatmulToLinalgGeneric(
mlir::MLIRContext *context,
std::function<mlir::concretelang::HLFHE::MulEintIntOp(
std::function<mlir::concretelang::FHE::MulEintIntOp(
mlir::OpBuilder &, mlir::Location, mlir::Type, mlir::Value,
mlir::Value)>
createMulOp,
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: ::mlir::OpRewritePattern<HLFHELinalgMatmulOp>(context, benefit),
: ::mlir::OpRewritePattern<FHELinalgMatmulOp>(context, benefit),
createMulOp(createMulOp) {}
::mlir::LogicalResult
matchAndRewrite(HLFHELinalgMatmulOp matmulOp,
matchAndRewrite(FHELinalgMatmulOp matmulOp,
::mlir::PatternRewriter &rewriter) const override {
mlir::Location matmulLoc = matmulOp.getLoc();
mlir::RankedTensorType resultTy =
@@ -839,11 +839,11 @@ struct HLFHELinalgMatmulToLinalgGeneric
auto generateBody = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
// %z = "HLFHE.zero" : () -> !HLFHE.eint<2>
mlir::concretelang::HLFHE::ZeroEintOp zeroOp =
nestedBuilder.create<mlir::concretelang::HLFHE::ZeroEintOp>(
// %z = "FHE.zero" : () -> !FHE.eint<2>
mlir::concretelang::FHE::ZeroEintOp zeroOp =
nestedBuilder.create<mlir::concretelang::FHE::ZeroEintOp>(
matmulLoc, resultElementTy);
// linalg.yield %z : !HLFHE.eint<p>
// linalg.yield %z : !FHE.eint<p>
nestedBuilder.create<mlir::tensor::YieldOp>(matmulLoc,
zeroOp.getResult());
};
@@ -873,16 +873,16 @@ struct HLFHELinalgMatmulToLinalgGeneric
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
// "HLFHE.mul_eint_int"(%a, %b) : (!HLFHE.eint<p>, ip') -> !HLFHE.eint<p>
mlir::concretelang::HLFHE::MulEintIntOp mulEintIntOp =
// "FHE.mul_eint_int"(%a, %b) : (!FHE.eint<p>, ip') -> !FHE.eint<p>
mlir::concretelang::FHE::MulEintIntOp mulEintIntOp =
createMulOp(nestedBuilder, matmulLoc, resultElementTy, blockArgs[0],
blockArgs[1]);
// "HLFHE.add_eint"(%c, %d): (!HLFHE.eint<p>, !HLFHE.eint<p>) ->
// !HLFHE.eint<p>
mlir::concretelang::HLFHE::AddEintOp addEintOp =
nestedBuilder.create<mlir::concretelang::HLFHE::AddEintOp>(
// "FHE.add_eint"(%c, %d): (!FHE.eint<p>, !FHE.eint<p>) ->
// !FHE.eint<p>
mlir::concretelang::FHE::AddEintOp addEintOp =
nestedBuilder.create<mlir::concretelang::FHE::AddEintOp>(
matmulLoc, resultElementTy, blockArgs[2], mulEintIntOp);
// linalg.yield %e : !HLFHE.eint<p>
// linalg.yield %e : !FHE.eint<p>
nestedBuilder.create<mlir::linalg::YieldOp>(matmulLoc,
addEintOp.getResult());
};
@@ -905,38 +905,38 @@ struct HLFHELinalgMatmulToLinalgGeneric
};
private:
std::function<mlir::concretelang::HLFHE::MulEintIntOp(
std::function<mlir::concretelang::FHE::MulEintIntOp(
mlir::OpBuilder &, mlir::Location, mlir::Type, mlir::Value, mlir::Value)>
createMulOp;
};
// This rewrite pattern transforms any instance of operators
// `HLFHELinalg.zero` to an instance of `linalg.generate` with an
// `FHELinalg.zero` to an instance of `linalg.generate` with an
// appropriate region yielding a zero value.
//
// Example:
//
// %out = "HLFHELinalg.zero"() : () -> tensor<MxNx!HLFHE.eint<p>>
// %out = "FHELinalg.zero"() : () -> tensor<MxNx!FHE.eint<p>>
//
// becomes:
//
// %0 = tensor.generate {
// ^bb0(%arg2: index, %arg3: index):
// %zero = "HLFHE.zero"() : () -> !HLFHE.eint<p>
// tensor.yield %zero : !HLFHE.eint<p>
// } : tensor<MxNx!HLFHE.eint<p>>
// %zero = "FHE.zero"() : () -> !FHE.eint<p>
// tensor.yield %zero : !FHE.eint<p>
// } : tensor<MxNx!FHE.eint<p>>
//
struct HLFHELinalgZeroToLinalgGenerate
: public mlir::OpRewritePattern<mlir::concretelang::HLFHELinalg::ZeroOp> {
HLFHELinalgZeroToLinalgGenerate(
struct FHELinalgZeroToLinalgGenerate
: public mlir::OpRewritePattern<mlir::concretelang::FHELinalg::ZeroOp> {
FHELinalgZeroToLinalgGenerate(
::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: ::mlir::OpRewritePattern<mlir::concretelang::HLFHELinalg::ZeroOp>(context,
: ::mlir::OpRewritePattern<mlir::concretelang::FHELinalg::ZeroOp>(context,
benefit) {
}
::mlir::LogicalResult
matchAndRewrite(mlir::concretelang::HLFHELinalg::ZeroOp zeroOp,
matchAndRewrite(mlir::concretelang::FHELinalg::ZeroOp zeroOp,
::mlir::PatternRewriter &rewriter) const override {
mlir::RankedTensorType resultTy =
zeroOp->getResult(0).getType().cast<mlir::RankedTensorType>();
@@ -945,7 +945,7 @@ struct HLFHELinalgZeroToLinalgGenerate
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
mlir::Value zeroScalar =
nestedBuilder.create<mlir::concretelang::HLFHE::ZeroEintOp>(
nestedBuilder.create<mlir::concretelang::FHE::ZeroEintOp>(
zeroOp.getLoc(), resultTy.getElementType());
nestedBuilder.create<mlir::tensor::YieldOp>(zeroOp.getLoc(), zeroScalar);
};
@@ -960,13 +960,13 @@ struct HLFHELinalgZeroToLinalgGenerate
};
namespace {
struct HLFHETensorOpsToLinalg
: public HLFHETensorOpsToLinalgBase<HLFHETensorOpsToLinalg> {
struct FHETensorOpsToLinalg
: public FHETensorOpsToLinalgBase<FHETensorOpsToLinalg> {
void runOnFunction() final;
};
void HLFHETensorOpsToLinalg::runOnFunction() {
void FHETensorOpsToLinalg::runOnFunction() {
mlir::FuncOp function = this->getFunction();
mlir::ConversionTarget target(getContext());
@@ -974,51 +974,51 @@ void HLFHETensorOpsToLinalg::runOnFunction() {
target.addLegalDialect<mlir::linalg::LinalgDialect>();
target.addLegalDialect<mlir::StandardOpsDialect>();
target.addLegalDialect<mlir::memref::MemRefDialect>();
target.addLegalDialect<mlir::concretelang::HLFHE::HLFHEDialect>();
target.addLegalDialect<mlir::concretelang::FHE::FHEDialect>();
target.addLegalDialect<mlir::tensor::TensorDialect>();
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
target.addIllegalOp<mlir::concretelang::HLFHELinalg::Dot>();
target.addIllegalDialect<mlir::concretelang::HLFHELinalg::HLFHELinalgDialect>();
target.addIllegalOp<mlir::concretelang::FHELinalg::Dot>();
target.addIllegalDialect<mlir::concretelang::FHELinalg::FHELinalgDialect>();
mlir::OwningRewritePatternList patterns(&getContext());
patterns.insert<DotToLinalgGeneric>(&getContext());
patterns.insert<
HLFHELinalgOpToLinalgGeneric<mlir::concretelang::HLFHELinalg::AddEintOp,
mlir::concretelang::HLFHE::AddEintOp>>(
FHELinalgOpToLinalgGeneric<mlir::concretelang::FHELinalg::AddEintOp,
mlir::concretelang::FHE::AddEintOp>>(
&getContext());
patterns.insert<
HLFHELinalgOpToLinalgGeneric<mlir::concretelang::HLFHELinalg::AddEintIntOp,
mlir::concretelang::HLFHE::AddEintIntOp>>(
FHELinalgOpToLinalgGeneric<mlir::concretelang::FHELinalg::AddEintIntOp,
mlir::concretelang::FHE::AddEintIntOp>>(
&getContext());
patterns.insert<
HLFHELinalgOpToLinalgGeneric<mlir::concretelang::HLFHELinalg::SubIntEintOp,
mlir::concretelang::HLFHE::SubIntEintOp>>(
FHELinalgOpToLinalgGeneric<mlir::concretelang::FHELinalg::SubIntEintOp,
mlir::concretelang::FHE::SubIntEintOp>>(
&getContext());
patterns.insert<
HLFHELinalgOpToLinalgGeneric<mlir::concretelang::HLFHELinalg::MulEintIntOp,
mlir::concretelang::HLFHE::MulEintIntOp>>(
FHELinalgOpToLinalgGeneric<mlir::concretelang::FHELinalg::MulEintIntOp,
mlir::concretelang::FHE::MulEintIntOp>>(
&getContext());
patterns.insert<HLFHELinalgApplyLookupTableToLinalgGeneric>(&getContext());
patterns.insert<HLFHELinalgNegEintToLinalgGeneric>(&getContext());
patterns.insert<HLFHELinalgMatmulToLinalgGeneric<
mlir::concretelang::HLFHELinalg::MatMulEintIntOp>>(
patterns.insert<FHELinalgApplyLookupTableToLinalgGeneric>(&getContext());
patterns.insert<FHELinalgNegEintToLinalgGeneric>(&getContext());
patterns.insert<FHELinalgMatmulToLinalgGeneric<
mlir::concretelang::FHELinalg::MatMulEintIntOp>>(
&getContext(), [](mlir::OpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value arg0, mlir::Value arg1) {
return builder.create<mlir::concretelang::HLFHE::MulEintIntOp>(loc, type,
return builder.create<mlir::concretelang::FHE::MulEintIntOp>(loc, type,
arg0, arg1);
});
patterns.insert<HLFHELinalgMatmulToLinalgGeneric<
mlir::concretelang::HLFHELinalg::MatMulIntEintOp>>(
patterns.insert<FHELinalgMatmulToLinalgGeneric<
mlir::concretelang::FHELinalg::MatMulIntEintOp>>(
&getContext(), [](mlir::OpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value arg0, mlir::Value arg1) {
return builder.create<mlir::concretelang::HLFHE::MulEintIntOp>(loc, type,
return builder.create<mlir::concretelang::FHE::MulEintIntOp>(loc, type,
arg1, arg0);
});
patterns.insert<HLFHELinalgApplyMultiLookupTableToLinalgGeneric>(
patterns.insert<FHELinalgApplyMultiLookupTableToLinalgGeneric>(
&getContext());
patterns.insert<HLFHELinalgApplyMappedLookupTableToLinalgGeneric>(
patterns.insert<FHELinalgApplyMappedLookupTableToLinalgGeneric>(
&getContext());
patterns.insert<HLFHELinalgZeroToLinalgGenerate>(&getContext());
patterns.insert<FHELinalgZeroToLinalgGenerate>(&getContext());
if (mlir::applyPartialConversion(function, target, std::move(patterns))
.failed())
@@ -1029,8 +1029,8 @@ void HLFHETensorOpsToLinalg::runOnFunction() {
namespace mlir {
namespace concretelang {
std::unique_ptr<mlir::FunctionPass> createConvertHLFHETensorOpsToLinalg() {
return std::make_unique<HLFHETensorOpsToLinalg>();
std::unique_ptr<mlir::FunctionPass> createConvertFHETensorOpsToLinalg() {
return std::make_unique<FHETensorOpsToLinalg>();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -0,0 +1,16 @@
add_mlir_dialect_library(FHEToTFHE
FHEToTFHE.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE
DEPENDS
FHEDialect
FHEToTFHEPatternsIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRTransforms
MLIRMath)
target_link_libraries(FHEToTFHE PUBLIC MLIRIR)

View File

@@ -6,31 +6,31 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "concretelang/Conversion/HLFHEToMidLFHE/Patterns.h"
#include "concretelang/Conversion/FHEToTFHE/Patterns.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h"
#include "concretelang/Conversion/Utils/TensorOpTypeConversion.h"
#include "concretelang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "concretelang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "concretelang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "concretelang/Dialect/MidLFHE/IR/MidLFHETypes.h"
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
#include "concretelang/Dialect/RT/IR/RTOps.h"
namespace {
struct HLFHEToMidLFHEPass : public HLFHEToMidLFHEBase<HLFHEToMidLFHEPass> {
struct FHEToTFHEPass : public FHEToTFHEBase<FHEToTFHEPass> {
void runOnOperation() final;
};
} // namespace
using mlir::concretelang::HLFHE::EncryptedIntegerType;
using mlir::concretelang::MidLFHE::GLWECipherTextType;
using mlir::concretelang::FHE::EncryptedIntegerType;
using mlir::concretelang::TFHE::GLWECipherTextType;
/// HLFHEToMidLFHETypeConverter is a TypeConverter that transform
/// `HLFHE.eint<p>` to `MidLFHE.glwe<{_,_,_}{p}>`
class HLFHEToMidLFHETypeConverter : public mlir::TypeConverter {
/// FHEToTFHETypeConverter is a TypeConverter that transform
/// `FHE.eint<p>` to `TFHE.glwe<{_,_,_}{p}>`
class FHEToTFHETypeConverter : public mlir::TypeConverter {
public:
HLFHEToMidLFHETypeConverter() {
FHEToTFHETypeConverter() {
addConversion([](mlir::Type type) { return type; });
addConversion([](EncryptedIntegerType type) {
return mlir::concretelang::convertTypeEncryptedIntegerToGLWE(
@@ -50,17 +50,17 @@ public:
}
};
void HLFHEToMidLFHEPass::runOnOperation() {
void FHEToTFHEPass::runOnOperation() {
auto op = this->getOperation();
mlir::ConversionTarget target(getContext());
HLFHEToMidLFHETypeConverter converter;
FHEToTFHETypeConverter converter;
// Mark ops from the target dialect as legal operations
target.addLegalDialect<mlir::concretelang::MidLFHE::MidLFHEDialect>();
target.addLegalDialect<mlir::concretelang::TFHE::TFHEDialect>();
// Make sure that no ops from `HLFHE` remain after the lowering
target.addIllegalDialect<mlir::concretelang::HLFHE::HLFHEDialect>();
// Make sure that no ops from `FHE` remain after the lowering
target.addIllegalDialect<mlir::concretelang::FHE::FHEDialect>();
// Make sure that no ops `linalg.generic` that have illegal types
target
@@ -77,19 +77,19 @@ void HLFHEToMidLFHEPass::runOnOperation() {
return converter.isSignatureLegal(funcOp.getType()) &&
converter.isLegal(&funcOp.getBody());
});
// Add all patterns required to lower all ops from `HLFHE` to
// `MidLFHE`
// Add all patterns required to lower all ops from `FHE` to
// `TFHE`
mlir::OwningRewritePatternList patterns(&getContext());
populateWithGeneratedHLFHEToMidLFHE(patterns);
populateWithGeneratedFHEToTFHE(patterns);
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
HLFHEToMidLFHETypeConverter>>(
FHEToTFHETypeConverter>>(
&getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
HLFHEToMidLFHETypeConverter>>(
FHEToTFHETypeConverter>>(
&getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<mlir::scf::ForOp,
HLFHEToMidLFHETypeConverter>>(
FHEToTFHETypeConverter>>(
&getContext(), converter);
mlir::concretelang::populateWithTensorTypeConverterPatterns(patterns, target,
@@ -110,8 +110,8 @@ void HLFHEToMidLFHEPass::runOnOperation() {
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>> createConvertHLFHEToMidLFHEPass() {
return std::make_unique<HLFHEToMidLFHEPass>();
std::unique_ptr<OperationPass<ModuleOp>> createConvertFHEToTFHEPass() {
return std::make_unique<FHEToTFHEPass>();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -1,17 +0,0 @@
add_mlir_dialect_library(HLFHETensorOpsToLinalg
TensorOpsToLinalg.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/HLFHE
DEPENDS
HLFHEDialect
HLFHELinalgDialect
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
MLIRIR
HLFHEDialect
HLFHELinalgDialect)
target_link_libraries(HLFHEDialect PUBLIC MLIRIR)

View File

@@ -1,16 +0,0 @@
add_mlir_dialect_library(HLFHEToMidLFHE
HLFHEToMidLFHE.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/HLFHE
DEPENDS
HLFHEDialect
HLFHEToMidLFHEPatternsIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRTransforms
MLIRMath)
target_link_libraries(HLFHEToMidLFHE PUBLIC MLIRIR)

Some files were not shown because too many files have changed in this diff Show More