feat: lower and exec boolean mux

This commit is contained in:
youben11
2023-01-26 10:48:23 +01:00
committed by Ayoub Benaissa
parent d0ae2563fa
commit 59d35619a8
3 changed files with 136 additions and 0 deletions

View File

@@ -92,6 +92,58 @@ private:
llvm::SmallVector<uint64_t, 4> truth_table_vector;
};
/// Rewrite an `FHE.mux` op, into a series of boolean and arithmetic operations
/// mux(cond, c1, c2) => c1 and not cond + c2 and cond
class MuxOpPattern
: public mlir::OpRewritePattern<mlir::concretelang::FHE::MuxOp> {
public:
MuxOpPattern(mlir::MLIRContext *context)
: mlir::OpRewritePattern<mlir::concretelang::FHE::MuxOp>(
context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
mlir::LogicalResult
matchAndRewrite(mlir::concretelang::FHE::MuxOp op,
mlir::PatternRewriter &rewriter) const override {
auto eint2 = mlir::concretelang::FHE::EncryptedIntegerType::get(
rewriter.getContext(), 2);
auto boolType = mlir::concretelang::FHE::EncryptedBooleanType::get(
rewriter.getContext());
// truth table for c1 and not cond
auto truth_table_attr = mlir::DenseElementsAttr::get(
mlir::RankedTensorType::get({4}, rewriter.getIntegerType(64)),
{llvm::APInt(1, 0, false), llvm::APInt(1, 0, false),
llvm::APInt(1, 1, false), llvm::APInt(1, 0, false)});
auto truth_table =
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), truth_table_attr);
auto c1AndNotCond =
rewriter
.create<mlir::concretelang::FHE::GenGateOp>(
op.getLoc(), boolType, op.c1(), op.cond(), truth_table)
.getResult();
auto c2AndCond = rewriter
.create<mlir::concretelang::FHE::BoolAndOp>(
op.getLoc(), boolType, op.c2(), op.cond())
.getResult();
auto c1AndNotCondBool = rewriter
.create<mlir::concretelang::FHE::FromBoolOp>(
op.getLoc(), eint2, c1AndNotCond)
.getResult();
auto c2AndCondBool = rewriter
.create<mlir::concretelang::FHE::FromBoolOp>(
op.getLoc(), eint2, c2AndCond)
.getResult();
auto result = rewriter
.create<mlir::concretelang::FHE::AddEintOp>(
op.getLoc(), c1AndNotCondBool, c2AndCondBool)
.getResult();
rewriter.replaceOpWithNewOp<mlir::concretelang::FHE::ToBoolOp>(op, boolType,
result);
return mlir::success();
}
};
/// Perfoms the transformation of boolean operations
class FHEBooleanTransformPass
: public FHEBooleanTransformBase<FHEBooleanTransformPass> {
@@ -101,6 +153,7 @@ public:
mlir::RewritePatternSet patterns(&getContext());
patterns.add<GenGatePattern>(&getContext());
patterns.add<MuxOpPattern>(&getContext());
patterns.add<GeneralizeGatePattern<mlir::concretelang::FHE::BoolAndOp>>(
&getContext(), llvm::SmallVector<uint64_t, 4>({0, 0, 0, 1}));
patterns.add<GeneralizeGatePattern<mlir::concretelang::FHE::BoolNandOp>>(

View File

@@ -78,3 +78,30 @@ func.func @xor(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
%1 = "FHE.xor"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool
return %1: !FHE.ebool
}
// CHECK-LABEL: func.func @mux(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: !FHE.ebool) -> !FHE.ebool
func.func @mux(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: !FHE.ebool) -> !FHE.ebool {
// CHECK-NEXT: %[[TT1:.*]] = arith.constant dense<[0, 0, 1, 0]> : tensor<4xi64>
// CHECK-NEXT: %[[C1:.*]] = arith.constant 2 : i3
// CHECK-NEXT: %[[TT2:.*]] = arith.constant dense<[0, 0, 0, 1]> : tensor<4xi64>
// CHECK-NEXT: %[[V1:.*]] = "FHE.from_bool"(%arg1) : (!FHE.ebool) -> !FHE.eint<2>
// CHECK-NEXT: %[[V2:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2>
// CHECK-NEXT: %[[V3:.*]] = "FHE.mul_eint_int"(%[[V1]], %[[C1]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
// CHECK-NEXT: %[[V4:.*]] = "FHE.add_eint"(%[[V3]], %[[V2]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
// CHECK-NEXT: %[[V5:.*]] = "FHE.apply_lookup_table"(%[[V4]], %[[TT1]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
// CHECK-NEXT: %[[V6:.*]] = "FHE.to_bool"(%[[V5]]) : (!FHE.eint<2>) -> !FHE.ebool
// CHECK-NEXT: %[[V7:.*]] = "FHE.from_bool"(%arg2) : (!FHE.ebool) -> !FHE.eint<2>
// CHECK-NEXT: %[[V8:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2>
// CHECK-NEXT: %[[V9:.*]] = "FHE.mul_eint_int"(%[[V7]], %[[C1]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
// CHECK-NEXT: %[[V10:.*]] = "FHE.add_eint"(%[[V9]], %[[V8]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
// CHECK-NEXT: %[[V11:.*]] = "FHE.apply_lookup_table"(%[[V10]], %[[TT2]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
// CHECK-NEXT: %[[V12:.*]] = "FHE.to_bool"(%[[V11]]) : (!FHE.eint<2>) -> !FHE.ebool
// CHECK-NEXT: %[[V13:.*]] = "FHE.from_bool"(%[[V6]]) : (!FHE.ebool) -> !FHE.eint<2>
// CHECK-NEXT: %[[V14:.*]] = "FHE.from_bool"(%[[V12]]) : (!FHE.ebool) -> !FHE.eint<2>
// CHECK-NEXT: %[[V15:.*]] = "FHE.add_eint"(%[[V13]], %[[V14]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
// CHECK-NEXT: %[[V16:.*]] = "FHE.to_bool"(%[[V15]]) : (!FHE.eint<2>) -> !FHE.ebool
// CHECK-NEXT: return %[[V16]] : !FHE.ebool
%1 = "FHE.mux"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, !FHE.ebool) -> !FHE.ebool
return %1: !FHE.ebool
}

View File

@@ -351,3 +351,59 @@ tests:
shape: [4]
outputs:
- scalar: 1
---
description: boolean_mux
program: |
func.func @main(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: !FHE.ebool) -> !FHE.ebool {
%1 = "FHE.mux"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, !FHE.ebool) -> !FHE.ebool
return %1: !FHE.ebool
}
tests:
- inputs:
- scalar: 0
- scalar: 0
- scalar: 0
outputs:
- scalar: 0
- inputs:
- scalar: 1
- scalar: 0
- scalar: 0
outputs:
- scalar: 0
- inputs:
- scalar: 0
- scalar: 1
- scalar: 0
outputs:
- scalar: 1
- inputs:
- scalar: 1
- scalar: 1
- scalar: 0
outputs:
- scalar: 0
- inputs:
- scalar: 0
- scalar: 0
- scalar: 1
outputs:
- scalar: 0
- inputs:
- scalar: 1
- scalar: 0
- scalar: 1
outputs:
- scalar: 1
- inputs:
- scalar: 0
- scalar: 1
- scalar: 1
outputs:
- scalar: 1
- inputs:
- scalar: 1
- scalar: 1
- scalar: 1
outputs:
- scalar: 1