diff --git a/compiler/include/zamalang/Support/Jit.h b/compiler/include/zamalang/Support/Jit.h index d959f1460..272eeac42 100644 --- a/compiler/include/zamalang/Support/Jit.h +++ b/compiler/include/zamalang/Support/Jit.h @@ -76,6 +76,8 @@ public: llvm::Expected> getResultDimensions(size_t pos); private: + // Verify if lambda can accept a n-th argument. + llvm::Error acceptNthArg(size_t n); llvm::Error setArg(size_t pos, size_t width, const void *data, llvm::ArrayRef shape); diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 868cd2f35..0d810383a 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -151,12 +151,21 @@ JITLambda::Argument::create(KeySet &keySet) { return std::move(args); } +llvm::Error JITLambda::Argument::acceptNthArg(size_t pos) { + size_t arity = inputGates.size(); + if (pos >= arity) { + auto msg = "Call a function of arity " + llvm::Twine(arity) + + " with at least " + llvm::Twine(pos + 1) + " arguments"; + return llvm::make_error(msg, + llvm::inconvertibleErrorCode()); + } + return llvm::Error::success(); +} + llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) { - if (pos >= inputGates.size()) { - return llvm::make_error( - llvm::Twine("argument index out of bound: pos=") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); + auto error = acceptNthArg(pos); + if (error) { + return error; } auto gate = inputGates[pos]; auto info = std::get<0>(gate); @@ -192,6 +201,10 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) { llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, const void *data, llvm::ArrayRef shape) { + auto error = acceptNthArg(pos); + if (error) { + return error; + } auto gate = inputGates[pos]; auto info = std::get<0>(gate); auto offset = std::get<1>(gate); diff --git a/compiler/tests/unittest/end_to_end_jit_lambda.cc b/compiler/tests/unittest/end_to_end_jit_lambda.cc index 8b027fbc9..023e9ce5d 100644 --- a/compiler/tests/unittest/end_to_end_jit_lambda.cc +++ b/compiler/tests/unittest/end_to_end_jit_lambda.cc @@ -74,7 +74,7 @@ TEST(Lambda_check_param, scalar_tensor_to_scalar) { ASSERT_EXPECTED_SUCCESS(lambda(1_u64, arg, ARRAY_SIZE(arg))); } -TEST(Lambda_check_param, DISABLED_scalar_tensor_to_scalar_superfluous_param) { +TEST(Lambda_check_param, scalar_tensor_to_scalar_superfluous_param) { Lambda lambda = checkedJit(R"XXX( func @main( %arg0: !HLFHE.eint<1>, %arg1: tensor<2x!HLFHE.eint<1>>) -> !HLFHE.eint<1>