refactor: remove chunked_eint

the pass transforming operations on chunked_eint will operate now on
eint
This commit is contained in:
youben11
2023-02-06 10:10:01 +01:00
committed by Ayoub Benaissa
parent bb87d29934
commit 7e60f87141
9 changed files with 24 additions and 177 deletions

View File

@@ -25,11 +25,6 @@ def FheIntegerInterface : TypeInterface<"FheIntegerInterface"> {
/*description=*/"Get whether the integer is unsigned.",
/*retTy=*/"bool",
/*methodName=*/"isUnsigned"
>,
InterfaceMethod<
/*description=*/"Get whether the integer is chunked (composed of multiple smaller integers).",
/*retTy=*/"bool",
/*methodName=*/"isChunked"
>
];
}

View File

@@ -33,7 +33,6 @@ def FHE_EncryptedIntegerType : FHE_Type<"EncryptedInteger",
let extraClassDeclaration = [{
bool isSigned() const { return false; }
bool isUnsigned() const { return true; }
bool isChunked() const { return false; }
}];
}
@@ -62,43 +61,12 @@ def FHE_EncryptedSignedIntegerType : FHE_Type<"EncryptedSignedInteger",
let extraClassDeclaration = [{
bool isSigned() const { return true; }
bool isUnsigned() const { return false; }
bool isChunked() const { return false; }
}];
}
def FHE_ChunkedEncryptedIntegerType : FHE_Type<"ChunkedEncryptedInteger",
[MemRefElementTypeInterface, FheIntegerInterface]> {
let mnemonic = "chunked_eint";
let summary = "An encrypted integer composed of multiple chunks";
let description = [{
An encrypted integer composed of multiple chunks.
Examples:
```mlir
!FHE.chunked_eint<64>
!FHE.chunked_eint<32>
```
}];
let parameters = (ins "unsigned":$width);
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = true;
let extraClassDeclaration = [{
bool isSigned() const { return false; }
bool isUnsigned() const { return true; }
bool isChunked() const { return true; }
}];
}
def FHE_AnyEncryptedInteger : Type<Or<[
FHE_EncryptedIntegerType.predicate,
FHE_EncryptedSignedIntegerType.predicate,
FHE_ChunkedEncryptedIntegerType.predicate
FHE_EncryptedSignedIntegerType.predicate
]>>;
def FHE_EncryptedBooleanType : FHE_Type<"EncryptedBoolean",

View File

@@ -89,33 +89,3 @@ mlir::Type EncryptedSignedIntegerType::parse(mlir::AsmParser &p) {
return getChecked(loc, loc.getContext(), width);
}
mlir::LogicalResult ChunkedEncryptedIntegerType::verify(
llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned p) {
if (p == 0) {
emitError() << "FHE.chunked_eint doesn't support precision of 0";
return mlir::failure();
}
return mlir::success();
}
void ChunkedEncryptedIntegerType::print(mlir::AsmPrinter &p) const {
p << "<" << getWidth() << ">";
}
mlir::Type ChunkedEncryptedIntegerType::parse(mlir::AsmParser &p) {
if (p.parseLess())
return mlir::Type();
int width;
if (p.parseInteger(width))
return mlir::Type();
if (p.parseGreater())
return mlir::Type();
mlir::Location loc = p.getEncodedSourceLoc(p.getNameLoc());
return getChecked(loc, loc.getContext(), width);
}

View File

