From ba5456068002d13796d79434f617da9cdecb944d Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Mon, 25 Oct 2021 15:37:48 +0200 Subject: [PATCH] feat(compiler): Lower HLFHELinalg binary operators --- .../HLFHETensorOpsToLinalg/CMakeLists.txt | 4 +- .../TensorOpsToLinalg.cpp | 126 ++++++++++ compiler/lib/Support/CompilerEngine.cpp | 2 + .../end_to_end_jit_encrypted_tensor.cc | 218 ++++++++++++++++++ 4 files changed, 349 insertions(+), 1 deletion(-) diff --git a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/CMakeLists.txt b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/CMakeLists.txt index f138b86cd..9af0bfef9 100644 --- a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/CMakeLists.txt +++ b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/CMakeLists.txt @@ -6,10 +6,12 @@ add_mlir_dialect_library(HLFHETensorOpsToLinalg DEPENDS HLFHEDialect + HLFHELinalgDialect MLIRConversionPassIncGen LINK_LIBS PUBLIC MLIRIR - HLFHEDialect) + HLFHEDialect + HLFHELinalgDialect) target_link_libraries(HLFHEDialect PUBLIC MLIRIR) diff --git a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp index 46013f17e..05d09413d 100644 --- a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -13,6 +13,8 @@ #include "zamalang/Conversion/Passes.h" #include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" #include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h" +#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h" +#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h" struct DotToLinalgGeneric : public ::mlir::RewritePattern { DotToLinalgGeneric(::mlir::MLIRContext *context) @@ -121,6 +123,125 @@ struct DotToLinalgGeneric : public ::mlir::RewritePattern { }; }; +mlir::AffineMap +getBroadcastedAffineMap(const mlir::RankedTensorType &resultType, + const mlir::RankedTensorType &operandType, + ::mlir::PatternRewriter &rewriter) { + mlir::SmallVector affineExprs; + auto resultShape = resultType.getShape(); + auto operandShape = operandType.getShape(); + affineExprs.reserve(resultShape.size()); + size_t deltaNumDim = resultShape.size() - operandShape.size(); + for (auto i = 0; i < operandShape.size(); i++) { + if (operandShape[i] == 1) { + affineExprs.push_back(rewriter.getAffineConstantExpr(0)); + } else { + affineExprs.push_back(rewriter.getAffineDimExpr(i + deltaNumDim)); + } + } + return mlir::AffineMap::get(resultShape.size(), 0, affineExprs, + rewriter.getContext()); +} + +// This template rewrite pattern transforms any instance of +// operators `HLFHELinalgOp` that implements the broadasting rules to an +// instance of `linalg.generic` with an appropriate region using `HLFHEOp` +// operation, an appropriate specification for the iteration dimensions and +// appropriate operaztions managing the accumulator of `linalg.generic`. +// +// Example: +// +// %res = HLFHELinalg.op(%lhs, %rhs): +// (tensor>, tensor) +// -> tensor> +// +// becomes: +// +// #maps_0 = [ +// affine_map<(a$R", ..., a$A, ..., a1) -> +// (dim(lhs, $A) == 1 ? 0 : a$A,..., dim(lhs, 1) == 1 ? 0 : a1)>, +// affine_map<(a$R", ..., a1) -> +// (dim(rhs, $B') == 1 ? 0 : a$B', ..., dim(rhs, 1) == 1 ? 0 : a1)>, +// affine_map<(a$R", ..., a1) -> (a$R", ..., a1) +// ] +// #attributes_0 { +// indexing_maps = #maps_0, +// iterator_types = ["parallel", ..., "parallel"], // $R" parallel +// } +// %init = linalg.init_tensor [DR",...,D1"] +// : tensor> +// %res = linalg.generic { +// ins(%lhs, %rhs: tensor>,tensor) +// outs(%init : tensor>) +// { +// ^bb0(%arg0: !HLFHE.eint

