fix(compiler): don't try to broadcast when dims are equals with value 1

This commit is contained in:
youben11
2021-11-11 15:22:14 +01:00
committed by Ayoub Benaissa
parent 99cce18c6a
commit 7c83994c7f
2 changed files with 44 additions and 1 deletions

View File

@@ -131,7 +131,7 @@ getBroadcastedAffineMap(const mlir::RankedTensorType &resultType,
affineExprs.reserve(resultShape.size());
size_t deltaNumDim = resultShape.size() - operandShape.size();
for (auto i = 0; i < operandShape.size(); i++) {
if (operandShape[i] == 1) {
if (operandShape[i] == 1 && resultShape[i + deltaNumDim] != 1) {
affineExprs.push_back(rewriter.getAffineConstantExpr(0));
} else {
affineExprs.push_back(rewriter.getAffineDimExpr(i + deltaNumDim));

View File

@@ -523,6 +523,49 @@ TEST(End2EndJit_HLFHELinalg, add_eint_matrix_line_missing_dim) {
}
}
TEST(End2EndJit_HLFHELinalg, add_eint_tensor_dim_equals_1) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
// Broadcasting shouldn't happen when some dimensions are equals to 1
func @main(%arg0: tensor<3x1x2x!HLFHE.eint<5>>, %arg1: tensor<3x1x2x!HLFHE.eint<5>>) -> tensor<3x1x2x!HLFHE.eint<5>> {
%1 = "HLFHELinalg.add_eint"(%arg0, %arg1) : (tensor<3x1x2x!HLFHE.eint<5>>, tensor<3x1x2x!HLFHE.eint<5>>) -> tensor<3x1x2x!HLFHE.eint<5>>
return %1 : tensor<3x1x2x!HLFHE.eint<5>>
}
)XXX");
const uint8_t a0[3][1][2]{
{{1, 2}},
{{4, 5}},
{{7, 8}},
};
const uint8_t a1[3][1][2]{
{{8, 10}},
{{12, 14}},
{{16, 18}},
};
mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint8_t>>
arg0(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 2), {3, 1, 2});
mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint8_t>>
arg1(llvm::ArrayRef<uint8_t>((const uint8_t *)a1, 3 * 2), {3, 1, 2});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), 3 * 1 * 2);
for (size_t i = 0; i < 3; i++) {
for (size_t j = 0; j < 1; j++) {
for (size_t k = 0; k < 2; k++) {
EXPECT_EQ((*res)[i * 2 + j + k], a0[i][j][k] + a1[i][j][k]);
}
}
}
}
///////////////////////////////////////////////////////////////////////////////
// HLFHELinalg sub_int_eint ///////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////