@@ -29,25 +29,13 @@ bool verifyEncryptedIntegerInputAndResultConsistency(
return false;
}
if (input.isChunked() != result.isChunked()) {
op.emitOpError("should have the same composition (chunked or not) of "
"encrypted input and result");
return false;
}
return true;
}
bool verifyEncryptedIntegerAndIntegerInputsConsistency(mlir::Operation &op,
FheIntegerInterface &a,
IntegerType &b) {
if (a.isChunked()) {
if (a.getWidth() != b.getWidth()) {
op.emitOpError("should have the width of plain input equal to width of "
"encrypted input");
return false;
}
} else if (a.getWidth() + 1 != b.getWidth()) {
if (a.getWidth() + 1 != b.getWidth()) {
op.emitOpError("should have the width of plain input equal to width of "
"encrypted input + 1");
return false;
@@ -69,12 +57,6 @@ bool verifyEncryptedIntegerInputsConsistency(mlir::Operation &op,
return false;
}
if (a.isChunked() != b.isChunked()) {
op.emitOpError("should have the same composition (chunked or not) of "
"encrypted inputs");
return false;
}
return true;
}

View File

@@ -49,10 +49,10 @@ namespace typing {
/// Converts `FHE::ChunkedEncryptedInteger` into a tensor of
/// `FHE::EncryptedInteger`.
mlir::RankedTensorType
convertChunkedEint(mlir::MLIRContext *context,
FHE::ChunkedEncryptedIntegerType chunkedEint,
unsigned int chunkSize, unsigned int chunkWidth) {
mlir::RankedTensorType convertChunkedEint(mlir::MLIRContext *context,
FHE::EncryptedIntegerType chunkedEint,
unsigned int chunkSize,
unsigned int chunkWidth) {
auto eint = FHE::EncryptedIntegerType::get(context, chunkSize);
auto bigIntWidth = chunkedEint.getWidth();
assert(bigIntWidth % chunkWidth == 0 &&
@@ -68,9 +68,13 @@ class TypeConverter : public mlir::TypeConverter {
public:
TypeConverter(unsigned int chunkSize, unsigned int chunkWidth) {
addConversion([](mlir::Type type) { return type; });
addConversion([chunkSize,
chunkWidth](FHE::ChunkedEncryptedIntegerType type) {
return convertChunkedEint(type.getContext(), type, chunkSize, chunkWidth);
addConversion([chunkSize, chunkWidth](FHE::EncryptedIntegerType type) {
if (type.getWidth() > chunkSize) {
return (mlir::Type)convertChunkedEint(type.getContext(), type,
chunkSize, chunkWidth);
} else {
return (mlir::Type)type;
}
});
}
};

View File

@@ -296,11 +296,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
return StreamStringError("Rewriting of encrypted mul failed");
}
if (mlir::concretelang::pipeline::transformFHEBigInt(
mlirContext, module, enablePass, options.chunkSize,
options.chunkWidth)
.failed()) {
return errorDiag("Transforming FHE big integer ops failed");
if (options.chunkIntegers) {
if (mlir::concretelang::pipeline::transformFHEBigInt(
mlirContext, module, enablePass, options.chunkSize,
options.chunkWidth)
.failed()) {
return errorDiag("Transforming FHE big integer ops failed");
}
}
// FHE High level pass to determine FHE parameters

View File

@@ -98,46 +98,6 @@ gateFromMLIRType(V0FHEContext fheContext, LweSecretKeyID secretKeyID,
/*.chunkInfo = */ chunkInfo,
};
}
// TODO: this a duplicate of the last if: should be removed when we remove
// chinked eint
if (auto lweTy = type.dyn_cast_or_null<
mlir::concretelang::FHE::ChunkedEncryptedIntegerType>()) {
bool sign = lweTy.isSignedInteger();
std::vector<int64_t> crt;
if (fheContext.parameter.largeInteger.has_value()) {
crt = fheContext.parameter.largeInteger.value().crtDecomposition;
}
size_t width;
uint64_t size;
std::vector<int64_t> dims;
if (chunkInfo.hasValue()) {
width = chunkInfo->size;
assert(lweTy.getWidth() % chunkInfo->width == 0);
size = lweTy.getWidth() / chunkInfo->width;
dims.push_back(size);
} else {
width = (size_t)lweTy.getWidth();
}
return CircuitGate{
/* .encryption = */ llvm::Optional<EncryptionGate>({
/* .secretKeyID = */ secretKeyID,
/* .variance = */ variance,
/* .encoding = */
{
/* .precision = */ width,
/* .crt = */ crt,
},
}),
/*.shape = */
{
/*.width = */ width,
/*.dimensions = */ dims,
/*.size = */ size,
/*.sign = */ sign,
},
/*.chunkInfo = */ chunkInfo,
};
}
if (auto lweTy = type.dyn_cast_or_null<
mlir::concretelang::FHE::EncryptedBooleanType>()) {
size_t width = mlir::concretelang::FHE::EncryptedBooleanType::getWidth();

View File

@@ -1,7 +1,7 @@
// RUN: concretecompiler --chunk-size 4 --chunk-width 2 --passes fhe-big-int-transform --action=dump-fhe %s 2>&1| FileCheck %s
// RUN: concretecompiler --chunk-integers --chunk-size 4 --chunk-width 2 --passes fhe-big-int-transform --action=dump-fhe %s 2>&1| FileCheck %s
// CHECK-LABEL: func.func @add_chunked_eint(%arg0: tensor<32x!FHE.eint<4>>, %arg1: tensor<32x!FHE.eint<4>>) -> tensor<32x!FHE.eint<4>>
func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> {
func.func @add_chunked_eint(%arg0: !FHE.eint<64>, %arg1: !FHE.eint<64>) -> !FHE.eint<64> {
// CHECK-NEXT: %[[V0:.*]] = "FHE.zero"() : () -> !FHE.eint<4>
// CHECK-NEXT: %[[V1:.*]] = "FHE.zero_tensor"() : () -> tensor<32x!FHE.eint<4>>
// CHECK-NEXT: %[[c4_i5:.*]] = arith.constant 4 : i5
@@ -19,6 +19,6 @@ func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_ei
// CHECK-NEXT: }
// CHECK-NEXT: return %2 : tensor<32x!FHE.eint<4>>
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.chunked_eint<64>, !FHE.chunked_eint<64>) -> (!FHE.chunked_eint<64>)
return %1: !FHE.chunked_eint<64>
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<64>, !FHE.eint<64>) -> (!FHE.eint<64>)
return %1: !FHE.eint<64>
}

View File

@@ -1,34 +0,0 @@
// RUN: concretecompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: func.func @mul_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64>
func.func @mul_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i64
// CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%arg0, %[[V1]]) : (!FHE.chunked_eint<64>, i64) -> !FHE.chunked_eint<64>
// CHECK-NEXT: return %[[V2]] : !FHE.chunked_eint<64>
%0 = arith.constant 1 : i64
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.chunked_eint<64>, i64) -> (!FHE.chunked_eint<64>)
return %1: !FHE.chunked_eint<64>
}
// CHECK-LABEL: func.func @add_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64>
func.func @add_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i64
// CHECK-NEXT: %[[V2:.*]] = "FHE.add_eint_int"(%arg0, %[[V1]]) : (!FHE.chunked_eint<64>, i64) -> !FHE.chunked_eint<64>
// CHECK-NEXT: return %[[V2]] : !FHE.chunked_eint<64>
%0 = arith.constant 1 : i64
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.chunked_eint<64>, i64) -> (!FHE.chunked_eint<64>)
return %1: !FHE.chunked_eint<64>
}
// CHECK-LABEL: func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64>
func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> {
// CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint"(%arg0, %arg1) : (!FHE.chunked_eint<64>, !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64>
// CHECK-NEXT: return %[[V1]] : !FHE.chunked_eint<64>
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.chunked_eint<64>, !FHE.chunked_eint<64>) -> (!FHE.chunked_eint<64>)
return %1: !FHE.chunked_eint<64>
}
// TODO: max/min