From 6df9f09e48ce8ffdc54c34e2d63effcb67269d69 Mon Sep 17 00:00:00 2001 From: youben11 Date: Tue, 9 Nov 2021 11:37:30 +0100 Subject: [PATCH] feat(compiler): lower HLFHELinalg.neg_eint --- .../TensorOpsToLinalg.cpp | 104 +++++++++++++++++- .../HLFHELinalgToLinalg/neg_eint.mlir | 19 ++++ .../unittest/end_to_end_jit_hlfhelinalg.cc | 49 ++++++++- 3 files changed, 167 insertions(+), 5 deletions(-) create mode 100644 compiler/tests/Conversion/HLFHELinalgToLinalg/neg_eint.mlir diff --git a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp index a580624ba..2bca0b32b 100644 --- a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -25,8 +25,8 @@ struct DotToLinalgGeneric // This rewrite pattern transforms any instance of // `HLFHELinalg.dot_eint_int` to an instance of `linalg.generic` with an // appropriate region using `HLFHE.mul_eint_int` and - // `HLFHELinalg.add_eint` operations, an appropriate specification for the - // iteration dimensions and appropriate operaztions managing the + // `HLFHE.add_eint` operations, an appropriate specification for the + // iteration dimensions and appropriate operations managing the // accumulator of `linalg.generic`. // // Example: @@ -145,7 +145,7 @@ getBroadcastedAffineMap(const mlir::RankedTensorType &resultType, // operators `HLFHELinalgOp` that implements the broadasting rules to an // instance of `linalg.generic` with an appropriate region using `HLFHEOp` // operation, an appropriate specification for the iteration dimensions and -// appropriate operaztions managing the accumulator of `linalg.generic`. +// appropriate operations managing the accumulator of `linalg.generic`. // // Example: // @@ -244,7 +244,7 @@ struct HLFHELinalgOpToLinalgGeneric // operators `HLFHELinalg.apply_lookup_table` that implements the broadasting // rules to an instance of `linalg.generic` with an appropriate region using // `HLFHE.apply_lookup_table` operation, an appropriate specification for the -// iteration dimensions and appropriate operaztions managing the accumulator of +// iteration dimensions and appropriate operations managing the accumulator of // `linalg.generic`. // // Example: @@ -341,6 +341,101 @@ struct HLFHELinalgApplyLookupTableToLinalgGeneric }; }; +// This template rewrite pattern transforms any instance of +// operators `HLFHELinalg.neg_eint` to an instance of `linalg.generic` with an +// appropriate region using `HLFHE.neg_eint` operation, an appropriate +// specification for the iteration dimensions and appropriate operations +// managing the accumulator of `linalg.generic`. +// +// Example: +// +// HLFHELinalg.neg_eint(%tensor): +// tensor> -> tensor> +// +// becomes: +// +// #maps_0 = [ +// affine_map<(aN, ..., a1) -> (aN, ..., a1)>, +// affine_map<(aN, ..., a1) -> (aN, ..., a1)> +// ] +// #attributes_0 { +// indexing_maps = #maps_0, +// iterator_types = ["parallel",..],//N parallel +// } +// %init = linalg.init_tensor [DN,...,D1] +// : tensor> +// %res = linalg.generic { +// ins(%tensor: tensor>) +// outs(%init : tensor>) +// { +// ^bb0(%arg0: !HLFHE.eint

): +// %0 = HLFHE.neg_eint(%arg0): !HLFHE.eint

