mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(compiler): don't try to broadcast when dims are equals with value 1
This commit is contained in:
@@ -131,7 +131,7 @@ getBroadcastedAffineMap(const mlir::RankedTensorType &resultType,
|
||||
affineExprs.reserve(resultShape.size());
|
||||
size_t deltaNumDim = resultShape.size() - operandShape.size();
|
||||
for (auto i = 0; i < operandShape.size(); i++) {
|
||||
if (operandShape[i] == 1) {
|
||||
if (operandShape[i] == 1 && resultShape[i + deltaNumDim] != 1) {
|
||||
affineExprs.push_back(rewriter.getAffineConstantExpr(0));
|
||||
} else {
|
||||
affineExprs.push_back(rewriter.getAffineDimExpr(i + deltaNumDim));
|
||||
|
||||
@@ -523,6 +523,49 @@ TEST(End2EndJit_HLFHELinalg, add_eint_matrix_line_missing_dim) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_HLFHELinalg, add_eint_tensor_dim_equals_1) {
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
// Broadcasting shouldn't happen when some dimensions are equals to 1
|
||||
func @main(%arg0: tensor<3x1x2x!HLFHE.eint<5>>, %arg1: tensor<3x1x2x!HLFHE.eint<5>>) -> tensor<3x1x2x!HLFHE.eint<5>> {
|
||||
%1 = "HLFHELinalg.add_eint"(%arg0, %arg1) : (tensor<3x1x2x!HLFHE.eint<5>>, tensor<3x1x2x!HLFHE.eint<5>>) -> tensor<3x1x2x!HLFHE.eint<5>>
|
||||
return %1 : tensor<3x1x2x!HLFHE.eint<5>>
|
||||
}
|
||||
)XXX");
|
||||
const uint8_t a0[3][1][2]{
|
||||
{{1, 2}},
|
||||
{{4, 5}},
|
||||
{{7, 8}},
|
||||
};
|
||||
const uint8_t a1[3][1][2]{
|
||||
{{8, 10}},
|
||||
{{12, 14}},
|
||||
{{16, 18}},
|
||||
};
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>
|
||||
arg0(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 2), {3, 1, 2});
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>
|
||||
arg1(llvm::ArrayRef<uint8_t>((const uint8_t *)a1, 3 * 2), {3, 1, 2});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), 3 * 1 * 2);
|
||||
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
for (size_t j = 0; j < 1; j++) {
|
||||
for (size_t k = 0; k < 2; k++) {
|
||||
EXPECT_EQ((*res)[i * 2 + j + k], a0[i][j][k] + a1[i][j][k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// HLFHELinalg sub_int_eint ///////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Reference in New Issue
Block a user