// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "concretelang/Dialect/FHE/IR/FHEOps.h" #include "concretelang/Dialect/FHE/Transforms/Max/Max.h" namespace arith = mlir::arith; namespace func = mlir::func; namespace FHE = mlir::concretelang::FHE; /// This rewrite pattern transforms all instances /// of `FHE.max_eint` to `max(x - y, 0) + y`. struct MaxEintPattern : public mlir::OpRewritePattern { MaxEintPattern(mlir::MLIRContext *context) : mlir::OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite(FHE::MaxEintOp maxEintOp, mlir::PatternRewriter &rewriter) const override { const mlir::Location loc = maxEintOp->getLoc(); const FHE::FheIntegerInterface outputTy = maxEintOp->getResult(0).getType().cast(); const int64_t outputBitWidth = outputTy.getWidth(); mlir::Value x = maxEintOp.x(); mlir::Value y = maxEintOp.y(); const auto xTy = x.getType().cast(); const auto yTy = y.getType().cast(); const auto signedTy = FHE::EncryptedSignedIntegerType::get( this->getContext(), outputBitWidth); if (xTy.isUnsigned()) { x = rewriter.create(loc, signedTy, x).getResult(); } if (yTy.isUnsigned()) { y = rewriter.create(loc, signedTy, y).getResult(); } const mlir::Value sub = rewriter.create(loc, x, y).getResult(); const int64_t lutSize = 1 << outputBitWidth; auto lutValues = std::vector(); for (int64_t i = 0; i < lutSize / 2; i++) { lutValues.push_back(i); } for (int64_t i = 0; i < lutSize / 2; i++) { lutValues.push_back(0); } const mlir::Attribute lutAttr = rewriter.getI64TensorAttr(lutValues); const mlir::Value lut = rewriter.create(loc, lutAttr).getResult(); const mlir::Value max = rewriter.create(loc, outputTy, sub, lut) .getResult(); const mlir::Value add = rewriter.create(loc, max, maxEintOp.y()).getResult(); rewriter.replaceOp(maxEintOp, {add}); return mlir::success(); }; }; namespace { struct FHEMaxTransform : public FHEMaxTransformBase { void runOnOperation() final; }; void FHEMaxTransform::runOnOperation() { auto target = mlir::ConversionTarget(this->getContext()); target.addLegalDialect(); target.addLegalDialect(); target.addIllegalOp(); auto patterns = mlir::RewritePatternSet(&this->getContext()); patterns.insert(&this->getContext()); mlir::Operation *op = this->getOperation(); if (mlir::applyPatternsAndFoldGreedily(op, std::move(patterns)).failed()) { this->signalPassFailure(); } } } // namespace namespace mlir { namespace concretelang { std::unique_ptr> createFHEMaxTransformPass() { return std::make_unique(); } } // namespace concretelang } // namespace mlir