feat(tests): test lambda call with bad number of parameters or bad parameters

[----------] Global test environment tear-down
[==========] 6 tests from 1 test suite ran. (1402 ms total)
[  PASSED  ] 6 tests.

  YOU HAVE 3 DISABLED TESTS

2 tests are disabled because execution is fatal
This commit is contained in:
rudy
2021-11-18 15:09:18 +01:00
committed by rudy-6-4
parent dad4390518
commit cc9186d60d
4 changed files with 152 additions and 2 deletions

View File

@@ -61,6 +61,9 @@ build-end-to-end-jit-encrypted-tensor: build-initialized
build-end-to-end-jit-hlfhelinalg: build-initialized
cmake --build $(BUILD_DIR) --target end_to_end_jit_hlfhelinalg
build-end-to-end-jit-lamnda: build-initialized
cmake --build $(BUILD_DIR) --target end_to_end_jit_lambda
build-end-to-end-jit: build-end-to-end-jit-test build-end-to-end-jit-clear-tensor build-end-to-end-jit-encrypted-tensor build-end-to-end-jit-hlfhelinalg
test-end-to-end-jit-test: build-end-to-end-jit-test
@@ -75,6 +78,11 @@ test-end-to-end-jit-encrypted-tensor: build-end-to-end-jit-encrypted-tensor
test-end-to-end-jit-hlfhelinalg: build-end-to-end-jit-hlfhelinalg
$(BUILD_DIR)/bin/end_to_end_jit_hlfhelinalg
test-end-to-end-jit-lambda: build-initialized build-end-to-end-jit-lamnda
$(BUILD_DIR)/bin/end_to_end_jit_lambda
test-end-to-end-jit: test-end-to-end-jit-test test-end-to-end-jit-clear-tensor test-end-to-end-jit-encrypted-tensor test-end-to-end-jit-hlfhelinalg
# LLVM/MLIR dependencies

View File

@@ -26,12 +26,18 @@ add_executable(
globals.cc
)
add_executable(
end_to_end_jit_lambda
end_to_end_jit_lambda.cc
globals.cc
)
set_source_files_properties(
end_to_end_jit_test.cc
end_to_end_jit_clear_tensor.cc
end_to_end_jit_encrypted_tensor.cc
end_to_end_jit_hlfhelinalg.cc
globals.cc
end_to_end_jit_lambda.cc
PROPERTIES COMPILE_FLAGS "-fno-rtti"
)
@@ -60,9 +66,16 @@ target_link_libraries(
ZamalangSupport
)
target_link_libraries(
end_to_end_jit_lambda
gtest_main
ZamalangSupport
)
include(GoogleTest)
gtest_discover_tests(end_to_end_jit_test)
gtest_discover_tests(end_to_end_jit_clear_tensor)
gtest_discover_tests(end_to_end_jit_encrypted_tensor)
gtest_discover_tests(end_to_end_jit_hlfhelinalg)
gtest_discover_tests(end_to_end_jit_lambda)

View File

