From 7c83994c7f010ebdae4e0600d233bf0ab90a6d59 Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 11 Nov 2021 15:22:14 +0100 Subject: [PATCH] fix(compiler): don't try to broadcast when dims are equals with value 1 --- .../TensorOpsToLinalg.cpp | 2 +- .../unittest/end_to_end_jit_hlfhelinalg.cc | 43 +++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp index 2bca0b32b..9cb5ecac4 100644 --- a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -131,7 +131,7 @@ getBroadcastedAffineMap(const mlir::RankedTensorType &resultType, affineExprs.reserve(resultShape.size()); size_t deltaNumDim = resultShape.size() - operandShape.size(); for (auto i = 0; i < operandShape.size(); i++) { - if (operandShape[i] == 1) { + if (operandShape[i] == 1 && resultShape[i + deltaNumDim] != 1) { affineExprs.push_back(rewriter.getAffineConstantExpr(0)); } else { affineExprs.push_back(rewriter.getAffineDimExpr(i + deltaNumDim)); diff --git a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc index 68a0a8797..9992e7c42 100644 --- a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc @@ -523,6 +523,49 @@ TEST(End2EndJit_HLFHELinalg, add_eint_matrix_line_missing_dim) { } } +TEST(End2EndJit_HLFHELinalg, add_eint_tensor_dim_equals_1) { + + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + // Broadcasting shouldn't happen when some dimensions are equals to 1 + func @main(%arg0: tensor<3x1x2x!HLFHE.eint<5>>, %arg1: tensor<3x1x2x!HLFHE.eint<5>>) -> tensor<3x1x2x!HLFHE.eint<5>> { + %1 = "HLFHELinalg.add_eint"(%arg0, %arg1) : (tensor<3x1x2x!HLFHE.eint<5>>, tensor<3x1x2x!HLFHE.eint<5>>) -> tensor<3x1x2x!HLFHE.eint<5>> + return %1 : tensor<3x1x2x!HLFHE.eint<5>> + } +)XXX"); + const uint8_t a0[3][1][2]{ + {{1, 2}}, + {{4, 5}}, + {{7, 8}}, + }; + const uint8_t a1[3][1][2]{ + {{8, 10}}, + {{12, 14}}, + {{16, 18}}, + }; + + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + arg0(llvm::ArrayRef((const uint8_t *)a0, 3 * 2), {3, 1, 2}); + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + arg1(llvm::ArrayRef((const uint8_t *)a1, 3 * 2), {3, 1, 2}); + + llvm::Expected> res = + lambda.operator()>({&arg0, &arg1}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), 3 * 1 * 2); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 1; j++) { + for (size_t k = 0; k < 2; k++) { + EXPECT_EQ((*res)[i * 2 + j + k], a0[i][j][k] + a1[i][j][k]); + } + } + } +} + /////////////////////////////////////////////////////////////////////////////// // HLFHELinalg sub_int_eint /////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////