feat(compiler): Add lowering for HLFHELinalg.zero to linalg.generate

Add a rewrite pattern that transforms an instance of
`HLFHELinalg.zero` into an instance of `linalg.generate` with an
appropriate region yielding a zero value.

Example:

  %out = "HLFHELinalg.zero"() : () -> tensor<MxNx!HLFHE.eint<p>>

becomes:

  %0 = tensor.generate   {
    ^bb0(%arg2: index, %arg3: index):
       %zero = "HLFHE.zero"() : () -> !HLFHE.eint<p>
       tensor.yield %zero : !HLFHE.eint<p>
  } : tensor<MxNx!HLFHE.eint<p>>
This commit is contained in:
Andi Drebes
2021-11-23 15:47:28 +01:00
parent bf9a831c3d
commit 6fb907295d

View File

@@ -736,6 +736,54 @@ private:
createMulOp;
};
// This rewrite pattern transforms any instance of operators
// `HLFHELinalg.zero` to an instance of `linalg.generate` with an
// appropriate region yielding a zero value.
//
// Example:
//
// %out = "HLFHELinalg.zero"() : () -> tensor<MxNx!HLFHE.eint<p>>
//
// becomes:
//
// %0 = tensor.generate {
// ^bb0(%arg2: index, %arg3: index):
// %zero = "HLFHE.zero"() : () -> !HLFHE.eint<p>
// tensor.yield %zero : !HLFHE.eint<p>
// } : tensor<MxNx!HLFHE.eint<p>>
//
struct HLFHELinalgZeroToLinalgGenerate
: public mlir::OpRewritePattern<mlir::zamalang::HLFHELinalg::ZeroOp> {
HLFHELinalgZeroToLinalgGenerate(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<mlir::zamalang::HLFHELinalg::ZeroOp>(context,
benefit) {
}
::mlir::LogicalResult
matchAndRewrite(mlir::zamalang::HLFHELinalg::ZeroOp zeroOp,
::mlir::PatternRewriter &rewriter) const override {
mlir::RankedTensorType resultTy =
zeroOp->getResult(0).getType().cast<mlir::RankedTensorType>();
auto generateBody = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
mlir::Value zeroScalar =
nestedBuilder.create<mlir::zamalang::HLFHE::ZeroEintOp>(
zeroOp.getLoc(), resultTy.getElementType());
nestedBuilder.create<mlir::tensor::YieldOp>(zeroOp.getLoc(), zeroScalar);
};
mlir::tensor::GenerateOp generateOp =
rewriter.create<mlir::tensor::GenerateOp>(
zeroOp.getLoc(), resultTy, mlir::ValueRange{}, generateBody);
rewriter.replaceOp(zeroOp, {generateOp.getResult()});
return ::mlir::success();
};
};
namespace {
struct HLFHETensorOpsToLinalg
: public HLFHETensorOpsToLinalgBase<HLFHETensorOpsToLinalg> {
@@ -793,6 +841,7 @@ void HLFHETensorOpsToLinalg::runOnFunction() {
});
patterns.insert<HLFHELinalgApplyMultiLookupTableToLinalgGeneric>(
&getContext());
patterns.insert<HLFHELinalgZeroToLinalgGenerate>(&getContext());
if (mlir::applyPartialConversion(function, target, std::move(patterns))
.failed())