, %arg1: T): +// %0 = HLFHE.op(%arg0, %arg1): !HLFHE.eint

, T -> +// !HLFHE.eint

+// linalg.yield %0 : !HLFHE.eint

+// } +// } +// +template +struct HLFHELinalgOpToLinalgGeneric + : public mlir::OpRewritePattern { + HLFHELinalgOpToLinalgGeneric(::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(context, benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(HLFHELinalgOp linalgOp, + ::mlir::PatternRewriter &rewriter) const override { + mlir::RankedTensorType resultTy = + ((mlir::Type)linalgOp->getResult(0).getType()) + .cast(); + mlir::RankedTensorType lhsTy = + ((mlir::Type)linalgOp.lhs().getType()).cast(); + mlir::RankedTensorType rhsTy = + ((mlir::Type)linalgOp.rhs().getType()).cast(); + // linalg.init_tensor for initial value + mlir::Value init = rewriter.create( + linalgOp.getLoc(), resultTy.getShape(), resultTy.getElementType()); + + // Create the affine #maps_0 + llvm::SmallVector maps{ + getBroadcastedAffineMap(resultTy, lhsTy, rewriter), + getBroadcastedAffineMap(resultTy, rhsTy, rewriter), + getBroadcastedAffineMap(resultTy, resultTy, rewriter), + }; + + // Create the iterator_types + llvm::SmallVector iteratorTypes(resultTy.getShape().size(), + "parallel"); + + // Create the body of the `linalg.generic` op + auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder, + mlir::Location nestedLoc, + mlir::ValueRange blockArgs) { + HLFHEOp hlfheOp = nestedBuilder.create( + linalgOp.getLoc(), blockArgs[0], blockArgs[1]); + + nestedBuilder.create(linalgOp.getLoc(), + hlfheOp.getResult()); + }; + + // Create the `linalg.generic` op + llvm::SmallVector resTypes{init.getType()}; + llvm::SmallVector ins{linalgOp.lhs(), linalgOp.rhs()}; + llvm::SmallVector outs{init}; + llvm::StringRef doc{""}; + llvm::StringRef call{""}; + + mlir::linalg::GenericOp genericOp = + rewriter.create(linalgOp.getLoc(), resTypes, + ins, outs, maps, iteratorTypes, + doc, call, bodyBuilder); + + rewriter.replaceOp(linalgOp, {genericOp.getResult(0)}); + + return ::mlir::success(); + }; +}; + namespace { struct HLFHETensorOpsToLinalg : public HLFHETensorOpsToLinalgBase { @@ -139,9 +260,14 @@ void HLFHETensorOpsToLinalg::runOnFunction() { target.addLegalDialect(); target.addLegalDialect(); target.addIllegalOp(); + target.addIllegalDialect(); mlir::OwningRewritePatternList patterns(&getContext()); patterns.insert(&getContext()); + patterns.insert< + HLFHELinalgOpToLinalgGeneric>( + &getContext()); if (mlir::applyPartialConversion(function, target, std::move(patterns)) .failed()) diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 5a41c7e46..99a495dbb 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -15,6 +16,7 @@ namespace mlir { namespace zamalang { void CompilerEngine::loadDialects() { + context->getOrLoadDialect(); context->getOrLoadDialect(); context->getOrLoadDialect(); context->getOrLoadDialect(); diff --git a/compiler/tests/unittest/end_to_end_jit_encrypted_tensor.cc b/compiler/tests/unittest/end_to_end_jit_encrypted_tensor.cc index c10774ba8..dc362a46f 100644 --- a/compiler/tests/unittest/end_to_end_jit_encrypted_tensor.cc +++ b/compiler/tests/unittest/end_to_end_jit_encrypted_tensor.cc @@ -168,4 +168,222 @@ func @main(%t0: tensor<2x10x!HLFHE.eint<6>>, %t1: tensor<2x2x!HLFHE.eint<6>>) -> } } } +} + +/////////////////////////////////////////////////////////////////////////////// +// HLFHELinalg add_eint /////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(End2EndJit_HLFHELinalg, add_eint_term_to_term) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( + // Returns the term to term addition of `%a0` with `%a1` + func @main(%a0: tensor<4x!HLFHE.eint<4>>, %a1: tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>> { + %res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>> + return %res : tensor<4x!HLFHE.eint<4>> + } +)XXX"; + const uint8_t a0[4]{31, 6, 12, 9}; + const uint8_t a1[4]{32, 9, 2, 3}; + + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints())); + + auto maybeArgument = engine.buildArgument(); + ASSERT_LLVM_ERROR(maybeArgument.takeError()); + auto argument = std::move(maybeArgument.get()); + // Set the %a0 and %a1 argument + ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)a0, 4)); + ASSERT_LLVM_ERROR(argument->setArg(1, (uint8_t *)a1, 4)); + // Invoke the function + ASSERT_LLVM_ERROR(engine.invoke(*argument)); + // Get and assert the result + uint64_t result[4]; + ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, 4)); + for (size_t i = 0; i < 4; i++) { + EXPECT_EQ(result[i], a0[i] + a1[i]) + << "result differ at pos " << i << ", expect " << a0[i] + a1[i] + << " got " << result[i]; + } +} + +TEST(End2EndJit_HLFHELinalg, add_eint_term_to_term_broadcast) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( + // Returns the term to term addition of `%a0` with `%a1` + func @main(%a0: tensor<4x1x4x!HLFHE.eint<4>>, %a1: tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>> { + %res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>> + return %res : tensor<4x4x4x!HLFHE.eint<4>> + } +)XXX"; + const uint8_t a0[4][1][4]{ + {{1, 2, 3, 4}}, + {{5, 6, 7, 8}}, + {{9, 10, 11, 12}}, + {{13, 14, 15, 16}}, + }; + const uint8_t a1[1][4][4]{ + { + {1, 2, 3, 4}, + {5, 6, 7, 8}, + {9, 10, 11, 12}, + {13, 14, 15, 16}, + }, + }; + + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints())); + + auto maybeArgument = engine.buildArgument(); + ASSERT_LLVM_ERROR(maybeArgument.takeError()); + auto argument = std::move(maybeArgument.get()); + // Set the %a0 and %a1 argument + ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)a0, {4, 1, 4})); + ASSERT_LLVM_ERROR(argument->setArg(1, (uint8_t *)a1, {1, 4, 4})); + // Invoke the function + ASSERT_LLVM_ERROR(engine.invoke(*argument)); + // Get and assert the result + uint64_t result[4][4][4]; + ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, 4 * 4 * 4)); + for (size_t i = 0; i < 4; i++) { + for (size_t j = 0; j < 4; j++) { + for (size_t k = 0; k < 4; k++) { + EXPECT_EQ(result[i][j][k], a0[i][0][k] + a1[0][j][k]) + << "result differ at pos " << i << ", expect " + << a0[i][0][k] + a1[0][j][k] << " got " << result[i]; + } + } + } +} + +TEST(End2EndJit_HLFHELinalg, add_eint_matrix_column) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( + // Returns the addition of a 3x3 matrix of encrypted integers and a 3x1 matrix (a column) of encrypted integers. + // + // [1,2,3] [1] [2,3,4] + // [4,5,6] + [2] = [6,7,8] + // [7,8,9] [3] [10,11,12] + // + // The dimension #1 of operand #2 is stretched as it is equals to 1. + func @main(%a0: tensor<3x3x!HLFHE.eint<4>>, %a1: tensor<3x1x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> { + %res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x1x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> + return %res : tensor<3x3x!HLFHE.eint<4>> + } +)XXX"; + const uint8_t a0[3][3]{ + {1, 2, 3}, + {4, 5, 6}, + {7, 8, 9}, + }; + const uint8_t a1[3][1]{ + {1}, + {2}, + {3}, + }; + + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints())); + + auto maybeArgument = engine.buildArgument(); + ASSERT_LLVM_ERROR(maybeArgument.takeError()); + auto argument = std::move(maybeArgument.get()); + // Set the %a0 and %a1 argument + ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)a0, {3, 3})); + ASSERT_LLVM_ERROR(argument->setArg(1, (uint8_t *)a1, {3, 1})); + // Invoke the function + ASSERT_LLVM_ERROR(engine.invoke(*argument)); + // Get and assert the result + uint64_t result[3][3]; + ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, 3 * 3)); + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 3; j++) { + EXPECT_EQ(result[i][j], a0[i][j] + a1[i][0]) + << "result differ at pos " << i << ", expect " << a0[i][j] + a1[i][0] + << " got " << result[i]; + } + } +} + +TEST(End2EndJit_HLFHELinalg, add_eint_matrix_line) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( + // Returns the addition of a 3x3 matrix of encrypted integers and a 1x3 matrix (a line) of encrypted integers. + // + // [1,2,3] [2,4,6] + // [4,5,6] + [1,2,3] = [5,7,9] + // [7,8,9] [8,10,12] + // + // The dimension #2 of operand #2 is stretched as it is equals to 1. + func @main(%a0: tensor<3x3x!HLFHE.eint<4>>, %a1: tensor<1x3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> { + %res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<1x3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> + return %res : tensor<3x3x!HLFHE.eint<4>> + } +)XXX"; + const uint8_t a0[3][3]{ + {1, 2, 3}, + {4, 5, 6}, + {7, 8, 9}, + }; + const uint8_t a1[1][3]{ + {1, 2, 3}, + }; + + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints())); + + auto maybeArgument = engine.buildArgument(); + ASSERT_LLVM_ERROR(maybeArgument.takeError()); + auto argument = std::move(maybeArgument.get()); + // Set the %a0 and %a1 argument + ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)a0, {3, 3})); + ASSERT_LLVM_ERROR(argument->setArg(1, (uint8_t *)a1, {1, 3})); + // Invoke the function + ASSERT_LLVM_ERROR(engine.invoke(*argument)); + // Get and assert the result + uint64_t result[3][3]; + ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, 3 * 3)); + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 3; j++) { + EXPECT_EQ(result[i][j], a0[i][j] + a1[0][j]) + << "result differ at pos (" << i << "," << j << "), expect " + << a0[i][j] + a1[0][j] << " got " << result[i][j] << "\n"; + } + } +} + +TEST(End2EndJit_HLFHELinalg, add_eint_matrix_line_missing_dim) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( + // Same behavior than the previous one, but as the dimension #2 of operand #2 is missing. + func @main(%a0: tensor<3x3x!HLFHE.eint<4>>, %a1: tensor<3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> { + %res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> + return %res : tensor<3x3x!HLFHE.eint<4>> + } +)XXX"; + const uint8_t a0[3][3]{ + {1, 2, 3}, + {4, 5, 6}, + {7, 8, 9}, + }; + const uint8_t a1[1][3]{ + {1, 2, 3}, + }; + + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints())); + + auto maybeArgument = engine.buildArgument(); + ASSERT_LLVM_ERROR(maybeArgument.takeError()); + auto argument = std::move(maybeArgument.get()); + // Set the %a0 and %a1 argument + ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)a0, {3, 3})); + ASSERT_LLVM_ERROR(argument->setArg(1, (uint8_t *)a1, {3})); + // Invoke the function + ASSERT_LLVM_ERROR(engine.invoke(*argument)); + // Get and assert the result + uint64_t result[3][3]; + ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, 3 * 3)); + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 3; j++) { + EXPECT_EQ(result[i][j], a0[i][j] + a1[0][j]) + << "result differ at pos (" << i << "," << j << "), expect " + << a0[i][j] + a1[0][j] << " got " << result[i][j] << "\n"; + } + } } \ No newline at end of file