feat(compiler): Add lowering pass from HLFHE.dot_eint_int to linalg.generic

This pass transforms any instance of `HLFHE.dot_eint_int` to an
instance of `linalg.generic` with an appropriate region using
`HLFHE.mul_eint_int` and `HLFHE.add_eint` operations and an
appropriate specification for the iteration dimensions.

Example:

  "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
    (memref<?x!HLFHE.eint<0>>, memref<?xi32>, memref<!HLFHE.eint<0>>) -> ()

becomes:

  linalg.generic {
    indexing_maps = [affine_map<(d0) -> (d0)>,
                     affine_map<(d0) -> (d0)>,
		     affine_map<(d0) -> ()>],
    iterator_types = ["reduction"]
  } ins(%arg0, %arg1 : memref<?x!HLFHE.eint<0>>, memref<?xi32>) outs(%arg2 : memref<!HLFHE.eint<0>>) {
    ^bb0(%arg3: !HLFHE.eint<0>, %arg4: i32, %arg5: !HLFHE.eint<0>):  // no predecessors
      %0 = "HLFHE.mul_eint_int"(%arg3, %arg4) : (!HLFHE.eint<0>, i32) -> !HLFHE.eint<0>
      %1 = "HLFHE.add_eint"(%0, %arg5) : (!HLFHE.eint<0>, !HLFHE.eint<0>) -> !HLFHE.eint<0>
      linalg.yield %1 : !HLFHE.eint<0>
  }
This commit is contained in:
Andi Drebes
2021-07-05 16:37:06 +02:00
committed by Quentin Bourgerie
parent b433627821
commit 9d1cdc6a0c
4 changed files with 157 additions and 0 deletions

View File

