diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHEInterfaces.td b/compiler/include/concretelang/Dialect/FHE/IR/FHEInterfaces.td index 587a088f6..cb0afe564 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHEInterfaces.td +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHEInterfaces.td @@ -25,11 +25,6 @@ def FheIntegerInterface : TypeInterface<"FheIntegerInterface"> { /*description=*/"Get whether the integer is unsigned.", /*retTy=*/"bool", /*methodName=*/"isUnsigned" - >, - InterfaceMethod< - /*description=*/"Get whether the integer is chunked (composed of multiple smaller integers).", - /*retTy=*/"bool", - /*methodName=*/"isChunked" > ]; } diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td b/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td index fe98730d3..e4cbd0bae 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td @@ -33,7 +33,6 @@ def FHE_EncryptedIntegerType : FHE_Type<"EncryptedInteger", let extraClassDeclaration = [{ bool isSigned() const { return false; } bool isUnsigned() const { return true; } - bool isChunked() const { return false; } }]; } @@ -62,43 +61,12 @@ def FHE_EncryptedSignedIntegerType : FHE_Type<"EncryptedSignedInteger", let extraClassDeclaration = [{ bool isSigned() const { return true; } bool isUnsigned() const { return false; } - bool isChunked() const { return false; } - }]; -} - -def FHE_ChunkedEncryptedIntegerType : FHE_Type<"ChunkedEncryptedInteger", - [MemRefElementTypeInterface, FheIntegerInterface]> { - let mnemonic = "chunked_eint"; - - let summary = "An encrypted integer composed of multiple chunks"; - - let description = [{ - An encrypted integer composed of multiple chunks. - - Examples: - ```mlir - !FHE.chunked_eint<64> - !FHE.chunked_eint<32> - ``` - }]; - - let parameters = (ins "unsigned":$width); - - let hasCustomAssemblyFormat = 1; - - let genVerifyDecl = true; - - let extraClassDeclaration = [{ - bool isSigned() const { return false; } - bool isUnsigned() const { return true; } - bool isChunked() const { return true; } }]; } def FHE_AnyEncryptedInteger : Type>; def FHE_EncryptedBooleanType : FHE_Type<"EncryptedBoolean", diff --git a/compiler/lib/Dialect/FHE/IR/FHEDialect.cpp b/compiler/lib/Dialect/FHE/IR/FHEDialect.cpp index b8c2c078e..3bea0d3a9 100644 --- a/compiler/lib/Dialect/FHE/IR/FHEDialect.cpp +++ b/compiler/lib/Dialect/FHE/IR/FHEDialect.cpp @@ -89,33 +89,3 @@ mlir::Type EncryptedSignedIntegerType::parse(mlir::AsmParser &p) { return getChecked(loc, loc.getContext(), width); } - -mlir::LogicalResult ChunkedEncryptedIntegerType::verify( - llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned p) { - if (p == 0) { - emitError() << "FHE.chunked_eint doesn't support precision of 0"; - return mlir::failure(); - } - return mlir::success(); -} - -void ChunkedEncryptedIntegerType::print(mlir::AsmPrinter &p) const { - p << "<" << getWidth() << ">"; -} - -mlir::Type ChunkedEncryptedIntegerType::parse(mlir::AsmParser &p) { - if (p.parseLess()) - return mlir::Type(); - - int width; - - if (p.parseInteger(width)) - return mlir::Type(); - - if (p.parseGreater()) - return mlir::Type(); - - mlir::Location loc = p.getEncodedSourceLoc(p.getNameLoc()); - - return getChecked(loc, loc.getContext(), width); -} diff --git a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp index bc03735c0..bee590615 100644 --- a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp +++ b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp @@ -29,25 +29,13 @@ bool verifyEncryptedIntegerInputAndResultConsistency( return false; } - if (input.isChunked() != result.isChunked()) { - op.emitOpError("should have the same composition (chunked or not) of " - "encrypted input and result"); - return false; - } - return true; } bool verifyEncryptedIntegerAndIntegerInputsConsistency(mlir::Operation &op, FheIntegerInterface &a, IntegerType &b) { - if (a.isChunked()) { - if (a.getWidth() != b.getWidth()) { - op.emitOpError("should have the width of plain input equal to width of " - "encrypted input"); - return false; - } - } else if (a.getWidth() + 1 != b.getWidth()) { + if (a.getWidth() + 1 != b.getWidth()) { op.emitOpError("should have the width of plain input equal to width of " "encrypted input + 1"); return false; @@ -69,12 +57,6 @@ bool verifyEncryptedIntegerInputsConsistency(mlir::Operation &op, return false; } - if (a.isChunked() != b.isChunked()) { - op.emitOpError("should have the same composition (chunked or not) of " - "encrypted inputs"); - return false; - } - return true; } diff --git a/compiler/lib/Dialect/FHE/Transforms/BigInt.cpp b/compiler/lib/Dialect/FHE/Transforms/BigInt.cpp index 34c0ae632..680c547d7 100644 --- a/compiler/lib/Dialect/FHE/Transforms/BigInt.cpp +++ b/compiler/lib/Dialect/FHE/Transforms/BigInt.cpp @@ -49,10 +49,10 @@ namespace typing { /// Converts `FHE::ChunkedEncryptedInteger` into a tensor of /// `FHE::EncryptedInteger`. -mlir::RankedTensorType -convertChunkedEint(mlir::MLIRContext *context, - FHE::ChunkedEncryptedIntegerType chunkedEint, - unsigned int chunkSize, unsigned int chunkWidth) { +mlir::RankedTensorType convertChunkedEint(mlir::MLIRContext *context, + FHE::EncryptedIntegerType chunkedEint, + unsigned int chunkSize, + unsigned int chunkWidth) { auto eint = FHE::EncryptedIntegerType::get(context, chunkSize); auto bigIntWidth = chunkedEint.getWidth(); assert(bigIntWidth % chunkWidth == 0 && @@ -68,9 +68,13 @@ class TypeConverter : public mlir::TypeConverter { public: TypeConverter(unsigned int chunkSize, unsigned int chunkWidth) { addConversion([](mlir::Type type) { return type; }); - addConversion([chunkSize, - chunkWidth](FHE::ChunkedEncryptedIntegerType type) { - return convertChunkedEint(type.getContext(), type, chunkSize, chunkWidth); + addConversion([chunkSize, chunkWidth](FHE::EncryptedIntegerType type) { + if (type.getWidth() > chunkSize) { + return (mlir::Type)convertChunkedEint(type.getContext(), type, + chunkSize, chunkWidth); + } else { + return (mlir::Type)type; + } }); } }; diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 10a37291c..6c9199d60 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -296,11 +296,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { return StreamStringError("Rewriting of encrypted mul failed"); } - if (mlir::concretelang::pipeline::transformFHEBigInt( - mlirContext, module, enablePass, options.chunkSize, - options.chunkWidth) - .failed()) { - return errorDiag("Transforming FHE big integer ops failed"); + if (options.chunkIntegers) { + if (mlir::concretelang::pipeline::transformFHEBigInt( + mlirContext, module, enablePass, options.chunkSize, + options.chunkWidth) + .failed()) { + return errorDiag("Transforming FHE big integer ops failed"); + } } // FHE High level pass to determine FHE parameters diff --git a/compiler/lib/Support/V0ClientParameters.cpp b/compiler/lib/Support/V0ClientParameters.cpp index 4042cac80..00125da8a 100644 --- a/compiler/lib/Support/V0ClientParameters.cpp +++ b/compiler/lib/Support/V0ClientParameters.cpp @@ -98,46 +98,6 @@ gateFromMLIRType(V0FHEContext fheContext, LweSecretKeyID secretKeyID, /*.chunkInfo = */ chunkInfo, }; } - // TODO: this a duplicate of the last if: should be removed when we remove - // chinked eint - if (auto lweTy = type.dyn_cast_or_null< - mlir::concretelang::FHE::ChunkedEncryptedIntegerType>()) { - bool sign = lweTy.isSignedInteger(); - std::vector crt; - if (fheContext.parameter.largeInteger.has_value()) { - crt = fheContext.parameter.largeInteger.value().crtDecomposition; - } - size_t width; - uint64_t size; - std::vector dims; - if (chunkInfo.hasValue()) { - width = chunkInfo->size; - assert(lweTy.getWidth() % chunkInfo->width == 0); - size = lweTy.getWidth() / chunkInfo->width; - dims.push_back(size); - } else { - width = (size_t)lweTy.getWidth(); - } - return CircuitGate{ - /* .encryption = */ llvm::Optional({ - /* .secretKeyID = */ secretKeyID, - /* .variance = */ variance, - /* .encoding = */ - { - /* .precision = */ width, - /* .crt = */ crt, - }, - }), - /*.shape = */ - { - /*.width = */ width, - /*.dimensions = */ dims, - /*.size = */ size, - /*.sign = */ sign, - }, - /*.chunkInfo = */ chunkInfo, - }; - } if (auto lweTy = type.dyn_cast_or_null< mlir::concretelang::FHE::EncryptedBooleanType>()) { size_t width = mlir::concretelang::FHE::EncryptedBooleanType::getWidth(); diff --git a/compiler/tests/check_tests/Dialect/FHE/Transform/big_int_transform.mlir b/compiler/tests/check_tests/Dialect/FHE/Transform/big_int_transform.mlir index 0f0fbc177..ff9e1d778 100644 --- a/compiler/tests/check_tests/Dialect/FHE/Transform/big_int_transform.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/Transform/big_int_transform.mlir @@ -1,7 +1,7 @@ -// RUN: concretecompiler --chunk-size 4 --chunk-width 2 --passes fhe-big-int-transform --action=dump-fhe %s 2>&1| FileCheck %s +// RUN: concretecompiler --chunk-integers --chunk-size 4 --chunk-width 2 --passes fhe-big-int-transform --action=dump-fhe %s 2>&1| FileCheck %s // CHECK-LABEL: func.func @add_chunked_eint(%arg0: tensor<32x!FHE.eint<4>>, %arg1: tensor<32x!FHE.eint<4>>) -> tensor<32x!FHE.eint<4>> -func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> { +func.func @add_chunked_eint(%arg0: !FHE.eint<64>, %arg1: !FHE.eint<64>) -> !FHE.eint<64> { // CHECK-NEXT: %[[V0:.*]] = "FHE.zero"() : () -> !FHE.eint<4> // CHECK-NEXT: %[[V1:.*]] = "FHE.zero_tensor"() : () -> tensor<32x!FHE.eint<4>> // CHECK-NEXT: %[[c4_i5:.*]] = arith.constant 4 : i5 @@ -19,6 +19,6 @@ func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_ei // CHECK-NEXT: } // CHECK-NEXT: return %2 : tensor<32x!FHE.eint<4>> - %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.chunked_eint<64>, !FHE.chunked_eint<64>) -> (!FHE.chunked_eint<64>) - return %1: !FHE.chunked_eint<64> + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<64>, !FHE.eint<64>) -> (!FHE.eint<64>) + return %1: !FHE.eint<64> } diff --git a/compiler/tests/check_tests/Dialect/FHE/big_int.mlir b/compiler/tests/check_tests/Dialect/FHE/big_int.mlir deleted file mode 100644 index b571b7e64..000000000 --- a/compiler/tests/check_tests/Dialect/FHE/big_int.mlir +++ /dev/null @@ -1,34 +0,0 @@ -// RUN: concretecompiler --action=roundtrip %s 2>&1| FileCheck %s - -// CHECK-LABEL: func.func @mul_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> -func.func @mul_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> { - // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i64 - // CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%arg0, %[[V1]]) : (!FHE.chunked_eint<64>, i64) -> !FHE.chunked_eint<64> - // CHECK-NEXT: return %[[V2]] : !FHE.chunked_eint<64> - - %0 = arith.constant 1 : i64 - %1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.chunked_eint<64>, i64) -> (!FHE.chunked_eint<64>) - return %1: !FHE.chunked_eint<64> -} - -// CHECK-LABEL: func.func @add_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> -func.func @add_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> { - // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i64 - // CHECK-NEXT: %[[V2:.*]] = "FHE.add_eint_int"(%arg0, %[[V1]]) : (!FHE.chunked_eint<64>, i64) -> !FHE.chunked_eint<64> - // CHECK-NEXT: return %[[V2]] : !FHE.chunked_eint<64> - - %0 = arith.constant 1 : i64 - %1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.chunked_eint<64>, i64) -> (!FHE.chunked_eint<64>) - return %1: !FHE.chunked_eint<64> -} - -// CHECK-LABEL: func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> -func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> { - // CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint"(%arg0, %arg1) : (!FHE.chunked_eint<64>, !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> - // CHECK-NEXT: return %[[V1]] : !FHE.chunked_eint<64> - - %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.chunked_eint<64>, !FHE.chunked_eint<64>) -> (!FHE.chunked_eint<64>) - return %1: !FHE.chunked_eint<64> -} - -// TODO: max/min