From 36f51ba0c2aa9fd5fc13ee7621fb6b1555db81d7 Mon Sep 17 00:00:00 2001 From: youben11 Date: Mon, 23 Jan 2023 13:38:08 +0100 Subject: [PATCH] feat: lower and exec boolean ops --- .../concretelang/Dialect/FHE/CMakeLists.txt | 1 + .../concretelang/Dialect/FHE/IR/FHEOps.td | 17 +- .../concretelang/Dialect/FHE/IR/FHETypes.td | 5 + .../Dialect/FHE/Transforms/Boolean.h | 23 +++ .../Dialect/FHE/Transforms/Boolean.td | 13 ++ .../Dialect/FHE/Transforms/CMakeLists.txt | 4 + .../include/concretelang/Support/Pipeline.h | 4 + .../FHEToTFHEScalar/FHEToTFHEScalar.cpp | 57 +++++++ compiler/lib/Dialect/FHE/Analysis/MANP.cpp | 21 +++ compiler/lib/Dialect/FHE/Analysis/utils.cpp | 11 +- compiler/lib/Dialect/FHE/CMakeLists.txt | 1 + compiler/lib/Dialect/FHE/IR/FHEOps.cpp | 6 +- .../lib/Dialect/FHE/Transforms/Boolean.cpp | 126 +++++++++++++++ .../lib/Dialect/FHE/Transforms/CMakeLists.txt | 12 ++ compiler/lib/Support/CMakeLists.txt | 1 + compiler/lib/Support/CompilerEngine.cpp | 6 + compiler/lib/Support/Pipeline.cpp | 10 ++ compiler/lib/Support/V0ClientParameters.cpp | 22 +++ .../Conversion/FHEToTFHEScalar/neg_eint.mlir | 8 + .../FHE/Transform/boolean_transforms.mlir | 80 ++++++++++ .../check_tests/Dialect/FHE/ops.invalid.mlir | 8 +- .../tests/check_tests/Dialect/FHE/ops.mlir | 8 +- .../tests_cpu/end_to_end_fhe.yaml | 148 ++++++++++++++++++ 23 files changed, 572 insertions(+), 20 deletions(-) create mode 100644 compiler/include/concretelang/Dialect/FHE/Transforms/Boolean.h create mode 100644 compiler/include/concretelang/Dialect/FHE/Transforms/Boolean.td create mode 100644 compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt create mode 100644 compiler/lib/Dialect/FHE/Transforms/Boolean.cpp create mode 100644 compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt create mode 100644 compiler/tests/check_tests/Dialect/FHE/Transform/boolean_transforms.mlir diff --git a/compiler/include/concretelang/Dialect/FHE/CMakeLists.txt b/compiler/include/concretelang/Dialect/FHE/CMakeLists.txt index 4f7494893..306b43968 100644 --- a/compiler/include/concretelang/Dialect/FHE/CMakeLists.txt +++ b/compiler/include/concretelang/Dialect/FHE/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(Analysis) add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td index d626ac809..5efd0d2c6 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td @@ -375,15 +375,17 @@ def FHE_GenGateOp : FHE_Op<"gen_gate", [NoSideEffect]> { Example: ```mlir // ok - "FHE.gen_gate"(%a, %b, %ttable): (!FHE.ebool, !FHE.ebool, tensor<4xi1>) -> (!FHE.ebool) + "FHE.gen_gate"(%a, %b, %ttable): (!FHE.ebool, !FHE.ebool, tensor<4xi64>) -> (!FHE.ebool) // error - "FHE.gen_gate"(%a, %b, %ttable): (!FHE.ebool, !FHE.ebool, tensor<4xi4>) -> (!FHE.ebool) - "FHE.gen_gate"(%a, %b, %ttable): (!FHE.ebool, !FHE.ebool, tensor<7xi1>) -> (!FHE.ebool) + "FHE.gen_gate"(%a, %b, %ttable): (!FHE.ebool, !FHE.ebool, tensor<7xi64>) -> (!FHE.ebool) ``` }]; - let arguments = (ins FHE_EncryptedBooleanType:$left, FHE_EncryptedBooleanType:$right, TensorOf<[I1]>:$truth_table); + // The reason the truth table is of AnyInteger and not I1 is that in lowering passes, the truth_table is meant to be passed + // to an LUT operation which requires the table to be of type I64. Whenever lowering passes are no more restrictive, this + // can be set to I1 to reflect the boolean logic. + let arguments = (ins FHE_EncryptedBooleanType:$left, FHE_EncryptedBooleanType:$right, TensorOf<[AnyInteger]>:$truth_table); let results = (outs FHE_EncryptedBooleanType); let hasVerifier = 1; } @@ -498,16 +500,17 @@ def FHE_ToBoolOp : FHE_Op<"to_bool", [NoSideEffect]> { let description = [{ Cast an unsigned integer to a boolean. - The input must necessarily be of width 1, in order to put that single - bit into a boolean. + The input must necessarily be of width 1 or 2. 2 being the current representation + of an encrypted boolean, leaving one bit for the carry. Examples: ```mlir // ok "FHE.to_bool"(%x) : (!FHE.eint<1>) -> !FHE.ebool + "FHE.to_bool"(%x) : (!FHE.eint<2>) -> !FHE.ebool // error - "FHE.to_bool"(%x) : (!FHE.eint<2>) -> !FHE.ebool + "FHE.to_bool"(%x) : (!FHE.eint<3>) -> !FHE.ebool ``` }]; diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td b/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td index 50f5138aa..e4cbd0bae 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td @@ -78,6 +78,11 @@ def FHE_EncryptedBooleanType : FHE_Type<"EncryptedBoolean", let description = [{ An encrypted boolean. }]; + + let extraClassDeclaration = [{ + /// Returns the required number of bits to represent an encrypted boolean + static size_t getWidth() { return 2; } + }]; } #endif diff --git a/compiler/include/concretelang/Dialect/FHE/Transforms/Boolean.h b/compiler/include/concretelang/Dialect/FHE/Transforms/Boolean.h new file mode 100644 index 000000000..386c6a6a1 --- /dev/null +++ b/compiler/include/concretelang/Dialect/FHE/Transforms/Boolean.h @@ -0,0 +1,23 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_FHE_BOOLEAN_PASS_H +#define CONCRETELANG_FHE_BOOLEAN_PASS_H + +#include +#include + +#define GEN_PASS_CLASSES +#include + +namespace mlir { +namespace concretelang { + +std::unique_ptr> createFHEBooleanTransformPass(); + +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Dialect/FHE/Transforms/Boolean.td b/compiler/include/concretelang/Dialect/FHE/Transforms/Boolean.td new file mode 100644 index 000000000..11e87d6a2 --- /dev/null +++ b/compiler/include/concretelang/Dialect/FHE/Transforms/Boolean.td @@ -0,0 +1,13 @@ +#ifndef CONCRETELANG_FHE_BOOLEAN_PASS +#define CONCRETELANG_FHE_BOOLEAN_PASS + +include "mlir/Pass/PassBase.td" + +def FHEBooleanTransform : Pass<"fhe-boolean-transform"> { + let summary = "Transform FHE boolean operations to integer operations"; + let constructor = "mlir::concretelang::createFHEBooleanTransformPass()"; + let options = []; + let dependentDialects = [ "mlir::concretelang::FHE::FHEDialect" ]; +} + +#endif diff --git a/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt b/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt new file mode 100644 index 000000000..1cac703ce --- /dev/null +++ b/compiler/include/concretelang/Dialect/FHE/Transforms/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS Boolean.td) +mlir_tablegen(Boolean.h.inc -gen-pass-decls -name Transforms) +add_public_tablegen_target(ConcretelangFHEBooleanPassIncGen) +add_dependencies(mlir-headers ConcretelangFHEBooleanPassIncGen) diff --git a/compiler/include/concretelang/Support/Pipeline.h b/compiler/include/concretelang/Support/Pipeline.h index aaa58634d..9067e3153 100644 --- a/compiler/include/concretelang/Support/Pipeline.h +++ b/compiler/include/concretelang/Support/Pipeline.h @@ -40,6 +40,10 @@ lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass, bool parallelize, bool batch); +mlir::LogicalResult +transformFHEBoolean(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass); + mlir::LogicalResult lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, llvm::Optional &fheContext, diff --git a/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp b/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp index 816850f4a..5e7727909 100644 --- a/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp +++ b/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp @@ -78,6 +78,11 @@ public: addConversion([](FHE::EncryptedIntegerType type) { return convertEint(type.getContext(), type); }); + addConversion([](FHE::EncryptedBooleanType type) { + return TFHE::GLWECipherTextType::get( + type.getContext(), -1, -1, -1, + mlir::concretelang::FHE::EncryptedBooleanType::getWidth()); + }); addConversion([](mlir::RankedTensorType type) { return maybeConvertEintTensor(type.getContext(), type); }); @@ -363,6 +368,51 @@ private: concretelang::ScalarLoweringParameters loweringParameters; }; +/// Rewriter for the `FHE::to_bool` operation. +struct ToBoolOpPattern : public mlir::OpRewritePattern { + ToBoolOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, benefit) {} + + mlir::LogicalResult + matchAndRewrite(FHE::ToBoolOp op, + mlir::PatternRewriter &rewriter) const override { + auto width = op.input() + .getType() + .dyn_cast() + .getWidth(); + if (width == mlir::concretelang::FHE::EncryptedBooleanType::getWidth()) { + rewriter.replaceOp(op, op.input()); + return mlir::success(); + } + // TODO + op->emitError("only support conversion with width 2 for the moment"); + return mlir::failure(); + } +}; + +/// Rewriter for the `FHE::from_bool` operation. +struct FromBoolOpPattern : public mlir::OpRewritePattern { + FromBoolOpPattern(mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, benefit) {} + + mlir::LogicalResult + matchAndRewrite(FHE::FromBoolOp op, + mlir::PatternRewriter &rewriter) const override { + auto width = op.getResult() + .getType() + .dyn_cast() + .getWidth(); + if (width == mlir::concretelang::FHE::EncryptedBooleanType::getWidth()) { + rewriter.replaceOp(op, op.input()); + return mlir::success(); + } + // TODO + op->emitError("only support conversion with width 2 for the moment"); + return mlir::failure(); + } +}; + } // namespace lowering struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { @@ -431,6 +481,9 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { // |_ `FHE::neg_eint` concretelang::GenericTypeAndOpConverterPattern, + // |_ `FHE::not` + concretelang::GenericTypeAndOpConverterPattern, // |_ `FHE::add_eint` concretelang::GenericTypeAndOpConverterPattern>( @@ -449,6 +502,10 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { patterns.add(&getContext(), loweringParameters); + // Patterns for boolean conversion ops + patterns.add( + &getContext()); + // Patterns for the relics of the `FHELinalg` dialect operations. // |_ `linalg::generic` turned to nested `scf::for` patterns diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 2bff387d9..ea3f1d634 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -412,6 +412,22 @@ static llvm::APInt getSqMANP( return eNorm; } +/// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation +/// that is equivalent to an `FHE.not` operation. +static llvm::APInt getSqMANP( + mlir::concretelang::FHE::BoolNotOp op, + llvm::ArrayRef *> operandMANPs) { + + assert( + operandMANPs.size() == 1 && + operandMANPs[0]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue(); + + return eNorm; +} + /// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation /// that is equivalent to an `FHE.mul_eint_int` operation. static llvm::APInt getSqMANP( @@ -1124,10 +1140,15 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { } else if (auto negEintOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(negEintOp, operands); + } else if (auto boolNotOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = getSqMANP(boolNotOp, operands); } else if (auto mulEintIntOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(mulEintIntOp, operands); } else if (llvm::isa(op) || + llvm::isa(op) || + llvm::isa(op) || llvm::isa(op) || llvm::isa(op)) { norm2SqEquiv = llvm::APInt{1, 1, false}; diff --git a/compiler/lib/Dialect/FHE/Analysis/utils.cpp b/compiler/lib/Dialect/FHE/Analysis/utils.cpp index d0d3b270c..5f23cdb1c 100644 --- a/compiler/lib/Dialect/FHE/Analysis/utils.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/utils.cpp @@ -15,6 +15,7 @@ namespace utils { bool isEncryptedValue(mlir::Value value) { return ( value.getType().isa() || + value.getType().isa() || (value.getType().isa() && value.getType() .cast() @@ -22,14 +23,20 @@ bool isEncryptedValue(mlir::Value value) { .isa())); } -/// Returns the bit width of `value` if `value` is an encrypted integer -/// or the bit width of the elements if `value` is a tensor of +/// Returns the bit width of `value` if `value` is an encrypted integer, +/// or the number of bits to represent a boolean if `value` is an encrypted +/// boolean, or the bit width of the elements if `value` is a tensor of /// encrypted integers. unsigned int getEintPrecision(mlir::Value value) { if (auto ty = value.getType() .dyn_cast_or_null< mlir::concretelang::FHE::EncryptedIntegerType>()) { return ty.getWidth(); + } + if (auto ty = value.getType() + .dyn_cast_or_null< + mlir::concretelang::FHE::EncryptedBooleanType>()) { + return mlir::concretelang::FHE::EncryptedBooleanType::getWidth(); } else if (auto tensorTy = value.getType().dyn_cast_or_null()) { if (auto ty = tensorTy.getElementType() diff --git a/compiler/lib/Dialect/FHE/CMakeLists.txt b/compiler/lib/Dialect/FHE/CMakeLists.txt index 4f7494893..306b43968 100644 --- a/compiler/lib/Dialect/FHE/CMakeLists.txt +++ b/compiler/lib/Dialect/FHE/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(Analysis) add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp index cfe13d8d6..6a4de8497 100644 --- a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp +++ b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp @@ -206,9 +206,9 @@ mlir::LogicalResult ToUnsignedOp::verify() { mlir::LogicalResult ToBoolOp::verify() { auto input = this->input().getType().cast(); - if (input.getWidth() != 1) { - this->emitOpError( - "should have 1 as the width of encrypted input to cast to a boolean"); + if (input.getWidth() != 1 && input.getWidth() != 2) { + this->emitOpError("should have 1 or 2 as the width of encrypted input to " + "cast to a boolean"); return mlir::failure(); } diff --git a/compiler/lib/Dialect/FHE/Transforms/Boolean.cpp b/compiler/lib/Dialect/FHE/Transforms/Boolean.cpp new file mode 100644 index 000000000..c559a39c0 --- /dev/null +++ b/compiler/lib/Dialect/FHE/Transforms/Boolean.cpp @@ -0,0 +1,126 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include +#include +#include + +#include +#include +#include +#include + +namespace mlir { +namespace concretelang { + +namespace { + +/// Rewrite an `FHE.gen_gate` operation as an LUT operation by composing a +/// single index from the two boolean inputs. +class GenGatePattern + : public mlir::OpRewritePattern { +public: + GenGatePattern(mlir::MLIRContext *context) + : mlir::OpRewritePattern( + context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + mlir::LogicalResult + matchAndRewrite(mlir::concretelang::FHE::GenGateOp op, + mlir::PatternRewriter &rewriter) const override { + auto eint2 = mlir::concretelang::FHE::EncryptedIntegerType::get( + rewriter.getContext(), 2); + auto left = rewriter + .create( + op.getLoc(), eint2, op.left()) + .getResult(); + auto right = rewriter + .create( + op.getLoc(), eint2, op.right()) + .getResult(); + auto cst_two = + rewriter.create(op.getLoc(), 2, 3) + .getResult(); + auto leftMulTwo = rewriter + .create( + op.getLoc(), left, cst_two) + .getResult(); + auto newIndex = rewriter + .create( + op.getLoc(), leftMulTwo, right) + .getResult(); + auto lut_result = + rewriter.create( + op.getLoc(), eint2, newIndex, op.truth_table()); + rewriter.replaceOpWithNewOp( + op, + mlir::concretelang::FHE::EncryptedBooleanType::get( + rewriter.getContext()), + lut_result); + return mlir::success(); + } +}; + +/// Rewrite an FHE GateOp (e.g. And/Or) into a GenGate with the given truth +/// table. +template +class GeneralizeGatePattern : public mlir::OpRewritePattern { +public: + GeneralizeGatePattern(mlir::MLIRContext *context, + llvm::SmallVector truth_table_vector) + : mlir::OpRewritePattern( + context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT), + truth_table_vector(truth_table_vector) {} + + mlir::LogicalResult + matchAndRewrite(GateOp op, mlir::PatternRewriter &rewriter) const override { + auto truth_table_attr = mlir::DenseElementsAttr::get( + mlir::RankedTensorType::get({4}, rewriter.getIntegerType(64)), + {llvm::APInt(1, this->truth_table_vector[0], false), + llvm::APInt(1, this->truth_table_vector[1], false), + llvm::APInt(1, this->truth_table_vector[2], false), + llvm::APInt(1, this->truth_table_vector[3], false)}); + auto truth_table = + rewriter.create(op.getLoc(), truth_table_attr); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), op.left(), op.right(), truth_table); + return mlir::success(); + } + +private: + llvm::SmallVector truth_table_vector; +}; + +/// Perfoms the transformation of boolean operations +class FHEBooleanTransformPass + : public FHEBooleanTransformBase { +public: + void runOnOperation() override { + mlir::Operation *op = getOperation(); + + mlir::RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + patterns.add>( + &getContext(), llvm::SmallVector({0, 0, 0, 1})); + patterns.add>( + &getContext(), llvm::SmallVector({1, 1, 1, 0})); + patterns.add>( + &getContext(), llvm::SmallVector({0, 1, 1, 1})); + patterns.add>( + &getContext(), llvm::SmallVector({0, 1, 1, 0})); + + if (mlir::applyPatternsAndFoldGreedily(op, std::move(patterns)).failed()) { + this->signalPassFailure(); + } + } +}; + +} // end anonymous namespace + +std::unique_ptr> createFHEBooleanTransformPass() { + return std::make_unique(); +} + +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt b/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt new file mode 100644 index 000000000..b8972f347 --- /dev/null +++ b/compiler/lib/Dialect/FHE/Transforms/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_library( + FHEDialectTransforms + Boolean.cpp + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/FHE + DEPENDS + FHEDialect + mlir-headers + LINK_LIBS + PUBLIC + MLIRIR + FHEDialect) diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index a20fd16b1..e478ffcfd 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -21,6 +21,7 @@ add_mlir_library( FHELinalgDialect FHELinalgDialectTransforms FHETensorOpsToLinalg + FHEDialectTransforms FHEToTFHECrt FHEToTFHEScalar ExtractSDFGOps diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index dbb3c0f72..ecd67dd60 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -283,6 +283,12 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { if (target == Target::ROUND_TRIP) return std::move(res); + if (mlir::concretelang::pipeline::transformFHEBoolean(mlirContext, module, + enablePass) + .failed()) { + return errorDiag("Transforming FHE boolean ops failed"); + } + // FHE High level pass to determine FHE parameters if (auto err = this->determineFHEParameters(res)) return std::move(err); diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index fe60c3dff..da3b51675 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -207,6 +208,15 @@ lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, return pm.run(module.getOperation()); } +mlir::LogicalResult +transformFHEBoolean(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass) { + mlir::PassManager pm(&context); + addPotentiallyNestedPass( + pm, mlir::concretelang::createFHEBooleanTransformPass(), enablePass); + return pm.run(module.getOperation()); +} + mlir::LogicalResult lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, llvm::Optional &fheContext, diff --git a/compiler/lib/Support/V0ClientParameters.cpp b/compiler/lib/Support/V0ClientParameters.cpp index 96b97f909..57bdd6768 100644 --- a/compiler/lib/Support/V0ClientParameters.cpp +++ b/compiler/lib/Support/V0ClientParameters.cpp @@ -85,6 +85,28 @@ llvm::Expected gateFromMLIRType(V0FHEContext fheContext, }, }; } + if (auto lweTy = type.dyn_cast_or_null< + mlir::concretelang::FHE::EncryptedBooleanType>()) { + size_t width = mlir::concretelang::FHE::EncryptedBooleanType::getWidth(); + return CircuitGate{ + /* .encryption = */ llvm::Optional({ + /* .secretKeyID = */ secretKeyID, + /* .variance = */ variance, + /* .encoding = */ + { + /* .precision = */ width, + /* .crt = */ std::vector(), + }, + }), + /*.shape = */ + { + /*.width = */ width, + /*.dimensions = */ std::vector(), + /*.size = */ 0, + /*.sign = */ false, + }, + }; + } auto tensor = type.dyn_cast_or_null(); if (tensor != nullptr) { auto gate = gateFromMLIRType(fheContext, secretKeyID, variance, diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/neg_eint.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/neg_eint.mlir index f21f99c76..d68f4d565 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/neg_eint.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/neg_eint.mlir @@ -8,3 +8,11 @@ func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { %1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>) return %1: !FHE.eint<7> } + +// CHECK-LABEL: func.func @not(%arg0: !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}> +func.func @not(%arg0: !FHE.ebool) -> !FHE.ebool { + // CHECK-NEXT: %0 = "TFHE.neg_glwe"(%arg0) {MANP = 1 : ui1} : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}> + // CHECK-NEXT: return %0 : !TFHE.glwe<{_,_,_}{2}> + %1 = "FHE.not"(%arg0) : (!FHE.ebool) -> !FHE.ebool + return %1: !FHE.ebool +} diff --git a/compiler/tests/check_tests/Dialect/FHE/Transform/boolean_transforms.mlir b/compiler/tests/check_tests/Dialect/FHE/Transform/boolean_transforms.mlir new file mode 100644 index 000000000..1d74d21a2 --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHE/Transform/boolean_transforms.mlir @@ -0,0 +1,80 @@ +// RUN: concretecompiler --passes fhe-boolean-transform --action=dump-fhe %s 2>&1| FileCheck %s + +// CHECK-LABEL: func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>) -> !FHE.ebool +func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>) -> !FHE.ebool { + // CHECK-NEXT: %[[C0:.*]] = arith.constant 2 : i3 + // CHECK-NEXT: %[[V0:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2> + // CHECK-NEXT: %[[V1:.*]] = "FHE.from_bool"(%arg1) : (!FHE.ebool) -> !FHE.eint<2> + // CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%[[V0]], %[[C0]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint"(%[[V2]], %[[V1]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK-NEXT: %[[V4:.*]] = "FHE.apply_lookup_table"(%[[V3]], %arg2) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2> + // CHECK-NEXT: %[[V5:.*]] = "FHE.to_bool"(%[[V4]]) : (!FHE.eint<2>) -> !FHE.ebool + // CHECK-NEXT: return %[[V5]] : !FHE.ebool + + %1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi64>) -> !FHE.ebool + return %1: !FHE.ebool +} + +// CHECK-LABEL: func.func @and(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool +func.func @and(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { + // CHECK-NEXT: %[[TT:.*]] = arith.constant dense<[0, 0, 0, 1]> : tensor<4xi64> + // CHECK-NEXT: %[[C0:.*]] = arith.constant 2 : i3 + // CHECK-NEXT: %[[V0:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2> + // CHECK-NEXT: %[[V1:.*]] = "FHE.from_bool"(%arg1) : (!FHE.ebool) -> !FHE.eint<2> + // CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%[[V0]], %[[C0]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint"(%[[V2]], %[[V1]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK-NEXT: %[[V4:.*]] = "FHE.apply_lookup_table"(%[[V3]], %[[TT]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2> + // CHECK-NEXT: %[[V5:.*]] = "FHE.to_bool"(%[[V4]]) : (!FHE.eint<2>) -> !FHE.ebool + // CHECK-NEXT: return %[[V5]] : !FHE.ebool + + %1 = "FHE.and"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool + return %1: !FHE.ebool +} + +// CHECK-LABEL: func.func @or(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool +func.func @or(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { + // CHECK-NEXT: %[[TT:.*]] = arith.constant dense<[0, 1, 1, 1]> : tensor<4xi64> + // CHECK-NEXT: %[[C0:.*]] = arith.constant 2 : i3 + // CHECK-NEXT: %[[V0:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2> + // CHECK-NEXT: %[[V1:.*]] = "FHE.from_bool"(%arg1) : (!FHE.ebool) -> !FHE.eint<2> + // CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%[[V0]], %[[C0]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint"(%[[V2]], %[[V1]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK-NEXT: %[[V4:.*]] = "FHE.apply_lookup_table"(%[[V3]], %[[TT]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2> + // CHECK-NEXT: %[[V5:.*]] = "FHE.to_bool"(%[[V4]]) : (!FHE.eint<2>) -> !FHE.ebool + // CHECK-NEXT: return %[[V5]] : !FHE.ebool + + %1 = "FHE.or"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool + return %1: !FHE.ebool +} + +// CHECK-LABEL: func.func @nand(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool +func.func @nand(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { + // CHECK-NEXT: %[[TT:.*]] = arith.constant dense<[1, 1, 1, 0]> : tensor<4xi64> + // CHECK-NEXT: %[[C0:.*]] = arith.constant 2 : i3 + // CHECK-NEXT: %[[V0:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2> + // CHECK-NEXT: %[[V1:.*]] = "FHE.from_bool"(%arg1) : (!FHE.ebool) -> !FHE.eint<2> + // CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%[[V0]], %[[C0]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint"(%[[V2]], %[[V1]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK-NEXT: %[[V4:.*]] = "FHE.apply_lookup_table"(%[[V3]], %[[TT]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2> + // CHECK-NEXT: %[[V5:.*]] = "FHE.to_bool"(%[[V4]]) : (!FHE.eint<2>) -> !FHE.ebool + // CHECK-NEXT: return %[[V5]] : !FHE.ebool + + %1 = "FHE.nand"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool + return %1: !FHE.ebool +} + +// CHECK-LABEL: func.func @xor(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool +func.func @xor(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { + // CHECK-NEXT: %[[TT:.*]] = arith.constant dense<[0, 1, 1, 0]> : tensor<4xi64> + // CHECK-NEXT: %[[C0:.*]] = arith.constant 2 : i3 + // CHECK-NEXT: %[[V0:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2> + // CHECK-NEXT: %[[V1:.*]] = "FHE.from_bool"(%arg1) : (!FHE.ebool) -> !FHE.eint<2> + // CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%[[V0]], %[[C0]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint"(%[[V2]], %[[V1]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK-NEXT: %[[V4:.*]] = "FHE.apply_lookup_table"(%[[V3]], %[[TT]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2> + // CHECK-NEXT: %[[V5:.*]] = "FHE.to_bool"(%[[V4]]) : (!FHE.eint<2>) -> !FHE.ebool + // CHECK-NEXT: return %[[V5]] : !FHE.ebool + + %1 = "FHE.xor"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool + return %1: !FHE.ebool +} diff --git a/compiler/tests/check_tests/Dialect/FHE/ops.invalid.mlir b/compiler/tests/check_tests/Dialect/FHE/ops.invalid.mlir index af4aae43e..707ae5f83 100644 --- a/compiler/tests/check_tests/Dialect/FHE/ops.invalid.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/ops.invalid.mlir @@ -16,9 +16,9 @@ func.func @zero_plaintext() -> tensor<4x9xi32> { // ----- -func.func @to_bool(%arg0: !FHE.eint<2>) -> !FHE.ebool { +func.func @to_bool(%arg0: !FHE.eint<3>) -> !FHE.ebool { // expected-error @+1 {{'FHE.to_bool' op}} - %1 = "FHE.to_bool"(%arg0): (!FHE.eint<2>) -> (!FHE.ebool) + %1 = "FHE.to_bool"(%arg0): (!FHE.eint<3>) -> (!FHE.ebool) return %1: !FHE.ebool } @@ -32,8 +32,8 @@ func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<5xi1>) - // ----- -func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi2>) -> !FHE.ebool { +func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<6xi64>) -> !FHE.ebool { // expected-error @+1 {{'FHE.gen_gate' op}} - %1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi2>) -> !FHE.ebool + %1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<6xi64>) -> !FHE.ebool return %1: !FHE.ebool } diff --git a/compiler/tests/check_tests/Dialect/FHE/ops.mlir b/compiler/tests/check_tests/Dialect/FHE/ops.mlir index 567902bc2..014dd522d 100644 --- a/compiler/tests/check_tests/Dialect/FHE/ops.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/ops.mlir @@ -232,12 +232,12 @@ func.func @from_bool(%arg0: !FHE.ebool) -> !FHE.eint<1> { return %1: !FHE.eint<1> } -// CHECK-LABEL: func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi1>) -> !FHE.ebool -func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi1>) -> !FHE.ebool { - // CHECK-NEXT: %[[V1:.*]] = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi1>) -> !FHE.ebool +// CHECK-LABEL: func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>) -> !FHE.ebool +func.func @gen_gate(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>) -> !FHE.ebool { + // CHECK-NEXT: %[[V1:.*]] = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi64>) -> !FHE.ebool // CHECK-NEXT: return %[[V1]] : !FHE.ebool - %1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi1>) -> !FHE.ebool + %1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi64>) -> !FHE.ebool return %1: !FHE.ebool } diff --git a/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhe.yaml b/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhe.yaml index 258b1b499..244270c95 100644 --- a/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhe.yaml +++ b/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhe.yaml @@ -183,3 +183,151 @@ tests: outputs: - tensor: [9,4,7,7,10,9,9,4,7,7,10,9] shape: [4,3] +--- +description: boolean_and +program: | + func.func @main(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { + %1 = "FHE.and"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool + return %1: !FHE.ebool + } +tests: + - inputs: + - scalar: 0 + - scalar: 0 + outputs: + - scalar: 0 + - inputs: + - scalar: 0 + - scalar: 1 + outputs: + - scalar: 0 + - inputs: + - scalar: 1 + - scalar: 0 + outputs: + - scalar: 0 + - inputs: + - scalar: 1 + - scalar: 1 + outputs: + - scalar: 1 +--- +description: boolean_or +program: | + func.func @main(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { + %1 = "FHE.or"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool + return %1: !FHE.ebool + } +tests: + - inputs: + - scalar: 0 + - scalar: 0 + outputs: + - scalar: 0 + - inputs: + - scalar: 0 + - scalar: 1 + outputs: + - scalar: 1 + - inputs: + - scalar: 1 + - scalar: 0 + outputs: + - scalar: 1 + - inputs: + - scalar: 1 + - scalar: 1 + outputs: + - scalar: 1 +--- +description: boolean_nand +program: | + func.func @main(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { + %1 = "FHE.nand"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool + return %1: !FHE.ebool + } +tests: + - inputs: + - scalar: 0 + - scalar: 0 + outputs: + - scalar: 1 + - inputs: + - scalar: 0 + - scalar: 1 + outputs: + - scalar: 1 + - inputs: + - scalar: 1 + - scalar: 0 + outputs: + - scalar: 1 + - inputs: + - scalar: 1 + - scalar: 1 + outputs: + - scalar: 0 +--- +description: boolean_xor +program: | + func.func @main(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { + %1 = "FHE.xor"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool + return %1: !FHE.ebool + } +tests: + - inputs: + - scalar: 0 + - scalar: 0 + outputs: + - scalar: 0 + - inputs: + - scalar: 0 + - scalar: 1 + outputs: + - scalar: 1 + - inputs: + - scalar: 1 + - scalar: 0 + outputs: + - scalar: 1 + - inputs: + - scalar: 1 + - scalar: 1 + outputs: + - scalar: 0 +--- +description: boolean_gen_gate +program: | + func.func @main(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: tensor<4xi64>) -> !FHE.ebool { + %1 = "FHE.gen_gate"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, tensor<4xi64>) -> !FHE.ebool + return %1: !FHE.ebool + } +tests: + - inputs: + - scalar: 0 + - scalar: 1 + - tensor: [1, 0, 1, 1] + shape: [4] + outputs: + - scalar: 0 + - inputs: + - scalar: 0 + - scalar: 0 + - tensor: [0, 1, 1, 1] + shape: [4] + outputs: + - scalar: 0 + - inputs: + - scalar: 1 + - scalar: 0 + - tensor: [0, 0, 1, 0] + shape: [4] + outputs: + - scalar: 1 + - inputs: + - scalar: 1 + - scalar: 1 + - tensor: [0, 0, 0, 1] + shape: [4] + outputs: + - scalar: 1