// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include #include #include #include #include #include #include #include #include using namespace mlir::concretelang::FHE; namespace mlir { namespace concretelang { namespace { struct ApplyLookupTableEintOpPattern : public mlir::OpConversionPattern { ApplyLookupTableEintOpPattern(mlir::MLIRContext *context) : mlir::OpConversionPattern( context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} mlir::LogicalResult matchAndRewrite(FHE::ApplyLookupTableEintOp op, FHE::ApplyLookupTableEintOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { // When lowered to the TFHE dialect, the table will need to be properly // encoded by a function specific to the kind of table lookup executed. This // function expects the input lut to use 64 bit integers. For this reason, // every lut that use integers of smaller precision needs to be extended to // 64 bits first. bool outputIsSigned = op.getResult().getType().cast().isSigned(); auto inputLutType = op.getLut().getType(); mlir::Value extendedLut; if (inputLutType.getElementType().getIntOrFloatBitWidth() == 64) { extendedLut = adaptor.getLut(); } else { // This is implemented as a map since the `arith.extsi` is not // bufferizable :( mlir::Value init = rewriter.create( op.getLoc(), mlir::RankedTensorType::get(inputLutType.getShape(), rewriter.getI64Type()), mlir::ValueRange{}); extendedLut = rewriter .create( op.getLoc(), mlir::ValueRange{adaptor.getLut()}, init, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange args) { mlir::Value extended; if (outputIsSigned) { extended = builder.create( loc, builder.getI64Type(), args[0]); } else { extended = builder.create( loc, builder.getI64Type(), args[0]); } builder.create( loc, mlir::ValueRange{extended}); }) ->getResult(0); } auto newOp = rewriter.replaceOpWithNewOp( op, op.getResult().getType(), op.getA(), extendedLut); // Propagating the Oid if any ... auto optimizerIdAttr = op->getAttr("TFHE.OId"); if (optimizerIdAttr != nullptr) newOp->setAttr("TFHE.OId", optimizerIdAttr); return mlir::success(); }; }; } // namespace class DynamicTLU : public DynamicTLUBase { public: void runOnOperation() override { mlir::func::FuncOp funcOp = getOperation(); mlir::ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalOp(); target.addLegalDialect(); target.addDynamicallyLegalOp( [&](FHE::ApplyLookupTableEintOp op) { return op.getLut() .getType() .getElementType() .getIntOrFloatBitWidth() == 64; }); mlir::RewritePatternSet patterns(funcOp->getContext()); patterns.add(funcOp->getContext()); if (mlir::applyPartialConversion(funcOp, target, std::move(patterns)) .failed()) { funcOp->emitError("Failed to extend dynamic luts."); this->signalPassFailure(); } } }; std::unique_ptr<::mlir::OperationPass<::mlir::func::FuncOp>> createDynamicTLUPass() { return std::make_unique(); } } // namespace concretelang } // namespace mlir