feat(compiler): lower HLFHELinalg.neg_eint

This commit is contained in:
youben11
2021-11-09 11:37:30 +01:00
parent 99d6d11616
commit 6df9f09e48
3 changed files with 167 additions and 5 deletions

View File

@@ -25,8 +25,8 @@ struct DotToLinalgGeneric
// This rewrite pattern transforms any instance of
// `HLFHELinalg.dot_eint_int` to an instance of `linalg.generic` with an
// appropriate region using `HLFHE.mul_eint_int` and
// `HLFHELinalg.add_eint` operations, an appropriate specification for the
// iteration dimensions and appropriate operaztions managing the
// `HLFHE.add_eint` operations, an appropriate specification for the
// iteration dimensions and appropriate operations managing the
// accumulator of `linalg.generic`.
//
// Example:
@@ -145,7 +145,7 @@ getBroadcastedAffineMap(const mlir::RankedTensorType &resultType,
// operators `HLFHELinalgOp` that implements the broadasting rules to an
// instance of `linalg.generic` with an appropriate region using `HLFHEOp`
// operation, an appropriate specification for the iteration dimensions and
// appropriate operaztions managing the accumulator of `linalg.generic`.
// appropriate operations managing the accumulator of `linalg.generic`.
//
// Example:
//
@@ -244,7 +244,7 @@ struct HLFHELinalgOpToLinalgGeneric
// operators `HLFHELinalg.apply_lookup_table` that implements the broadasting
// rules to an instance of `linalg.generic` with an appropriate region using
// `HLFHE.apply_lookup_table` operation, an appropriate specification for the
// iteration dimensions and appropriate operaztions managing the accumulator of
// iteration dimensions and appropriate operations managing the accumulator of
// `linalg.generic`.
//
// Example:
@@ -341,6 +341,101 @@ struct HLFHELinalgApplyLookupTableToLinalgGeneric
};
};
// This template rewrite pattern transforms any instance of
// operators `HLFHELinalg.neg_eint` to an instance of `linalg.generic` with an
// appropriate region using `HLFHE.neg_eint` operation, an appropriate
// specification for the iteration dimensions and appropriate operations
// managing the accumulator of `linalg.generic`.
//
// Example:
//
// HLFHELinalg.neg_eint(%tensor):
// tensor<DNx...xD1x!HLFHE.eint<p>> -> tensor<DNx...xD1x!HLFHE.eint<p'>>
//
// becomes:
//
// #maps_0 = [
// affine_map<(aN, ..., a1) -> (aN, ..., a1)>,
// affine_map<(aN, ..., a1) -> (aN, ..., a1)>
// ]
// #attributes_0 {
// indexing_maps = #maps_0,
// iterator_types = ["parallel",..],//N parallel
// }
// %init = linalg.init_tensor [DN,...,D1]
// : tensor<DNx...xD1x!HLFHE.eint<p'>>
// %res = linalg.generic {
// ins(%tensor: tensor<DNx...xD1x!HLFHE.eint<p>>)
// outs(%init : tensor<DNx...xD1x!HLFHE.eint<p'>>)
// {
// ^bb0(%arg0: !HLFHE.eint<p>):
// %0 = HLFHE.neg_eint(%arg0): !HLFHE.eint<p> -> !HLFHE.eint<p'>
// linalg.yield %0 : !HLFHE.eint<p'>
// }
// }
//
struct HLFHELinalgNegEintToLinalgGeneric
: public mlir::OpRewritePattern<mlir::zamalang::HLFHELinalg::NegEintOp> {
HLFHELinalgNegEintToLinalgGeneric(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<mlir::zamalang::HLFHELinalg::NegEintOp>(
context, benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::zamalang::HLFHELinalg::NegEintOp negEintOp,
::mlir::PatternRewriter &rewriter) const override {
mlir::RankedTensorType resultTy =
((mlir::Type)negEintOp->getResult(0).getType())
.cast<mlir::RankedTensorType>();
mlir::RankedTensorType tensorTy = ((mlir::Type)negEintOp.tensor().getType())
.cast<mlir::RankedTensorType>();
// linalg.init_tensor for initial value
mlir::Value init = rewriter.create<mlir::linalg::InitTensorOp>(
negEintOp.getLoc(), resultTy.getShape(), resultTy.getElementType());
// Create the affine #maps_0
llvm::SmallVector<mlir::AffineMap, 2> maps{
mlir::AffineMap::getMultiDimIdentityMap(tensorTy.getShape().size(),
this->getContext()),
mlir::AffineMap::getMultiDimIdentityMap(resultTy.getShape().size(),
this->getContext()),
};
// Create the iterator_types
llvm::SmallVector<llvm::StringRef> iteratorTypes(resultTy.getShape().size(),
"parallel");
// Create the body of the `linalg.generic` op
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
mlir::zamalang::HLFHE::NegEintOp hlfheOp =
nestedBuilder.create<mlir::zamalang::HLFHE::NegEintOp>(
negEintOp.getLoc(), resultTy.getElementType(), blockArgs[0]);
nestedBuilder.create<mlir::linalg::YieldOp>(negEintOp.getLoc(),
hlfheOp.getResult());
};
// Create the `linalg.generic` op
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
llvm::SmallVector<mlir::Value, 1> ins{negEintOp.tensor()};
llvm::SmallVector<mlir::Value, 1> outs{init};
llvm::StringRef doc{""};
llvm::StringRef call{""};
mlir::linalg::GenericOp genericOp =
rewriter.create<mlir::linalg::GenericOp>(negEintOp.getLoc(), resTypes,
ins, outs, maps, iteratorTypes,
doc, call, bodyBuilder);
rewriter.replaceOp(negEintOp, {genericOp.getResult(0)});
return ::mlir::success();
};
};
namespace {
struct HLFHETensorOpsToLinalg
: public HLFHETensorOpsToLinalgBase<HLFHETensorOpsToLinalg> {
@@ -381,6 +476,7 @@ void HLFHETensorOpsToLinalg::runOnFunction() {
mlir::zamalang::HLFHE::MulEintIntOp>>(
&getContext());
patterns.insert<HLFHELinalgApplyLookupTableToLinalgGeneric>(&getContext());
patterns.insert<HLFHELinalgNegEintToLinalgGeneric>(&getContext());
if (mlir::applyPartialConversion(function, target, std::move(patterns))
.failed())

View File

@@ -0,0 +1,19 @@
// RUN: zamacompiler %s --action=dump-midlfhe --passes hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-NEXT: module {
// CHECK-NEXT: func @neg_eint(%arg0: tensor<2x3x4x!HLFHE.eint<2>>) -> tensor<2x3x4x!HLFHE.eint<2>> {
// CHECK-NEXT: %0 = linalg.init_tensor [2, 3, 4] : tensor<2x3x4x!HLFHE.eint<2>>
// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3x4x!HLFHE.eint<2>>) outs(%0 : tensor<2x3x4x!HLFHE.eint<2>>) {
// CHECK-NEXT: ^bb0(%arg1: !HLFHE.eint<2>, %arg2: !HLFHE.eint<2>): // no predecessors
// CHECK-NEXT: %2 = "HLFHE.neg_eint"(%arg1) : (!HLFHE.eint<2>) -> !HLFHE.eint<2>
// CHECK-NEXT: linalg.yield %2 : !HLFHE.eint<2>
// CHECK-NEXT: } -> tensor<2x3x4x!HLFHE.eint<2>>
// CHECK-NEXT: return %1 : tensor<2x3x4x!HLFHE.eint<2>>
// CHECK-NEXT: }
// CHECK-NEXT: }
func @neg_eint(%arg0: tensor<2x3x4x!HLFHE.eint<2>>) -> tensor<2x3x4x!HLFHE.eint<2>> {
%1 = "HLFHELinalg.neg_eint"(%arg0): (tensor<2x3x4x!HLFHE.eint<2>>) -> (tensor<2x3x4x!HLFHE.eint<2>>)
return %1: tensor<2x3x4x!HLFHE.eint<2>>
}

View File

@@ -1023,4 +1023,51 @@ func @main(%arg0: tensor<4x!HLFHE.eint<7>>,
lambda(arg0, ARRAY_SIZE(arg0), arg1, ARRAY_SIZE(arg1));
ASSERT_EXPECTED_VALUE(res, 14);
}
}
///////////////////////////////////////////////////////////////////////////////
// HLFHELinalg neg_eint /////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(End2EndJit_HLFHELinalg, neg_eint) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
// Returns the negation of a 3x3 matrix of encrypted integers of width 2.
//
// ([0,1,2]) [0,7,6]
// negate ([3,4,5]) = [5,4,3]
// ([6,7,0]) [2,1,0]
func @main(%t: tensor<3x3x!HLFHE.eint<2>>) -> tensor<3x3x!HLFHE.eint<2>> {
%res = "HLFHELinalg.neg_eint"(%t) : (tensor<3x3x!HLFHE.eint<2>>) -> tensor<3x3x!HLFHE.eint<2>>
return %res : tensor<3x3x!HLFHE.eint<2>>
}
)XXX");
const uint8_t t[3][3]{
{0, 1, 2},
{3, 4, 5},
{6, 7, 0},
};
const uint8_t expected[3][3]{
{0, 7, 6},
{5, 4, 3},
{2, 1, 0},
};
mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint8_t>>
tArg(llvm::MutableArrayRef<uint8_t>((uint8_t *)t, 3 * 3), {3, 3});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&tArg});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), 3 * 3);
for (size_t i = 0; i < 3; i++) {
for (size_t j = 0; j < 3; j++) {
EXPECT_EQ((*res)[i * 3 + j], expected[i][j])
<< ", at pos(" << i << "," << j << ")";
}
}
}