diff --git a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc index 6142f1a37..ae3591a2a 100644 --- a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc @@ -38,6 +38,46 @@ TEST(End2EndJit_HLFHELinalg, add_eint_int_term_to_term) { } } +// Same as add_eint_int_term_to_term test above, but returning a lambda argument +TEST(End2EndJit_HLFHELinalg, add_eint_int_term_to_term_ret_lambda_argument) { + + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + // Returns the term to term addition of `%a0` with `%a1` + func @main(%a0: tensor<4x!HLFHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>> { + %res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<6>>, tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>> + return %res : tensor<4x!HLFHE.eint<6>> + } +)XXX"); + std::vector a0{31, 6, 12, 9}; + std::vector a1{32, 9, 2, 3}; + + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + arg0(a0); + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + arg1(a1); + + llvm::Expected> res = + lambda.operator()>( + {&arg0, &arg1}); + + ASSERT_EXPECTED_SUCCESS(res); + + mlir::zamalang::TensorLambdaArgument> + &resp = (*res) + ->cast>>(); + + ASSERT_EQ(resp.getDimensions().size(), (size_t)1); + ASSERT_EQ(resp.getDimensions().at(0), 4); + ASSERT_EXPECTED_VALUE(resp.getNumElements(), 4); + + for (size_t i = 0; i < 4; i++) { + ASSERT_EQ(resp.getValue()[i], a0[i] + a1[i]); + } +} + TEST(End2EndJit_HLFHELinalg, add_eint_int_term_to_term_broadcast) { mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( diff --git a/compiler/tests/unittest/end_to_end_jit_test.cc b/compiler/tests/unittest/end_to_end_jit_test.cc index adef214a0..faf408dd1 100644 --- a/compiler/tests/unittest/end_to_end_jit_test.cc +++ b/compiler/tests/unittest/end_to_end_jit_test.cc @@ -305,7 +305,7 @@ func @main(%0: !HLFHE.eint<5>) -> tensor<1x!HLFHE.eint<5>> { ASSERT_EQ(resp.getDimensions().size(), (size_t)1); ASSERT_EQ(resp.getDimensions().at(0), 1); ASSERT_EXPECTED_VALUE(resp.getNumElements(), 1); - // ASSERT_EQ(resp.getValue()[0], 10_u64); + ASSERT_EQ(resp.getValue()[0], 10_u64); } TEST(CompileAndRunTensorEncrypted, in_out_tensor_with_op_5) {