From dba76a1e1b728b747e876210621d52e03a657cec Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Tue, 24 Aug 2021 09:50:48 +0200 Subject: [PATCH] enhance(compiler): Add tensor ops type rewriting on high level pipepline --- .../Utils/GenericOpTypeConversionPattern.h | 60 +++++++++++++++++++ .../Conversion/Utils/TensorOpTypeConversion.h | 27 +++++++++ .../HLFHEToMidLFHE/HLFHEToMidLFHE.cpp | 3 + .../MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp | 3 + 4 files changed, 93 insertions(+) create mode 100644 compiler/include/zamalang/Conversion/Utils/GenericOpTypeConversionPattern.h create mode 100644 compiler/include/zamalang/Conversion/Utils/TensorOpTypeConversion.h diff --git a/compiler/include/zamalang/Conversion/Utils/GenericOpTypeConversionPattern.h b/compiler/include/zamalang/Conversion/Utils/GenericOpTypeConversionPattern.h new file mode 100644 index 000000000..4cdad5f9b --- /dev/null +++ b/compiler/include/zamalang/Conversion/Utils/GenericOpTypeConversionPattern.h @@ -0,0 +1,60 @@ +#ifndef ZAMALANG_CONVERSION_GENERICOPTYPECONVERSIONPATTERN_H_ +#define ZAMALANG_CONVERSION_GENERICOPTYPECONVERSIONPATTERN_H_ + +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace zamalang { +template +struct GenericTypeConverterPattern : public mlir::OpRewritePattern { + GenericTypeConverterPattern(mlir::MLIRContext *context, + mlir::TypeConverter &converter, + mlir::PatternBenefit benefit = 100) + : mlir::OpRewritePattern(context, benefit), converter(converter) {} + + mlir::LogicalResult + matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { + + rewriter.startRootUpdate(op); + // Rewrite arguments + { + for (auto i = 0; i < op->getNumOperands(); i++) { + auto operand = op->getOperand(i); + mlir::Type type = converter.convertType(operand.getType()); + if (type != mlir::Type()) { + operand.setType(type); + } + } + } + // Rewrite results + { + for (auto i = 0; i < op->getNumResults(); i++) { + auto result = op->getResult(i); + mlir::Type type = converter.convertType(result.getType()); + if (type != mlir::Type()) { + result.setType(type); + } + } + } + rewriter.finalizeRootUpdate(op); + return mlir::success(); + } + +private: + mlir::TypeConverter &converter; +}; + +template +void addDynamicallyLegalTypeOp(mlir::ConversionTarget &target, + mlir::TypeConverter &typeConverter) { + target.addDynamicallyLegalOp([&](Op op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); +} + +} // namespace zamalang +} // namespace mlir + +#endif \ No newline at end of file diff --git a/compiler/include/zamalang/Conversion/Utils/TensorOpTypeConversion.h b/compiler/include/zamalang/Conversion/Utils/TensorOpTypeConversion.h new file mode 100644 index 000000000..c66572636 --- /dev/null +++ b/compiler/include/zamalang/Conversion/Utils/TensorOpTypeConversion.h @@ -0,0 +1,27 @@ +#ifndef ZAMALANG_CONVERSION_TENSOROPTYPECONVERSIONPATTERN_H_ +#define ZAMALANG_CONVERSION_TENSOROPTYPECONVERSIONPATTERN_H_ + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "zamalang/Conversion/Utils/GenericOpTypeConversionPattern.h" + +namespace mlir { +namespace zamalang { + +inline void +populateWithTensorTypeConverterPatterns(mlir::RewritePatternSet &patterns, + mlir::ConversionTarget &target, + mlir::TypeConverter &typeConverter) { + patterns.add>( + patterns.getContext(), typeConverter); + addDynamicallyLegalTypeOp(target, typeConverter); + patterns.add>( + patterns.getContext(), typeConverter); + addDynamicallyLegalTypeOp(target, + typeConverter); +} +} // namespace zamalang +} // namespace mlir + +#endif \ No newline at end of file diff --git a/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp b/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp index 5774c95db..cf614d712 100644 --- a/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp +++ b/compiler/lib/Conversion/HLFHEToMidLFHE/HLFHEToMidLFHE.cpp @@ -6,6 +6,7 @@ #include "zamalang/Conversion/HLFHEToMidLFHE/Patterns.h" #include "zamalang/Conversion/Passes.h" #include "zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h" +#include "zamalang/Conversion/Utils/TensorOpTypeConversion.h" #include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" #include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" @@ -77,6 +78,8 @@ void HLFHEToMidLFHEPass::runOnOperation() { populateWithGeneratedHLFHEToMidLFHE(patterns); patterns.add>( &getContext(), converter); + mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target, + converter); mlir::populateFuncOpTypeConversionPattern(patterns, converter); // Apply conversion diff --git a/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp b/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp index e55155d68..9e63d906f 100644 --- a/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp +++ b/compiler/lib/Conversion/MidLFHEToLowLFHE/MidLFHEToLowLFHE.cpp @@ -6,6 +6,7 @@ #include "zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h" #include "zamalang/Conversion/Passes.h" #include "zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h" +#include "zamalang/Conversion/Utils/TensorOpTypeConversion.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" @@ -77,6 +78,8 @@ void MidLFHEToLowLFHEPass::runOnOperation() { patterns .add>( &getContext(), converter); + mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target, + converter); mlir::populateFuncOpTypeConversionPattern(patterns, converter); // Apply conversion