// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include #include #include #include namespace mlir { namespace concretelang { namespace { /// Rewrite an `FHE.gen_gate` operation as an LUT operation by composing a /// single index from the two boolean inputs. class GenGatePattern : public mlir::OpRewritePattern { public: GenGatePattern(mlir::MLIRContext *context) : mlir::OpRewritePattern( context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} mlir::LogicalResult matchAndRewrite(mlir::concretelang::FHE::GenGateOp op, mlir::PatternRewriter &rewriter) const override { auto eint2 = mlir::concretelang::FHE::EncryptedIntegerType::get( rewriter.getContext(), 2); auto left = rewriter .create( op.getLoc(), eint2, op.left()) .getResult(); auto right = rewriter .create( op.getLoc(), eint2, op.right()) .getResult(); auto cst_two = rewriter.create(op.getLoc(), 2, 3) .getResult(); auto leftMulTwo = rewriter .create( op.getLoc(), left, cst_two) .getResult(); auto newIndex = rewriter .create( op.getLoc(), leftMulTwo, right) .getResult(); auto lut_result = rewriter.create( op.getLoc(), eint2, newIndex, op.truth_table()); rewriter.replaceOpWithNewOp( op, mlir::concretelang::FHE::EncryptedBooleanType::get( rewriter.getContext()), lut_result); return mlir::success(); } }; /// Rewrite an FHE GateOp (e.g. And/Or) into a GenGate with the given truth /// table. template class GeneralizeGatePattern : public mlir::OpRewritePattern { public: GeneralizeGatePattern(mlir::MLIRContext *context, llvm::SmallVector truth_table_vector) : mlir::OpRewritePattern( context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT), truth_table_vector(truth_table_vector) {} mlir::LogicalResult matchAndRewrite(GateOp op, mlir::PatternRewriter &rewriter) const override { auto truth_table_attr = mlir::DenseElementsAttr::get( mlir::RankedTensorType::get({4}, rewriter.getIntegerType(64)), {llvm::APInt(1, this->truth_table_vector[0], false), llvm::APInt(1, this->truth_table_vector[1], false), llvm::APInt(1, this->truth_table_vector[2], false), llvm::APInt(1, this->truth_table_vector[3], false)}); auto truth_table = rewriter.create(op.getLoc(), truth_table_attr); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), op.left(), op.right(), truth_table); return mlir::success(); } private: llvm::SmallVector truth_table_vector; }; /// Rewrite an `FHE.mux` op, into a series of boolean and arithmetic operations /// mux(cond, c1, c2) => c1 and not cond + c2 and cond class MuxOpPattern : public mlir::OpRewritePattern { public: MuxOpPattern(mlir::MLIRContext *context) : mlir::OpRewritePattern( context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} mlir::LogicalResult matchAndRewrite(mlir::concretelang::FHE::MuxOp op, mlir::PatternRewriter &rewriter) const override { auto eint2 = mlir::concretelang::FHE::EncryptedIntegerType::get( rewriter.getContext(), 2); auto boolType = mlir::concretelang::FHE::EncryptedBooleanType::get( rewriter.getContext()); // truth table for c1 and not cond auto truth_table_attr = mlir::DenseElementsAttr::get( mlir::RankedTensorType::get({4}, rewriter.getIntegerType(64)), {llvm::APInt(1, 0, false), llvm::APInt(1, 0, false), llvm::APInt(1, 1, false), llvm::APInt(1, 0, false)}); auto truth_table = rewriter.create(op.getLoc(), truth_table_attr); auto c1AndNotCond = rewriter .create( op.getLoc(), boolType, op.c1(), op.cond(), truth_table) .getResult(); auto c2AndCond = rewriter .create( op.getLoc(), boolType, op.c2(), op.cond()) .getResult(); auto c1AndNotCondBool = rewriter .create( op.getLoc(), eint2, c1AndNotCond) .getResult(); auto c2AndCondBool = rewriter .create( op.getLoc(), eint2, c2AndCond) .getResult(); auto result = rewriter .create( op.getLoc(), c1AndNotCondBool, c2AndCondBool) .getResult(); rewriter.replaceOpWithNewOp(op, boolType, result); return mlir::success(); } }; /// Perfoms the transformation of boolean operations class FHEBooleanTransformPass : public FHEBooleanTransformBase { public: void runOnOperation() override { mlir::Operation *op = getOperation(); mlir::RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); patterns.add>( &getContext(), llvm::SmallVector({0, 0, 0, 1})); patterns.add>( &getContext(), llvm::SmallVector({1, 1, 1, 0})); patterns.add>( &getContext(), llvm::SmallVector({0, 1, 1, 1})); patterns.add>( &getContext(), llvm::SmallVector({0, 1, 1, 0})); if (mlir::applyPatternsAndFoldGreedily(op, std::move(patterns)).failed()) { this->signalPassFailure(); } } }; } // end anonymous namespace std::unique_ptr> createFHEBooleanTransformPass() { return std::make_unique(); } } // namespace concretelang } // namespace mlir