From 6fb907295dfeacb8de5fbf10d538d532736d5e85 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Tue, 23 Nov 2021 15:47:28 +0100 Subject: [PATCH] feat(compiler): Add lowering for HLFHELinalg.zero to linalg.generate Add a rewrite pattern that transforms an instance of `HLFHELinalg.zero` into an instance of `linalg.generate` with an appropriate region yielding a zero value. Example: %out = "HLFHELinalg.zero"() : () -> tensor> becomes: %0 = tensor.generate { ^bb0(%arg2: index, %arg3: index): %zero = "HLFHE.zero"() : () -> !HLFHE.eint

tensor.yield %zero : !HLFHE.eint

} : tensor> --- .../TensorOpsToLinalg.cpp | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp index a4db2c5de..6c09ff285 100644 --- a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -736,6 +736,54 @@ private: createMulOp; }; +// This rewrite pattern transforms any instance of operators +// `HLFHELinalg.zero` to an instance of `linalg.generate` with an +// appropriate region yielding a zero value. +// +// Example: +// +// %out = "HLFHELinalg.zero"() : () -> tensor> +// +// becomes: +// +// %0 = tensor.generate { +// ^bb0(%arg2: index, %arg3: index): +// %zero = "HLFHE.zero"() : () -> !HLFHE.eint

+// tensor.yield %zero : !HLFHE.eint

+// } : tensor> +// +struct HLFHELinalgZeroToLinalgGenerate + : public mlir::OpRewritePattern { + HLFHELinalgZeroToLinalgGenerate(::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(context, + benefit) { + } + + ::mlir::LogicalResult + matchAndRewrite(mlir::zamalang::HLFHELinalg::ZeroOp zeroOp, + ::mlir::PatternRewriter &rewriter) const override { + mlir::RankedTensorType resultTy = + zeroOp->getResult(0).getType().cast(); + + auto generateBody = [&](mlir::OpBuilder &nestedBuilder, + mlir::Location nestedLoc, + mlir::ValueRange blockArgs) { + mlir::Value zeroScalar = + nestedBuilder.create( + zeroOp.getLoc(), resultTy.getElementType()); + nestedBuilder.create(zeroOp.getLoc(), zeroScalar); + }; + mlir::tensor::GenerateOp generateOp = + rewriter.create( + zeroOp.getLoc(), resultTy, mlir::ValueRange{}, generateBody); + + rewriter.replaceOp(zeroOp, {generateOp.getResult()}); + + return ::mlir::success(); + }; +}; + namespace { struct HLFHETensorOpsToLinalg : public HLFHETensorOpsToLinalgBase { @@ -793,6 +841,7 @@ void HLFHETensorOpsToLinalg::runOnFunction() { }); patterns.insert( &getContext()); + patterns.insert(&getContext()); if (mlir::applyPartialConversion(function, target, std::move(patterns)) .failed())