From 3d8d5c438a5ef3d414e837389801ae27b13bee21 Mon Sep 17 00:00:00 2001 From: youben11 Date: Fri, 6 Aug 2021 09:14:00 +0100 Subject: [PATCH] refactor(compiler): generalize LinalgGenericPattern also fixes an issue regarding populateWithGenerated, which can be duplicated across different pattern files. So I redefined a different function that is more unique to the pass that should be ran, and hide the populateWithGenerated from the global namespace --- .../Conversion/HLFHEToMidLFHE/Patterns.h | 6 ++ .../Utils/LinalgGenericTypeConverterPattern.h | 55 +++++++++++++++++++ .../HLFHEToMidLFHE/HLFHEToMidLFHE.cpp | 55 ++----------------- 3 files changed, 65 insertions(+), 51 deletions(-) create mode 100644 compiler/include/zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h diff --git a/compiler/include/zamalang/Conversion/HLFHEToMidLFHE/Patterns.h b/compiler/include/zamalang/Conversion/HLFHEToMidLFHE/Patterns.h index 57f8bc77c..1005af379 100644 --- a/compiler/include/zamalang/Conversion/HLFHEToMidLFHE/Patterns.h +++ b/compiler/include/zamalang/Conversion/HLFHEToMidLFHE/Patterns.h @@ -62,6 +62,12 @@ createApplyLookupTableGLWEOpFromHLFHE(mlir::PatternRewriter rewriter, } // namespace zamalang } // namespace mlir +namespace { #include "zamalang/Conversion/HLFHEToMidLFHE/Patterns.h.inc" +} + +void populateWithGeneratedHLFHEToMidLFHE(mlir::RewritePatternSet &patterns) { + populateWithGenerated(patterns); +} #endif diff --git a/compiler/include/zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h b/compiler/include/zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h new file mode 100644 index 000000000..564d1ba28 --- /dev/null +++ b/compiler/include/zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h @@ -0,0 +1,55 @@ +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/IR/PatternMatch.h" + +/// LinalgGenericTypeConverterPattern is a rewrite pattern that convert types +/// `linalg.generic` operation, using a specific `typeConverter` +template +struct LinalgGenericTypeConverterPattern + : public mlir::OpRewritePattern { + LinalgGenericTypeConverterPattern(mlir::MLIRContext *context, + TypeConverter &converter, + mlir::PatternBenefit benefit = 100) + : mlir::OpRewritePattern(context, benefit), + converter(converter) {} + + mlir::LogicalResult + matchAndRewrite(mlir::linalg::GenericOp op, + mlir::PatternRewriter &rewriter) const override { + + rewriter.startRootUpdate(op); + // Rewrite arguments + { + for (auto i = 0; i < op->getNumOperands(); i++) { + auto operand = op->getOperand(i); + mlir::Type type = converter.convertType(operand.getType()); + if (type != mlir::Type()) { + operand.setType(type); + } + } + } + // Rewrite results + { + for (auto i = 0; i < op->getNumResults(); i++) { + auto result = op->getResult(i); + mlir::Type type = converter.convertType(result.getType()); + if (type != mlir::Type()) { + result.setType(type); + } + } + } + // Rewrite block arguments + mlir::Region ®ion = op->getRegion(0); + mlir::Block *entry = ®ion.front(); + for (auto arg : entry->getArguments()) { + mlir::Type type = converter.convertType(arg.getType()); + if (type != mlir::Type()) { + arg.setType(type); + } + } + rewriter.finalizeRootUpdate(op); + return mlir::success(); + } + +private: + TypeConverter &converter; +}; diff --git a/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp b/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp index aeba4403b..acb31956f 100644 --- a/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp +++ b/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp @@ -5,6 +5,7 @@ #include "zamalang/Conversion/HLFHEToMidLFHE/Patterns.h" #include "zamalang/Conversion/Passes.h" +#include "zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h" #include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" #include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" @@ -64,55 +65,6 @@ public: } }; -/// LinalgGenericPattern is a rewrite pattern that convert types from HLFHE to -/// MidLFHE of `linalg.generic` operation -struct LinalgGenericPattern - : public mlir::OpRewritePattern { - LinalgGenericPattern(mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 100) - : mlir::OpRewritePattern(context, benefit) {} - - mlir::LogicalResult - matchAndRewrite(mlir::linalg::GenericOp op, - mlir::PatternRewriter &rewriter) const override { - - HLFHEToMidLFHETypeConverter converter; - - rewriter.startRootUpdate(op); - // Rewrite arguments - { - for (auto i = 0; i < op->getNumOperands(); i++) { - auto operand = op->getOperand(i); - mlir::Type type = converter.convertType(operand.getType()); - if (type != mlir::Type()) { - operand.setType(type); - } - } - } - // Rewrite results - { - for (auto i = 0; i < op->getNumResults(); i++) { - auto result = op->getResult(i); - mlir::Type type = converter.convertType(result.getType()); - if (type != mlir::Type()) { - result.setType(type); - } - } - } - // Rewrite block arguments - mlir::Region ®ion = op->getRegion(0); - mlir::Block *entry = ®ion.front(); - for (auto arg : entry->getArguments()) { - mlir::Type type = converter.convertType(arg.getType()); - if (type != mlir::Type()) { - arg.setType(type); - } - } - rewriter.finalizeRootUpdate(op); - return mlir::success(); - } -}; - void HLFHEToMidLFHEPass::runOnOperation() { auto op = this->getOperation(); @@ -149,8 +101,9 @@ void HLFHEToMidLFHEPass::runOnOperation() { // `MidLFHE` mlir::OwningRewritePatternList patterns(&getContext()); - populateWithGenerated(patterns); - patterns.add(&getContext()); + populateWithGeneratedHLFHEToMidLFHE(patterns); + patterns.add>( + &getContext(), converter); mlir::populateFuncOpTypeConversionPattern(patterns, converter); // Apply conversion