-> !HLFHE.eint +// linalg.yield %0 : !HLFHE.eint +// } +// } +// +struct HLFHELinalgNegEintToLinalgGeneric + : public mlir::OpRewritePattern { + HLFHELinalgNegEintToLinalgGeneric(::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern( + context, benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(mlir::zamalang::HLFHELinalg::NegEintOp negEintOp, + ::mlir::PatternRewriter &rewriter) const override { + mlir::RankedTensorType resultTy = + ((mlir::Type)negEintOp->getResult(0).getType()) + .cast(); + mlir::RankedTensorType tensorTy = ((mlir::Type)negEintOp.tensor().getType()) + .cast(); + + // linalg.init_tensor for initial value + mlir::Value init = rewriter.create( + negEintOp.getLoc(), resultTy.getShape(), resultTy.getElementType()); + + // Create the affine #maps_0 + llvm::SmallVector maps{ + mlir::AffineMap::getMultiDimIdentityMap(tensorTy.getShape().size(), + this->getContext()), + mlir::AffineMap::getMultiDimIdentityMap(resultTy.getShape().size(), + this->getContext()), + }; + + // Create the iterator_types + llvm::SmallVector iteratorTypes(resultTy.getShape().size(), + "parallel"); + + // Create the body of the `linalg.generic` op + auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder, + mlir::Location nestedLoc, + mlir::ValueRange blockArgs) { + mlir::zamalang::HLFHE::NegEintOp hlfheOp = + nestedBuilder.create( + negEintOp.getLoc(), resultTy.getElementType(), blockArgs[0]); + + nestedBuilder.create(negEintOp.getLoc(), + hlfheOp.getResult()); + }; + + // Create the `linalg.generic` op + llvm::SmallVector resTypes{init.getType()}; + llvm::SmallVector ins{negEintOp.tensor()}; + llvm::SmallVector outs{init}; + llvm::StringRef doc{""}; + llvm::StringRef call{""}; + + mlir::linalg::GenericOp genericOp = + rewriter.create(negEintOp.getLoc(), resTypes, + ins, outs, maps, iteratorTypes, + doc, call, bodyBuilder); + + rewriter.replaceOp(negEintOp, {genericOp.getResult(0)}); + + return ::mlir::success(); + }; +}; + namespace { struct HLFHETensorOpsToLinalg : public HLFHETensorOpsToLinalgBase { @@ -381,6 +476,7 @@ void HLFHETensorOpsToLinalg::runOnFunction() { mlir::zamalang::HLFHE::MulEintIntOp>>( &getContext()); patterns.insert(&getContext()); + patterns.insert(&getContext()); if (mlir::applyPartialConversion(function, target, std::move(patterns)) .failed()) diff --git a/compiler/tests/Conversion/HLFHELinalgToLinalg/neg_eint.mlir b/compiler/tests/Conversion/HLFHELinalgToLinalg/neg_eint.mlir new file mode 100644 index 000000000..58037ebee --- /dev/null +++ b/compiler/tests/Conversion/HLFHELinalgToLinalg/neg_eint.mlir @@ -0,0 +1,19 @@ +// RUN: zamacompiler %s --action=dump-midlfhe --passes hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s + +// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-NEXT: module { +// CHECK-NEXT: func @neg_eint(%arg0: tensor<2x3x4x!HLFHE.eint<2>>) -> tensor<2x3x4x!HLFHE.eint<2>> { +// CHECK-NEXT: %0 = linalg.init_tensor [2, 3, 4] : tensor<2x3x4x!HLFHE.eint<2>> +// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3x4x!HLFHE.eint<2>>) outs(%0 : tensor<2x3x4x!HLFHE.eint<2>>) { +// CHECK-NEXT: ^bb0(%arg1: !HLFHE.eint<2>, %arg2: !HLFHE.eint<2>): // no predecessors +// CHECK-NEXT: %2 = "HLFHE.neg_eint"(%arg1) : (!HLFHE.eint<2>) -> !HLFHE.eint<2> +// CHECK-NEXT: linalg.yield %2 : !HLFHE.eint<2> +// CHECK-NEXT: } -> tensor<2x3x4x!HLFHE.eint<2>> +// CHECK-NEXT: return %1 : tensor<2x3x4x!HLFHE.eint<2>> +// CHECK-NEXT: } +// CHECK-NEXT: } + +func @neg_eint(%arg0: tensor<2x3x4x!HLFHE.eint<2>>) -> tensor<2x3x4x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.neg_eint"(%arg0): (tensor<2x3x4x!HLFHE.eint<2>>) -> (tensor<2x3x4x!HLFHE.eint<2>>) + return %1: tensor<2x3x4x!HLFHE.eint<2>> +} \ No newline at end of file diff --git a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc index f805a2d9a..68a0a8797 100644 --- a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc @@ -1023,4 +1023,51 @@ func @main(%arg0: tensor<4x!HLFHE.eint<7>>, lambda(arg0, ARRAY_SIZE(arg0), arg1, ARRAY_SIZE(arg1)); ASSERT_EXPECTED_VALUE(res, 14); -} \ No newline at end of file +} + +/////////////////////////////////////////////////////////////////////////////// +// HLFHELinalg neg_eint ///////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(End2EndJit_HLFHELinalg, neg_eint) { + + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + // Returns the negation of a 3x3 matrix of encrypted integers of width 2. + // + // ([0,1,2]) [0,7,6] + // negate ([3,4,5]) = [5,4,3] + // ([6,7,0]) [2,1,0] + func @main(%t: tensor<3x3x!HLFHE.eint<2>>) -> tensor<3x3x!HLFHE.eint<2>> { + %res = "HLFHELinalg.neg_eint"(%t) : (tensor<3x3x!HLFHE.eint<2>>) -> tensor<3x3x!HLFHE.eint<2>> + return %res : tensor<3x3x!HLFHE.eint<2>> + } +)XXX"); + const uint8_t t[3][3]{ + {0, 1, 2}, + {3, 4, 5}, + {6, 7, 0}, + }; + const uint8_t expected[3][3]{ + {0, 7, 6}, + {5, 4, 3}, + {2, 1, 0}, + }; + + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + tArg(llvm::MutableArrayRef((uint8_t *)t, 3 * 3), {3, 3}); + + llvm::Expected> res = + lambda.operator()>({&tArg}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), 3 * 3); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 3; j++) { + EXPECT_EQ((*res)[i * 3 + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +}