enhance(compiler): Add tensor ops type rewriting on high level pipepline

This commit is contained in:
Quentin Bourgerie
2021-08-24 09:50:48 +02:00
parent ce776c0eba
commit dba76a1e1b
4 changed files with 93 additions and 0 deletions

View File

@@ -0,0 +1,60 @@
#ifndef ZAMALANG_CONVERSION_GENERICOPTYPECONVERSIONPATTERN_H_
#define ZAMALANG_CONVERSION_GENERICOPTYPECONVERSIONPATTERN_H_
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
namespace zamalang {
template <typename Op>
struct GenericTypeConverterPattern : public mlir::OpRewritePattern<Op> {
GenericTypeConverterPattern(mlir::MLIRContext *context,
mlir::TypeConverter &converter,
mlir::PatternBenefit benefit = 100)
: mlir::OpRewritePattern<Op>(context, benefit), converter(converter) {}
mlir::LogicalResult
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
rewriter.startRootUpdate(op);
// Rewrite arguments
{
for (auto i = 0; i < op->getNumOperands(); i++) {
auto operand = op->getOperand(i);
mlir::Type type = converter.convertType(operand.getType());
if (type != mlir::Type()) {
operand.setType(type);
}
}
}
// Rewrite results
{
for (auto i = 0; i < op->getNumResults(); i++) {
auto result = op->getResult(i);
mlir::Type type = converter.convertType(result.getType());
if (type != mlir::Type()) {
result.setType(type);
}
}
}
rewriter.finalizeRootUpdate(op);
return mlir::success();
}
private:
mlir::TypeConverter &converter;
};
template <typename Op>
void addDynamicallyLegalTypeOp(mlir::ConversionTarget &target,
mlir::TypeConverter &typeConverter) {
target.addDynamicallyLegalOp<Op>([&](Op op) {
return typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes());
});
}
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -0,0 +1,27 @@
#ifndef ZAMALANG_CONVERSION_TENSOROPTYPECONVERSIONPATTERN_H_
#define ZAMALANG_CONVERSION_TENSOROPTYPECONVERSIONPATTERN_H_
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "zamalang/Conversion/Utils/GenericOpTypeConversionPattern.h"
namespace mlir {
namespace zamalang {
inline void
populateWithTensorTypeConverterPatterns(mlir::RewritePatternSet &patterns,
mlir::ConversionTarget &target,
mlir::TypeConverter &typeConverter) {
patterns.add<GenericTypeConverterPattern<mlir::tensor::ExtractOp>>(
patterns.getContext(), typeConverter);
addDynamicallyLegalTypeOp<mlir::tensor::ExtractOp>(target, typeConverter);
patterns.add<GenericTypeConverterPattern<mlir::tensor::FromElementsOp>>(
patterns.getContext(), typeConverter);
addDynamicallyLegalTypeOp<mlir::tensor::FromElementsOp>(target,
typeConverter);
}
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -6,6 +6,7 @@
#include "zamalang/Conversion/HLFHEToMidLFHE/Patterns.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h"
#include "zamalang/Conversion/Utils/TensorOpTypeConversion.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
@@ -77,6 +78,8 @@ void HLFHEToMidLFHEPass::runOnOperation() {
populateWithGeneratedHLFHEToMidLFHE(patterns);
patterns.add<LinalgGenericTypeConverterPattern<HLFHEToMidLFHETypeConverter>>(
&getContext(), converter);
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
converter);
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
// Apply conversion

View File

@@ -6,6 +6,7 @@
#include "zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h"
#include "zamalang/Conversion/Utils/TensorOpTypeConversion.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
@@ -77,6 +78,8 @@ void MidLFHEToLowLFHEPass::runOnOperation() {
patterns
.add<LinalgGenericTypeConverterPattern<MidLFHEToLowLFHETypeConverter>>(
&getContext(), converter);
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
converter);
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
// Apply conversion