fix(Lambda): missing superfluous check in setArg

[----------] Global test environment tear-down
[==========] 7 tests from 1 test suite ran. (1513 ms total)
[  PASSED  ] 7 tests.

  YOU HAVE 2 DISABLED TESTS
This commit is contained in:
rudy
2021-11-19 15:33:34 +01:00
committed by rudy-6-4
parent 47b4b667bb
commit 2c56a26c75
3 changed files with 21 additions and 6 deletions

View File

@@ -76,6 +76,8 @@ public:
llvm::Expected<std::vector<int64_t>> getResultDimensions(size_t pos);
private:
// Verify if lambda can accept a n-th argument.
llvm::Error acceptNthArg(size_t n);
llvm::Error setArg(size_t pos, size_t width, const void *data,
llvm::ArrayRef<int64_t> shape);

View File

@@ -151,12 +151,21 @@ JITLambda::Argument::create(KeySet &keySet) {
return std::move(args);
}
llvm::Error JITLambda::Argument::acceptNthArg(size_t pos) {
size_t arity = inputGates.size();
if (pos >= arity) {
auto msg = "Call a function of arity " + llvm::Twine(arity) +
" with at least " + llvm::Twine(pos + 1) + " arguments";
return llvm::make_error<llvm::StringError>(msg,
llvm::inconvertibleErrorCode());
}
return llvm::Error::success();
}
llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
if (pos >= inputGates.size()) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument index out of bound: pos=")
.concat(llvm::Twine(pos)),
llvm::inconvertibleErrorCode());
auto error = acceptNthArg(pos);
if (error) {
return error;
}
auto gate = inputGates[pos];
auto info = std::get<0>(gate);
@@ -192,6 +201,10 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width,
const void *data,
llvm::ArrayRef<int64_t> shape) {
auto error = acceptNthArg(pos);
if (error) {
return error;
}
auto gate = inputGates[pos];
auto info = std::get<0>(gate);
auto offset = std::get<1>(gate);

View File

@@ -74,7 +74,7 @@ TEST(Lambda_check_param, scalar_tensor_to_scalar) {
ASSERT_EXPECTED_SUCCESS(lambda(1_u64, arg, ARRAY_SIZE(arg)));
}
TEST(Lambda_check_param, DISABLED_scalar_tensor_to_scalar_superfluous_param) {
TEST(Lambda_check_param, scalar_tensor_to_scalar_superfluous_param) {
Lambda lambda = checkedJit(R"XXX(
func @main(
%arg0: !HLFHE.eint<1>, %arg1: tensor<2x!HLFHE.eint<1>>) -> !HLFHE.eint<1>