diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 18badbddd..549968b86 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -55,16 +55,15 @@ JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module, } llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef args) { - // size_t nbReturn = 0; - // TODO - This check break with memref as we have 5 returns args. - // if (!this->type.getReturnType().isa()) { - // nbReturn = 1; - // } - // if (this->type.getNumParams() != args.size() - nbReturn) { - // return llvm::make_error( - // "invokeRaw: wrong number of argument", - // llvm::inconvertibleErrorCode()); - // } + size_t nbReturn = 0; + if (!this->type.getReturnType().isa()) { + nbReturn = 1; + } + if (this->type.getNumParams() != args.size() - nbReturn) { + return llvm::make_error( + "invokeRaw: wrong number of argument", + llvm::inconvertibleErrorCode()); + } if (llvm::find(args, nullptr) != args.end()) { return llvm::make_error( "invoke: some arguments are null", llvm::inconvertibleErrorCode()); diff --git a/compiler/tests/unittest/end_to_end_jit_lambda.cc b/compiler/tests/unittest/end_to_end_jit_lambda.cc index 6400521c0..583ae2b5b 100644 --- a/compiler/tests/unittest/end_to_end_jit_lambda.cc +++ b/compiler/tests/unittest/end_to_end_jit_lambda.cc @@ -74,8 +74,7 @@ TEST(Lambda_check_param, scalar_tensor_to_scalar) { ASSERT_EXPECTED_SUCCESS(lambda(1_u64, arg, ARRAY_SIZE(arg))); } -TEST(Lambda_check_param, scalar_tensor_to_scalar_superfluous_param) { - // DISABLED Note: "terminate called after throwing an instance of 'std::bad_alloc'" +TEST(Lambda_check_param, DISABLED_scalar_tensor_to_scalar_superfluous_param) { Lambda lambda = checkedJit(R"XXX( func @main( %arg0: !HLFHE.eint<1>, %arg1: tensor<2x!HLFHE.eint<1>>) -> !HLFHE.eint<1> @@ -87,7 +86,7 @@ TEST(Lambda_check_param, scalar_tensor_to_scalar_superfluous_param) { ASSERT_EXPECTED_FAILURE(lambda(1_u64, arg, ARRAY_SIZE(arg), arg, ARRAY_SIZE(arg))); } -TEST(Lambda_check_param, scalar_tensor_to_tensor_good_number_param) { +TEST(Lambda_check_param, DISABLED_scalar_tensor_to_tensor_good_number_param) { Lambda lambda = checkedJit(R"XXX( func @main( %arg0: !HLFHE.eint<1>, %arg1: tensor<2x!HLFHE.eint<1>>) -> tensor<2x!HLFHE.eint<1>>