// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" template struct FunctionConstantOpConversion : public mlir::OpRewritePattern { FunctionConstantOpConversion(mlir::MLIRContext *ctx, TypeConverterType &converter, mlir::PatternBenefit benefit = 1) : ::mlir::OpRewritePattern(ctx, benefit), converter(converter) {} mlir::LogicalResult matchAndRewrite(mlir::func::ConstantOp op, mlir::PatternRewriter &rewriter) const override { auto symTab = mlir::SymbolTable::getNearestSymbolTable(op); auto funcOp = mlir::SymbolTable::lookupSymbolIn(symTab, op.getValue()); assert(funcOp && "Function symbol missing in symbol table for function constant op."); mlir::FunctionType funType = mlir::cast(funcOp) .getFunctionType() .cast(); typename TypeConverterType::SignatureConversion result( funType.getNumInputs()); mlir::SmallVector newResults; if (failed(converter.convertSignatureArgs(funType.getInputs(), result)) || failed(converter.convertTypes(funType.getResults(), newResults))) return mlir::failure(); auto newType = mlir::FunctionType::get( rewriter.getContext(), result.getConvertedTypes(), newResults); rewriter.updateRootInPlace(op, [&] { op.getResult().setType(newType); }); return mlir::success(); } static bool isLegal(mlir::func::ConstantOp fun, TypeConverterType &converter) { auto symTab = mlir::SymbolTable::getNearestSymbolTable(fun); auto funcOp = mlir::SymbolTable::lookupSymbolIn(symTab, fun.getValue()); assert(funcOp && "Function symbol missing in symbol table for function constant op."); mlir::FunctionType funType = mlir::cast(funcOp) .getFunctionType() .cast(); typename TypeConverterType::SignatureConversion result( funType.getNumInputs()); mlir::SmallVector newResults; if (failed(converter.convertSignatureArgs(funType.getInputs(), result)) || failed(converter.convertTypes(funType.getResults(), newResults))) return false; auto newType = mlir::FunctionType::get( fun.getContext(), result.getConvertedTypes(), newResults); return newType == fun.getType(); } private: TypeConverterType &converter; };