mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(Lambda): #253 fix the bug about Lambda parameter number verification
[----------] Global test environment tear-down [==========] 6 tests from 1 test suite ran. (1513 ms total) [ PASSED ] 6 tests. YOU HAVE 3 DISABLED TESTS Compared to previous commit, a fatal test is disabled
This commit is contained in:
@@ -55,23 +55,23 @@ JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
|
||||
size_t nbReturn = 0;
|
||||
if (!this->type.getReturnType().isa<mlir::LLVM::LLVMVoidType>()) {
|
||||
nbReturn = 1;
|
||||
}
|
||||
if (this->type.getNumParams() != args.size() - nbReturn) {
|
||||
if (!args.empty() && llvm::find(args, nullptr) != args.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"invokeRaw: wrong number of argument",
|
||||
"invoke: some arguments are null or missing",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
if (llvm::find(args, nullptr) != args.end()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"invoke: some arguments are null", llvm::inconvertibleErrorCode());
|
||||
}
|
||||
return this->engine->invokePacked(this->name, args);
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::invoke(Argument &args) {
|
||||
size_t expectedInputs = this->type.getNumParams();
|
||||
size_t actualInputs = args.inputs.size();
|
||||
if (expectedInputs != actualInputs) {
|
||||
auto msg = "invokeRaw: received " + llvm::Twine(actualInputs) +
|
||||
"arguments instead of " + llvm::Twine(expectedInputs);
|
||||
return llvm::make_error<llvm::StringError>(msg,
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
return std::move(invokeRaw(args.rawArg));
|
||||
}
|
||||
|
||||
@@ -196,7 +196,7 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width,
|
||||
auto info = std::get<0>(gate);
|
||||
auto offset = std::get<1>(gate);
|
||||
// Check if the width is compatible
|
||||
// TODO - I found this rules empirically, they are a spec somewhere?
|
||||
// TODO - I found this rules empirically, they are a spec somewhere?
|
||||
if (info.shape.width <= 8 && width != 8) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
llvm::Twine("argument width should be 8: pos=")
|
||||
|
||||
@@ -86,7 +86,7 @@ TEST(Lambda_check_param, DISABLED_scalar_tensor_to_scalar_superfluous_param) {
|
||||
ASSERT_EXPECTED_FAILURE(lambda(1_u64, arg, ARRAY_SIZE(arg), arg, ARRAY_SIZE(arg)));
|
||||
}
|
||||
|
||||
TEST(Lambda_check_param, DISABLED_scalar_tensor_to_tensor_good_number_param) {
|
||||
TEST(Lambda_check_param, scalar_tensor_to_tensor_good_number_param) {
|
||||
Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(
|
||||
%arg0: !HLFHE.eint<1>, %arg1: tensor<2x!HLFHE.eint<1>>) -> tensor<2x!HLFHE.eint<1>>
|
||||
|
||||
Reference in New Issue
Block a user