mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
test(comiler): add tests for result tensor when calling Lambda::operator()
Lambda::operator()<unique_ptr<LambdaArgument>> was giving bad results when using tensors due to bad references. Now it's fixed and we want to spot that in the future.
This commit is contained in:
@@ -38,6 +38,46 @@ TEST(End2EndJit_HLFHELinalg, add_eint_int_term_to_term) {
|
||||
}
|
||||
}
|
||||
|
||||
// Same as add_eint_int_term_to_term test above, but returning a lambda argument
|
||||
TEST(End2EndJit_HLFHELinalg, add_eint_int_term_to_term_ret_lambda_argument) {
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
// Returns the term to term addition of `%a0` with `%a1`
|
||||
func @main(%a0: tensor<4x!HLFHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>> {
|
||||
%res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<6>>, tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>>
|
||||
return %res : tensor<4x!HLFHE.eint<6>>
|
||||
}
|
||||
)XXX");
|
||||
std::vector<uint8_t> a0{31, 6, 12, 9};
|
||||
std::vector<uint8_t> a1{32, 9, 2, 3};
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>
|
||||
arg0(a0);
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>
|
||||
arg1(a1);
|
||||
|
||||
llvm::Expected<std::unique_ptr<mlir::zamalang::LambdaArgument>> res =
|
||||
lambda.operator()<std::unique_ptr<mlir::zamalang::LambdaArgument>>(
|
||||
{&arg0, &arg1});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<mlir::zamalang::IntLambdaArgument<>>
|
||||
&resp = (*res)
|
||||
->cast<mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<>>>();
|
||||
|
||||
ASSERT_EQ(resp.getDimensions().size(), (size_t)1);
|
||||
ASSERT_EQ(resp.getDimensions().at(0), 4);
|
||||
ASSERT_EXPECTED_VALUE(resp.getNumElements(), 4);
|
||||
|
||||
for (size_t i = 0; i < 4; i++) {
|
||||
ASSERT_EQ(resp.getValue()[i], a0[i] + a1[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_HLFHELinalg, add_eint_int_term_to_term_broadcast) {
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
|
||||
@@ -305,7 +305,7 @@ func @main(%0: !HLFHE.eint<5>) -> tensor<1x!HLFHE.eint<5>> {
|
||||
ASSERT_EQ(resp.getDimensions().size(), (size_t)1);
|
||||
ASSERT_EQ(resp.getDimensions().at(0), 1);
|
||||
ASSERT_EXPECTED_VALUE(resp.getNumElements(), 1);
|
||||
// ASSERT_EQ(resp.getValue()[0], 10_u64);
|
||||
ASSERT_EQ(resp.getValue()[0], 10_u64);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, in_out_tensor_with_op_5) {
|
||||
|
||||
Reference in New Issue
Block a user