mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(compiler): Add lowering for HLFHELinalg.zero to linalg.generate
Add a rewrite pattern that transforms an instance of
`HLFHELinalg.zero` into an instance of `linalg.generate` with an
appropriate region yielding a zero value.
Example:
%out = "HLFHELinalg.zero"() : () -> tensor<MxNx!HLFHE.eint<p>>
becomes:
%0 = tensor.generate {
^bb0(%arg2: index, %arg3: index):
%zero = "HLFHE.zero"() : () -> !HLFHE.eint<p>
tensor.yield %zero : !HLFHE.eint<p>
} : tensor<MxNx!HLFHE.eint<p>>
This commit is contained in:
@@ -736,6 +736,54 @@ private:
|
||||
createMulOp;
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of operators
|
||||
// `HLFHELinalg.zero` to an instance of `linalg.generate` with an
|
||||
// appropriate region yielding a zero value.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// %out = "HLFHELinalg.zero"() : () -> tensor<MxNx!HLFHE.eint<p>>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// %0 = tensor.generate {
|
||||
// ^bb0(%arg2: index, %arg3: index):
|
||||
// %zero = "HLFHE.zero"() : () -> !HLFHE.eint<p>
|
||||
// tensor.yield %zero : !HLFHE.eint<p>
|
||||
// } : tensor<MxNx!HLFHE.eint<p>>
|
||||
//
|
||||
struct HLFHELinalgZeroToLinalgGenerate
|
||||
: public mlir::OpRewritePattern<mlir::zamalang::HLFHELinalg::ZeroOp> {
|
||||
HLFHELinalgZeroToLinalgGenerate(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::zamalang::HLFHELinalg::ZeroOp>(context,
|
||||
benefit) {
|
||||
}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::zamalang::HLFHELinalg::ZeroOp zeroOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::RankedTensorType resultTy =
|
||||
zeroOp->getResult(0).getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
auto generateBody = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
mlir::ValueRange blockArgs) {
|
||||
mlir::Value zeroScalar =
|
||||
nestedBuilder.create<mlir::zamalang::HLFHE::ZeroEintOp>(
|
||||
zeroOp.getLoc(), resultTy.getElementType());
|
||||
nestedBuilder.create<mlir::tensor::YieldOp>(zeroOp.getLoc(), zeroScalar);
|
||||
};
|
||||
mlir::tensor::GenerateOp generateOp =
|
||||
rewriter.create<mlir::tensor::GenerateOp>(
|
||||
zeroOp.getLoc(), resultTy, mlir::ValueRange{}, generateBody);
|
||||
|
||||
rewriter.replaceOp(zeroOp, {generateOp.getResult()});
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct HLFHETensorOpsToLinalg
|
||||
: public HLFHETensorOpsToLinalgBase<HLFHETensorOpsToLinalg> {
|
||||
@@ -793,6 +841,7 @@ void HLFHETensorOpsToLinalg::runOnFunction() {
|
||||
});
|
||||
patterns.insert<HLFHELinalgApplyMultiLookupTableToLinalgGeneric>(
|
||||
&getContext());
|
||||
patterns.insert<HLFHELinalgZeroToLinalgGenerate>(&getContext());
|
||||
|
||||
if (mlir::applyPartialConversion(function, target, std::move(patterns))
|
||||
.failed())
|
||||
|
||||
Reference in New Issue
Block a user