mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: lower and exec boolean mux
This commit is contained in:
@@ -92,6 +92,58 @@ private:
|
||||
llvm::SmallVector<uint64_t, 4> truth_table_vector;
|
||||
};
|
||||
|
||||
/// Rewrite an `FHE.mux` op, into a series of boolean and arithmetic operations
|
||||
/// mux(cond, c1, c2) => c1 and not cond + c2 and cond
|
||||
class MuxOpPattern
|
||||
: public mlir::OpRewritePattern<mlir::concretelang::FHE::MuxOp> {
|
||||
public:
|
||||
MuxOpPattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<mlir::concretelang::FHE::MuxOp>(
|
||||
context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::FHE::MuxOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto eint2 = mlir::concretelang::FHE::EncryptedIntegerType::get(
|
||||
rewriter.getContext(), 2);
|
||||
auto boolType = mlir::concretelang::FHE::EncryptedBooleanType::get(
|
||||
rewriter.getContext());
|
||||
|
||||
// truth table for c1 and not cond
|
||||
auto truth_table_attr = mlir::DenseElementsAttr::get(
|
||||
mlir::RankedTensorType::get({4}, rewriter.getIntegerType(64)),
|
||||
{llvm::APInt(1, 0, false), llvm::APInt(1, 0, false),
|
||||
llvm::APInt(1, 1, false), llvm::APInt(1, 0, false)});
|
||||
auto truth_table =
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), truth_table_attr);
|
||||
auto c1AndNotCond =
|
||||
rewriter
|
||||
.create<mlir::concretelang::FHE::GenGateOp>(
|
||||
op.getLoc(), boolType, op.c1(), op.cond(), truth_table)
|
||||
.getResult();
|
||||
auto c2AndCond = rewriter
|
||||
.create<mlir::concretelang::FHE::BoolAndOp>(
|
||||
op.getLoc(), boolType, op.c2(), op.cond())
|
||||
.getResult();
|
||||
|
||||
auto c1AndNotCondBool = rewriter
|
||||
.create<mlir::concretelang::FHE::FromBoolOp>(
|
||||
op.getLoc(), eint2, c1AndNotCond)
|
||||
.getResult();
|
||||
auto c2AndCondBool = rewriter
|
||||
.create<mlir::concretelang::FHE::FromBoolOp>(
|
||||
op.getLoc(), eint2, c2AndCond)
|
||||
.getResult();
|
||||
auto result = rewriter
|
||||
.create<mlir::concretelang::FHE::AddEintOp>(
|
||||
op.getLoc(), c1AndNotCondBool, c2AndCondBool)
|
||||
.getResult();
|
||||
rewriter.replaceOpWithNewOp<mlir::concretelang::FHE::ToBoolOp>(op, boolType,
|
||||
result);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Perfoms the transformation of boolean operations
|
||||
class FHEBooleanTransformPass
|
||||
: public FHEBooleanTransformBase<FHEBooleanTransformPass> {
|
||||
@@ -101,6 +153,7 @@ public:
|
||||
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
patterns.add<GenGatePattern>(&getContext());
|
||||
patterns.add<MuxOpPattern>(&getContext());
|
||||
patterns.add<GeneralizeGatePattern<mlir::concretelang::FHE::BoolAndOp>>(
|
||||
&getContext(), llvm::SmallVector<uint64_t, 4>({0, 0, 0, 1}));
|
||||
patterns.add<GeneralizeGatePattern<mlir::concretelang::FHE::BoolNandOp>>(
|
||||
|
||||
@@ -78,3 +78,30 @@ func.func @xor(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {
|
||||
%1 = "FHE.xor"(%arg0, %arg1) : (!FHE.ebool, !FHE.ebool) -> !FHE.ebool
|
||||
return %1: !FHE.ebool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @mux(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: !FHE.ebool) -> !FHE.ebool
|
||||
func.func @mux(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: !FHE.ebool) -> !FHE.ebool {
|
||||
// CHECK-NEXT: %[[TT1:.*]] = arith.constant dense<[0, 0, 1, 0]> : tensor<4xi64>
|
||||
// CHECK-NEXT: %[[C1:.*]] = arith.constant 2 : i3
|
||||
// CHECK-NEXT: %[[TT2:.*]] = arith.constant dense<[0, 0, 0, 1]> : tensor<4xi64>
|
||||
// CHECK-NEXT: %[[V1:.*]] = "FHE.from_bool"(%arg1) : (!FHE.ebool) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "FHE.mul_eint_int"(%[[V1]], %[[C1]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V4:.*]] = "FHE.add_eint"(%[[V3]], %[[V2]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V5:.*]] = "FHE.apply_lookup_table"(%[[V4]], %[[TT1]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V6:.*]] = "FHE.to_bool"(%[[V5]]) : (!FHE.eint<2>) -> !FHE.ebool
|
||||
// CHECK-NEXT: %[[V7:.*]] = "FHE.from_bool"(%arg2) : (!FHE.ebool) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V8:.*]] = "FHE.from_bool"(%arg0) : (!FHE.ebool) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V9:.*]] = "FHE.mul_eint_int"(%[[V7]], %[[C1]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V10:.*]] = "FHE.add_eint"(%[[V9]], %[[V8]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V11:.*]] = "FHE.apply_lookup_table"(%[[V10]], %[[TT2]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V12:.*]] = "FHE.to_bool"(%[[V11]]) : (!FHE.eint<2>) -> !FHE.ebool
|
||||
// CHECK-NEXT: %[[V13:.*]] = "FHE.from_bool"(%[[V6]]) : (!FHE.ebool) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V14:.*]] = "FHE.from_bool"(%[[V12]]) : (!FHE.ebool) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V15:.*]] = "FHE.add_eint"(%[[V13]], %[[V14]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: %[[V16:.*]] = "FHE.to_bool"(%[[V15]]) : (!FHE.eint<2>) -> !FHE.ebool
|
||||
// CHECK-NEXT: return %[[V16]] : !FHE.ebool
|
||||
|
||||
%1 = "FHE.mux"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, !FHE.ebool) -> !FHE.ebool
|
||||
return %1: !FHE.ebool
|
||||
}
|
||||
|
||||
@@ -351,3 +351,59 @@ tests:
|
||||
shape: [4]
|
||||
outputs:
|
||||
- scalar: 1
|
||||
---
|
||||
description: boolean_mux
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.ebool, %arg1: !FHE.ebool, %arg2: !FHE.ebool) -> !FHE.ebool {
|
||||
%1 = "FHE.mux"(%arg0, %arg1, %arg2) : (!FHE.ebool, !FHE.ebool, !FHE.ebool) -> !FHE.ebool
|
||||
return %1: !FHE.ebool
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 0
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 0
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 1
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 1
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 1
|
||||
- scalar: 0
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 0
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 0
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 1
|
||||
- inputs:
|
||||
- scalar: 0
|
||||
- scalar: 1
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 1
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
- scalar: 1
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 1
|
||||
|
||||
Reference in New Issue
Block a user