refactor(compiler): generalize LinalgGenericPattern

also fixes an issue regarding populateWithGenerated, which can be
duplicated across different pattern files. So I redefined a different
function that is more unique to the pass that should be ran, and hide
the populateWithGenerated from the global namespace
This commit is contained in:
youben11
2021-08-06 09:14:00 +01:00
committed by Ayoub Benaissa
parent 7a2511b3d4
commit 3d8d5c438a
3 changed files with 65 additions and 51 deletions

View File

@@ -62,6 +62,12 @@ createApplyLookupTableGLWEOpFromHLFHE(mlir::PatternRewriter rewriter,
} // namespace zamalang
} // namespace mlir
namespace {
#include "zamalang/Conversion/HLFHEToMidLFHE/Patterns.h.inc"
}
void populateWithGeneratedHLFHEToMidLFHE(mlir::RewritePatternSet &patterns) {
populateWithGenerated(patterns);
}
#endif

View File

@@ -0,0 +1,55 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/IR/PatternMatch.h"
/// LinalgGenericTypeConverterPattern is a rewrite pattern that convert types
/// `linalg.generic` operation, using a specific `typeConverter`
template <typename TypeConverter>
struct LinalgGenericTypeConverterPattern
: public mlir::OpRewritePattern<mlir::linalg::GenericOp> {
LinalgGenericTypeConverterPattern(mlir::MLIRContext *context,
TypeConverter &converter,
mlir::PatternBenefit benefit = 100)
: mlir::OpRewritePattern<mlir::linalg::GenericOp>(context, benefit),
converter(converter) {}
mlir::LogicalResult
matchAndRewrite(mlir::linalg::GenericOp op,
mlir::PatternRewriter &rewriter) const override {
rewriter.startRootUpdate(op);
// Rewrite arguments
{
for (auto i = 0; i < op->getNumOperands(); i++) {
auto operand = op->getOperand(i);
mlir::Type type = converter.convertType(operand.getType());
if (type != mlir::Type()) {
operand.setType(type);
}
}
}
// Rewrite results
{
for (auto i = 0; i < op->getNumResults(); i++) {
auto result = op->getResult(i);
mlir::Type type = converter.convertType(result.getType());
if (type != mlir::Type()) {
result.setType(type);
}
}
}
// Rewrite block arguments
mlir::Region &region = op->getRegion(0);
mlir::Block *entry = &region.front();
for (auto arg : entry->getArguments()) {
mlir::Type type = converter.convertType(arg.getType());
if (type != mlir::Type()) {
arg.setType(type);
}
}
rewriter.finalizeRootUpdate(op);
return mlir::success();
}
private:
TypeConverter &converter;
};

View File

@@ -5,6 +5,7 @@
#include "zamalang/Conversion/HLFHEToMidLFHE/Patterns.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
@@ -64,55 +65,6 @@ public:
}
};
/// LinalgGenericPattern is a rewrite pattern that convert types from HLFHE to
/// MidLFHE of `linalg.generic` operation
struct LinalgGenericPattern
: public mlir::OpRewritePattern<mlir::linalg::GenericOp> {
LinalgGenericPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 100)
: mlir::OpRewritePattern<mlir::linalg::GenericOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(mlir::linalg::GenericOp op,
mlir::PatternRewriter &rewriter) const override {
HLFHEToMidLFHETypeConverter converter;
rewriter.startRootUpdate(op);
// Rewrite arguments
{
for (auto i = 0; i < op->getNumOperands(); i++) {
auto operand = op->getOperand(i);
mlir::Type type = converter.convertType(operand.getType());
if (type != mlir::Type()) {
operand.setType(type);
}
}
}
// Rewrite results
{
for (auto i = 0; i < op->getNumResults(); i++) {
auto result = op->getResult(i);
mlir::Type type = converter.convertType(result.getType());
if (type != mlir::Type()) {
result.setType(type);
}
}
}
// Rewrite block arguments
mlir::Region &region = op->getRegion(0);
mlir::Block *entry = &region.front();
for (auto arg : entry->getArguments()) {
mlir::Type type = converter.convertType(arg.getType());
if (type != mlir::Type()) {
arg.setType(type);
}
}
rewriter.finalizeRootUpdate(op);
return mlir::success();
}
};
void HLFHEToMidLFHEPass::runOnOperation() {
auto op = this->getOperation();
@@ -149,8 +101,9 @@ void HLFHEToMidLFHEPass::runOnOperation() {
// `MidLFHE`
mlir::OwningRewritePatternList patterns(&getContext());
populateWithGenerated(patterns);
patterns.add<LinalgGenericPattern>(&getContext());
populateWithGeneratedHLFHEToMidLFHE(patterns);
patterns.add<LinalgGenericTypeConverterPattern<HLFHEToMidLFHETypeConverter>>(
&getContext(), converter);
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
// Apply conversion