fix(Lambda): #253 fix the bug about Lambda parameter number verification

[----------] Global test environment tear-down
[==========] 6 tests from 1 test suite ran. (1513 ms total)
[  PASSED  ] 6 tests.

  YOU HAVE 3 DISABLED TESTS

Compared to previous commit, a fatal test is disabled
This commit is contained in:
rudy
2021-11-18 15:10:47 +01:00
committed by rudy-6-4
parent 9a09afaa80
commit 47b4b667bb
2 changed files with 12 additions and 12 deletions

View File

@@ -55,23 +55,23 @@ JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
}
llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
size_t nbReturn = 0;
if (!this->type.getReturnType().isa<mlir::LLVM::LLVMVoidType>()) {
nbReturn = 1;
}
if (this->type.getNumParams() != args.size() - nbReturn) {
if (!args.empty() && llvm::find(args, nullptr) != args.end()) {
return llvm::make_error<llvm::StringError>(
"invokeRaw: wrong number of argument",
"invoke: some arguments are null or missing",
llvm::inconvertibleErrorCode());
}
if (llvm::find(args, nullptr) != args.end()) {
return llvm::make_error<llvm::StringError>(
"invoke: some arguments are null", llvm::inconvertibleErrorCode());
}
return this->engine->invokePacked(this->name, args);
}
llvm::Error JITLambda::invoke(Argument &args) {
size_t expectedInputs = this->type.getNumParams();
size_t actualInputs = args.inputs.size();
if (expectedInputs != actualInputs) {
auto msg = "invokeRaw: received " + llvm::Twine(actualInputs) +
"arguments instead of " + llvm::Twine(expectedInputs);
return llvm::make_error<llvm::StringError>(msg,
llvm::inconvertibleErrorCode());
}
return std::move(invokeRaw(args.rawArg));
}
@@ -196,7 +196,7 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width,
auto info = std::get<0>(gate);
auto offset = std::get<1>(gate);
// Check if the width is compatible
// TODO - I found this rules empirically, they are a spec somewhere?
// TODO - I found this rules empirically, they are a spec somewhere?
if (info.shape.width <= 8 && width != 8) {
return llvm::make_error<llvm::StringError>(
llvm::Twine("argument width should be 8: pos=")

View File

@@ -86,7 +86,7 @@ TEST(Lambda_check_param, DISABLED_scalar_tensor_to_scalar_superfluous_param) {
ASSERT_EXPECTED_FAILURE(lambda(1_u64, arg, ARRAY_SIZE(arg), arg, ARRAY_SIZE(arg)));
}
TEST(Lambda_check_param, DISABLED_scalar_tensor_to_tensor_good_number_param) {
TEST(Lambda_check_param, scalar_tensor_to_tensor_good_number_param) {
Lambda lambda = checkedJit(R"XXX(
func @main(
%arg0: !HLFHE.eint<1>, %arg1: tensor<2x!HLFHE.eint<1>>) -> tensor<2x!HLFHE.eint<1>>