@@ -0,0 +1,14 @@
#ifndef ZAMALANG_DIALECT_HLFHE_TRANSFORMS_TENSOROPSTOLINALG_H
#define ZAMALANG_DIALECT_HLFHE_TRANSFORMS_TENSOROPSTOLINALG_H
#include <mlir/Pass/Pass.h>
namespace mlir {
namespace zamalang {
namespace HLFHE {
std::unique_ptr<mlir::Pass> createLowerTensorOpsToLinalgPass();
} // namespace HLFHE
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -1 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,14 @@
add_mlir_dialect_library(HLFHEDialectTransforms
TensorOpsToLinalg.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/zamalang/Dialect/HLFHE
DEPENDS
HLFHEDialect
LINK_LIBS PUBLIC
MLIRIR
HLFHEDialect)
target_link_libraries(HLFHEDialect PUBLIC MLIRIR)

View File

@@ -0,0 +1,128 @@
#include "zamalang/Dialect/HLFHE/Transforms/TensorOpsToLinalg.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/OperationSupport.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "llvm/ADT/SmallVector.h"
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Transforms/DialectConversion.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHEOps.h>
struct DotToLinalgGeneric : public ::mlir::RewritePattern {
DotToLinalgGeneric(::mlir::MLIRContext *context)
: ::mlir::RewritePattern("HLFHE.dot_eint_int", 1, context,
{"linalg.generic"}) {}
// This rewrite pattern transforms any instance of
// `HLFHE.dot_eint_int` to an instance of `linalg.generic` with an
// appropriate region using `HLFHE.mul_eint_int` and
// `HLFHE.add_eint` operations and an appropriate specification for
// the iteration dimensions.
//
// Example:
//
// "HLFHE.dot_eint_int"(%arg0, %arg1, %arg2) :
// (memref<?x!HLFHE.eint<0>>,
// memref<?xi32>,
// memref<!HLFHE.eint<0>>) -> ()
//
// becomes:
//
// linalg.generic {
// indexing_maps = [affine_map<(d0) -> (d0)>,
// affine_map<(d0) -> (d0)>,
// affine_map<(d0) -> ()>],
// iterator_types = ["reduction"]
// } ins(%arg0, %arg1 : memref<?x!HLFHE.eint<0>>, memref<?xi32>)
// outs(%arg2: memref<!HLFHE.eint<0>>)
// {
// ^bb0(%arg3: !HLFHE.eint<0>, %arg4: i32, %arg5: !HLFHE.eint<0>):
// %0 = "HLFHE.mul_eint_int"(%arg3, %arg4) : (!HLFHE.eint<0>, i32) ->
// !HLFHE.eint<0> %1 = "HLFHE.add_eint"(%0, %arg5) : (!HLFHE.eint<0>,
// !HLFHE.eint<0>) -> !HLFHE.eint<0> linalg.yield %1 : !HLFHE.eint<0>
// }
//
::mlir::LogicalResult
matchAndRewrite(::mlir::Operation *op0,
::mlir::PatternRewriter &rewriter) const override {
::mlir::zamalang::HLFHE::Dot &&dotOp =
::llvm::dyn_cast_or_null<::mlir::zamalang::HLFHE::Dot>(op0);
mlir::TypeRange resTypes{};
llvm::SmallVector<mlir::Value, 2> ins{dotOp.lhs(), dotOp.rhs()};
llvm::SmallVector<mlir::Value, 1> outs{dotOp.out()};
llvm::SmallVector<mlir::AffineMap, 3> maps{
mlir::AffineMap::getMultiDimIdentityMap(1, this->getContext()),
mlir::AffineMap::getMultiDimIdentityMap(1, this->getContext()),
mlir::AffineMap::get(1, 0, this->getContext())};
llvm::SmallVector<llvm::StringRef, 1> itTypes{"reduction"};
llvm::StringRef doc{""};
llvm::StringRef call{""};
auto regBuilder = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
mlir::zamalang::HLFHE::MulEintIntOp mul =
nestedBuilder.create<mlir::zamalang::HLFHE::MulEintIntOp>(
dotOp.getLoc(), blockArgs[0], blockArgs[1]);
mlir::zamalang::HLFHE::AddEintOp add =
nestedBuilder.create<mlir::zamalang::HLFHE::AddEintOp>(
dotOp.getLoc(), mul, blockArgs[2]);
nestedBuilder.create<mlir::linalg::YieldOp>(dotOp.getLoc(),
add.getResult());
};
mlir::linalg::GenericOp gop = rewriter.create<mlir::linalg::GenericOp>(
dotOp.getLoc(), resTypes, ins, outs, maps, itTypes, doc, call,
regBuilder);
rewriter.replaceOp(op0, {gop.getODSResults(0)});
return ::mlir::success();
};
};
namespace {
struct LowerTensorOpsToLinalgPass
: public mlir::PassWrapper<LowerTensorOpsToLinalgPass, mlir::FunctionPass> {
void getDependentDialects(mlir::DialectRegistry &registry) const override {
registry.insert<mlir::linalg::LinalgDialect>();
}
void runOnFunction() final;
};
void LowerTensorOpsToLinalgPass::runOnFunction() {
mlir::FuncOp function = this->getFunction();
mlir::ConversionTarget target(getContext());
target.addLegalDialect<mlir::linalg::LinalgDialect>();
target.addLegalDialect<mlir::StandardOpsDialect>();
target.addLegalDialect<mlir::memref::MemRefDialect>();
target.addLegalDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
target.addIllegalOp<mlir::zamalang::HLFHE::Dot>();
mlir::OwningRewritePatternList patterns(&getContext());
patterns.insert<DotToLinalgGeneric>(&getContext());
if (mlir::applyPartialConversion(function, target, std::move(patterns))
.failed())
this->signalPassFailure();
}
} // namespace
namespace mlir {
namespace zamalang {
namespace HLFHE {
std::unique_ptr<mlir::Pass> createLowerTensorOpsToLinalgPass() {
return std::make_unique<LowerTensorOpsToLinalgPass>();
}
} // namespace HLFHE
} // namespace zamalang
} // namespace mlir