test(comiler): add tests for result tensor when calling Lambda::operator()

Lambda::operator()<unique_ptr<LambdaArgument>> was giving bad results
when using tensors due to bad references. Now it's fixed and we want to
spot that in the future.
This commit is contained in:
youben11
2021-11-04 15:46:40 +01:00
committed by Ayoub Benaissa
parent c92f047721
commit badc8e44bf
2 changed files with 41 additions and 1 deletions

View File

@@ -38,6 +38,46 @@ TEST(End2EndJit_HLFHELinalg, add_eint_int_term_to_term) {
}
}
// Same as add_eint_int_term_to_term test above, but returning a lambda argument
TEST(End2EndJit_HLFHELinalg, add_eint_int_term_to_term_ret_lambda_argument) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
// Returns the term to term addition of `%a0` with `%a1`
func @main(%a0: tensor<4x!HLFHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>> {
%res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<6>>, tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>>
return %res : tensor<4x!HLFHE.eint<6>>
}
)XXX");
std::vector<uint8_t> a0{31, 6, 12, 9};
std::vector<uint8_t> a1{32, 9, 2, 3};
mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint8_t>>
arg0(a0);
mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint8_t>>
arg1(a1);
llvm::Expected<std::unique_ptr<mlir::zamalang::LambdaArgument>> res =
lambda.operator()<std::unique_ptr<mlir::zamalang::LambdaArgument>>(
{&arg0, &arg1});
ASSERT_EXPECTED_SUCCESS(res);
mlir::zamalang::TensorLambdaArgument<mlir::zamalang::IntLambdaArgument<>>
&resp = (*res)
->cast<mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<>>>();
ASSERT_EQ(resp.getDimensions().size(), (size_t)1);
ASSERT_EQ(resp.getDimensions().at(0), 4);
ASSERT_EXPECTED_VALUE(resp.getNumElements(), 4);
for (size_t i = 0; i < 4; i++) {
ASSERT_EQ(resp.getValue()[i], a0[i] + a1[i]);
}
}
TEST(End2EndJit_HLFHELinalg, add_eint_int_term_to_term_broadcast) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(

View File

@@ -305,7 +305,7 @@ func @main(%0: !HLFHE.eint<5>) -> tensor<1x!HLFHE.eint<5>> {
ASSERT_EQ(resp.getDimensions().size(), (size_t)1);
ASSERT_EQ(resp.getDimensions().at(0), 1);
ASSERT_EXPECTED_VALUE(resp.getNumElements(), 1);
// ASSERT_EQ(resp.getValue()[0], 10_u64);
ASSERT_EQ(resp.getValue()[0], 10_u64);
}
TEST(CompileAndRunTensorEncrypted, in_out_tensor_with_op_5) {