From 47b4b667bbdfc6c577764947c9ce3fcc047ee5fc Mon Sep 17 00:00:00 2001 From: rudy Date: Thu, 18 Nov 2021 15:10:47 +0100 Subject: [PATCH] fix(Lambda): #253 fix the bug about Lambda parameter number verification [----------] Global test environment tear-down [==========] 6 tests from 1 test suite ran. (1513 ms total) [ PASSED ] 6 tests. YOU HAVE 3 DISABLED TESTS Compared to previous commit, a fatal test is disabled --- compiler/lib/Support/Jit.cpp | 22 +++++++++---------- .../tests/unittest/end_to_end_jit_lambda.cc | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 549968b86..868cd2f35 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -55,23 +55,23 @@ JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module, } llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef args) { - size_t nbReturn = 0; - if (!this->type.getReturnType().isa()) { - nbReturn = 1; - } - if (this->type.getNumParams() != args.size() - nbReturn) { + if (!args.empty() && llvm::find(args, nullptr) != args.end()) { return llvm::make_error( - "invokeRaw: wrong number of argument", + "invoke: some arguments are null or missing", llvm::inconvertibleErrorCode()); } - if (llvm::find(args, nullptr) != args.end()) { - return llvm::make_error( - "invoke: some arguments are null", llvm::inconvertibleErrorCode()); - } return this->engine->invokePacked(this->name, args); } llvm::Error JITLambda::invoke(Argument &args) { + size_t expectedInputs = this->type.getNumParams(); + size_t actualInputs = args.inputs.size(); + if (expectedInputs != actualInputs) { + auto msg = "invokeRaw: received " + llvm::Twine(actualInputs) + + "arguments instead of " + llvm::Twine(expectedInputs); + return llvm::make_error(msg, + llvm::inconvertibleErrorCode()); + } return std::move(invokeRaw(args.rawArg)); } @@ -196,7 +196,7 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, auto info = std::get<0>(gate); auto offset = std::get<1>(gate); // Check if the width is compatible - // TODO - I found this rules empirically, they are a spec somewhere? + // TODO - I found this rules empirically, they are a spec somewhere? if (info.shape.width <= 8 && width != 8) { return llvm::make_error( llvm::Twine("argument width should be 8: pos=") diff --git a/compiler/tests/unittest/end_to_end_jit_lambda.cc b/compiler/tests/unittest/end_to_end_jit_lambda.cc index 583ae2b5b..8b027fbc9 100644 --- a/compiler/tests/unittest/end_to_end_jit_lambda.cc +++ b/compiler/tests/unittest/end_to_end_jit_lambda.cc @@ -86,7 +86,7 @@ TEST(Lambda_check_param, DISABLED_scalar_tensor_to_scalar_superfluous_param) { ASSERT_EXPECTED_FAILURE(lambda(1_u64, arg, ARRAY_SIZE(arg), arg, ARRAY_SIZE(arg))); } -TEST(Lambda_check_param, DISABLED_scalar_tensor_to_tensor_good_number_param) { +TEST(Lambda_check_param, scalar_tensor_to_tensor_good_number_param) { Lambda lambda = checkedJit(R"XXX( func @main( %arg0: !HLFHE.eint<1>, %arg1: tensor<2x!HLFHE.eint<1>>) -> tensor<2x!HLFHE.eint<1>>