diff --git a/compiler/lib/Dialect/FHE/Transforms/Boolean.cpp b/compiler/lib/Dialect/FHE/Transforms/Boolean.cpp index a3a2e1300..a51da4fd1 100644 --- a/compiler/lib/Dialect/FHE/Transforms/Boolean.cpp +++ b/compiler/lib/Dialect/FHE/Transforms/Boolean.cpp @@ -92,6 +92,58 @@ private: llvm::SmallVector truth_table_vector; }; +/// Rewrite an `FHE.mux` op, into a series of boolean and arithmetic operations +/// mux(cond, c1, c2) => c1 and not cond + c2 and cond +class MuxOpPattern + : public mlir::OpRewritePattern { +public: + MuxOpPattern(mlir::MLIRContext *context) + : mlir::OpRewritePattern( + context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + mlir::LogicalResult + matchAndRewrite(mlir::concretelang::FHE::MuxOp op, + mlir::PatternRewriter &rewriter) const override { + auto eint2 = mlir::concretelang::FHE::EncryptedIntegerType::get( + rewriter.getContext(), 2); + auto boolType = mlir::concretelang::FHE::EncryptedBooleanType::get( + rewriter.getContext()); + + // truth table for c1 and not cond + auto truth_table_attr = mlir::DenseElementsAttr::get( + mlir::RankedTensorType::get({4}, rewriter.getIntegerType(64)), + {llvm::APInt(1, 0, false), llvm::APInt(1, 0, false), + llvm::APInt(1, 1, false), llvm::APInt(1, 0, false)}); + auto truth_table = + rewriter.create(op.getLoc(), truth_table_attr); + auto c1AndNotCond = + rewriter + .create( + op.getLoc(), boolType, op.c1(), op.cond(), truth_table) + .getResult(); + auto c2AndCond = rewriter + .create( + op.getLoc(), boolType, op.c2(), op.cond()) + .getResult(); + + auto c1AndNotCondBool = rewriter + .create( + op.getLoc(), eint2, c1AndNotCond) + .getResult(); + auto c2AndCondBool = rewriter + .create( + op.getLoc(), eint2, c2AndCond) + .getResult(); + auto result = rewriter + .create( + op.getLoc(), c1AndNotCondBool, c2AndCondBool) + .getResult(); + rewriter.replaceOpWithNewOp(op, boolType, + result); + return mlir::success(); + } +}; + /// Perfoms the transformation of boolean operations class FHEBooleanTransformPass : public FHEBooleanTransformBase { @@ -101,6 +153,7 @@ public: mlir::RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); + patterns.add(&getContext()); patterns.add>( &getContext(), llvm::SmallVector({0, 0, 0, 1})); patterns.add>( diff --git a/compiler/tests/check_tests/Dialect/FHE/Transform/boolean_transforms.mlir b/compiler/tests/check_tests/Dialect/FHE/Transform/boolean_transforms.mlir index 1d74d21a2..a47a7bee5 100644 --- a/compiler/tests/check_tests/Dialect/FHE/Transform/boolean_transforms.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/Transform/boolean_transforms.mlir @@ -78,3 +78,30 @@ func.func @xor(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool { %1 = "FHE.xor"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool return %1: !FHE.ebool } + +// CHECK-LABEL: func.func @mux(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: !FHE.ebool) -> !FHE.ebool +func.func @mux(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: !FHE.ebool) -> !FHE.ebool { + // CHECK-NEXT: %[[TT1:.*]] = arith.constant dense<[0, 0, 1, 0]> : tensor<4xi64> + // CHECK-NEXT: %[[C1:.*]] = arith.constant 2 : i3 + // CHECK-NEXT: %[[TT2:.*]] = arith.constant dense<[0, 0, 0, 1]> : tensor<4xi64> + // CHECK-NEXT: %[[V1:.*]] = "FHE.from_bool"(%arg1) : (!FHE.ebool) -> !FHE.eint<2> + // CHECK-NEXT: %[[V2:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2> + // CHECK-NEXT: %[[V3:.*]] = "FHE.mul_eint_int"(%[[V1]], %[[C1]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK-NEXT: %[[V4:.*]] = "FHE.add_eint"(%[[V3]], %[[V2]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK-NEXT: %[[V5:.*]] = "FHE.apply_lookup_table"(%[[V4]], %[[TT1]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2> + // CHECK-NEXT: %[[V6:.*]] = "FHE.to_bool"(%[[V5]]) : (!FHE.eint<2>) -> !FHE.ebool + // CHECK-NEXT: %[[V7:.*]] = "FHE.from_bool"(%arg2) : (!FHE.ebool) -> !FHE.eint<2> + // CHECK-NEXT: %[[V8:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2> + // CHECK-NEXT: %[[V9:.*]] = "FHE.mul_eint_int"(%[[V7]], %[[C1]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK-NEXT: %[[V10:.*]] = "FHE.add_eint"(%[[V9]], %[[V8]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK-NEXT: %[[V11:.*]] = "FHE.apply_lookup_table"(%[[V10]], %[[TT2]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2> + // CHECK-NEXT: %[[V12:.*]] = "FHE.to_bool"(%[[V11]]) : (!FHE.eint<2>) -> !FHE.ebool + // CHECK-NEXT: %[[V13:.*]] = "FHE.from_bool"(%[[V6]]) : (!FHE.ebool) -> !FHE.eint<2> + // CHECK-NEXT: %[[V14:.*]] = "FHE.from_bool"(%[[V12]]) : (!FHE.ebool) -> !FHE.eint<2> + // CHECK-NEXT: %[[V15:.*]] = "FHE.add_eint"(%[[V13]], %[[V14]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK-NEXT: %[[V16:.*]] = "FHE.to_bool"(%[[V15]]) : (!FHE.eint<2>) -> !FHE.ebool + // CHECK-NEXT: return %[[V16]] : !FHE.ebool + + %1 = "FHE.mux"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, !FHE.ebool) -> !FHE.ebool + return %1: !FHE.ebool +} diff --git a/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhe.yaml b/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhe.yaml index 334cea23d..99fb81bc1 100644 --- a/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhe.yaml +++ b/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhe.yaml @@ -351,3 +351,59 @@ tests: shape: [4] outputs: - scalar: 1 +--- +description: boolean_mux +program: | + func.func @main(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: !FHE.ebool) -> !FHE.ebool { + %1 = "FHE.mux"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, !FHE.ebool) -> !FHE.ebool + return %1: !FHE.ebool + } +tests: + - inputs: + - scalar: 0 + - scalar: 0 + - scalar: 0 + outputs: + - scalar: 0 + - inputs: + - scalar: 1 + - scalar: 0 + - scalar: 0 + outputs: + - scalar: 0 + - inputs: + - scalar: 0 + - scalar: 1 + - scalar: 0 + outputs: + - scalar: 1 + - inputs: + - scalar: 1 + - scalar: 1 + - scalar: 0 + outputs: + - scalar: 0 + - inputs: + - scalar: 0 + - scalar: 0 + - scalar: 1 + outputs: + - scalar: 0 + - inputs: + - scalar: 1 + - scalar: 0 + - scalar: 1 + outputs: + - scalar: 1 + - inputs: + - scalar: 0 + - scalar: 1 + - scalar: 1 + outputs: + - scalar: 1 + - inputs: + - scalar: 1 + - scalar: 1 + - scalar: 1 + outputs: + - scalar: 1