@@ -0,0 +1,114 @@
#include <gtest/gtest.h>
#include "end_to_end_jit_test.h"
const mlir::zamalang::V0FHEConstraint defaultV0Constraints{10, 7};
using Lambda = mlir::zamalang::JitCompilerEngine::Lambda;
TEST(Lambda_check_param, int_to_void_missing_param) {
Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: !HLFHE.eint<1>) {
return
}
)XXX");
ASSERT_EXPECTED_FAILURE(lambda());
}
TEST(Lambda_check_param, DISABLED_int_to_void_good) {
// DISABLED Note: it segfaults
Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: !HLFHE.eint<1>) {
return
}
)XXX");
ASSERT_EXPECTED_SUCCESS(lambda(1_u64));
}
TEST(Lambda_check_param, int_to_void_superfluous_param) {
Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: !HLFHE.eint<1>) {
return
}
)XXX");
ASSERT_EXPECTED_FAILURE(lambda(1_u64, 1_u64));
}
TEST(Lambda_check_param, scalar_parameters_number) {
Lambda lambda = checkedJit(R"XXX(
func @main(
%arg0: !HLFHE.eint<1>, %arg1: !HLFHE.eint<1>,
%arg2: !HLFHE.eint<1>) -> !HLFHE.eint<1>
{
return %arg0: !HLFHE.eint<1>
}
)XXX");
ASSERT_EXPECTED_FAILURE(lambda());
ASSERT_EXPECTED_FAILURE(lambda(1_u64));
ASSERT_EXPECTED_FAILURE(lambda(1_u64, 2_u64));
ASSERT_EXPECTED_SUCCESS(lambda(1_u64, 2_u64, 3_u64));
ASSERT_EXPECTED_FAILURE(lambda(1_u64, 2_u64, 3_u64, 4_u64));
}
TEST(Lambda_check_param, scalar_tensor_to_scalar_missing_param) {
Lambda lambda = checkedJit(R"XXX(
func @main(
%arg0: !HLFHE.eint<1>, %arg1: tensor<2x!HLFHE.eint<1>>) -> !HLFHE.eint<1>
{
return %arg0: !HLFHE.eint<1>
}
)XXX");
ASSERT_EXPECTED_FAILURE(lambda(1_u64));
}
TEST(Lambda_check_param, scalar_tensor_to_scalar) {
Lambda lambda = checkedJit(R"XXX(
func @main(
%arg0: !HLFHE.eint<1>, %arg1: tensor<2x!HLFHE.eint<1>>) -> !HLFHE.eint<1>
{
return %arg0: !HLFHE.eint<1>
}
)XXX");
uint8_t arg[2] = {1 ,2};
ASSERT_EXPECTED_SUCCESS(lambda(1_u64, arg, ARRAY_SIZE(arg)));
}
TEST(Lambda_check_param, scalar_tensor_to_scalar_superfluous_param) {
// DISABLED Note: "terminate called after throwing an instance of 'std::bad_alloc'"
Lambda lambda = checkedJit(R"XXX(
func @main(
%arg0: !HLFHE.eint<1>, %arg1: tensor<2x!HLFHE.eint<1>>) -> !HLFHE.eint<1>
{
return %arg0: !HLFHE.eint<1>
}
)XXX");
uint8_t arg[2] = {1 ,2};
ASSERT_EXPECTED_FAILURE(lambda(1_u64, arg, ARRAY_SIZE(arg), arg, ARRAY_SIZE(arg)));
}
TEST(Lambda_check_param, scalar_tensor_to_tensor_good_number_param) {
Lambda lambda = checkedJit(R"XXX(
func @main(
%arg0: !HLFHE.eint<1>, %arg1: tensor<2x!HLFHE.eint<1>>) -> tensor<2x!HLFHE.eint<1>>
{
return %arg1: tensor<2x!HLFHE.eint<1>>
}
)XXX");
uint8_t arg[2] = {1 ,2};
ASSERT_EXPECTED_SUCCESS(
lambda.operator()<std::vector<uint8_t>>(1_u64, arg, ARRAY_SIZE(arg))
);
}
TEST(Lambda_check_param, DISABLED_check_parameters_scalar_too_big) {
// DISABLED Note: loss of precision without any warning or error.
Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: !HLFHE.eint<1>) -> !HLFHE.eint<1>
{
return %arg0: !HLFHE.eint<1>
}
)XXX");
uint16_t arg = 3;
ASSERT_EXPECTED_FAILURE(lambda(arg));
}

View File

@@ -21,7 +21,7 @@
template <typename T>
static bool assert_expected_success(llvm::Expected<T> &val) {
if (!((bool)val)) {
llvm::errs() << llvm::toString(std::move(val.takeError()));
llvm::errs() << llvm::toString(std::move(val.takeError())) << "\n";
return false;
}
@@ -35,6 +35,13 @@ static bool assert_expected_success(llvm::Expected<T> &&val) {
return assert_expected_success(val);
}
// Checks that the value `val` is not in an error state. Returns
// `true` if the test passes, otherwise `false`.
template <typename T>
static bool assert_expected_failure(llvm::Expected<T> &&val) {
return !assert_expected_success(val);
}
// Checks that the value `val` of type `llvm::Expected<T>` is not in
// an error state.
#define ASSERT_EXPECTED_SUCCESS(val) \
@@ -43,6 +50,14 @@ static bool assert_expected_success(llvm::Expected<T> &&val) {
GTEST_FATAL_FAILURE_("Expected<T> contained in error state"); \
} while (0)
// Checks that the value `val` of type `llvm::Expected<T>` is in
// an error state.
#define ASSERT_EXPECTED_FAILURE(val) \
do { \
if (assert_expected_success(val)) \
GTEST_FATAL_FAILURE_("Expected<T> contained not in error state"); \
} while (0)
// Checks that the value `val` is not in an error state and is equal
// to the value given in `exp`. Returns `true` if the test passes,
// otherwise `false`.