diff --git a/compiler/Makefile b/compiler/Makefile index 9d4c89ded..e35cf5d72 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -61,6 +61,9 @@ build-end-to-end-jit-encrypted-tensor: build-initialized build-end-to-end-jit-hlfhelinalg: build-initialized cmake --build $(BUILD_DIR) --target end_to_end_jit_hlfhelinalg +build-end-to-end-jit-lamnda: build-initialized + cmake --build $(BUILD_DIR) --target end_to_end_jit_lambda + build-end-to-end-jit: build-end-to-end-jit-test build-end-to-end-jit-clear-tensor build-end-to-end-jit-encrypted-tensor build-end-to-end-jit-hlfhelinalg test-end-to-end-jit-test: build-end-to-end-jit-test @@ -75,6 +78,11 @@ test-end-to-end-jit-encrypted-tensor: build-end-to-end-jit-encrypted-tensor test-end-to-end-jit-hlfhelinalg: build-end-to-end-jit-hlfhelinalg $(BUILD_DIR)/bin/end_to_end_jit_hlfhelinalg +test-end-to-end-jit-lambda: build-initialized build-end-to-end-jit-lamnda + $(BUILD_DIR)/bin/end_to_end_jit_lambda + + + test-end-to-end-jit: test-end-to-end-jit-test test-end-to-end-jit-clear-tensor test-end-to-end-jit-encrypted-tensor test-end-to-end-jit-hlfhelinalg # LLVM/MLIR dependencies diff --git a/compiler/tests/unittest/CMakeLists.txt b/compiler/tests/unittest/CMakeLists.txt index 8eb96e1e2..b6ca3f966 100644 --- a/compiler/tests/unittest/CMakeLists.txt +++ b/compiler/tests/unittest/CMakeLists.txt @@ -26,12 +26,18 @@ add_executable( globals.cc ) +add_executable( + end_to_end_jit_lambda + end_to_end_jit_lambda.cc + globals.cc +) + set_source_files_properties( end_to_end_jit_test.cc end_to_end_jit_clear_tensor.cc end_to_end_jit_encrypted_tensor.cc end_to_end_jit_hlfhelinalg.cc - globals.cc + end_to_end_jit_lambda.cc PROPERTIES COMPILE_FLAGS "-fno-rtti" ) @@ -60,9 +66,16 @@ target_link_libraries( ZamalangSupport ) +target_link_libraries( + end_to_end_jit_lambda + gtest_main + ZamalangSupport +) + include(GoogleTest) gtest_discover_tests(end_to_end_jit_test) gtest_discover_tests(end_to_end_jit_clear_tensor) gtest_discover_tests(end_to_end_jit_encrypted_tensor) gtest_discover_tests(end_to_end_jit_hlfhelinalg) +gtest_discover_tests(end_to_end_jit_lambda) diff --git a/compiler/tests/unittest/end_to_end_jit_lambda.cc b/compiler/tests/unittest/end_to_end_jit_lambda.cc new file mode 100644 index 000000000..6400521c0 --- /dev/null +++ b/compiler/tests/unittest/end_to_end_jit_lambda.cc @@ -0,0 +1,114 @@ +#include + +#include "end_to_end_jit_test.h" + +const mlir::zamalang::V0FHEConstraint defaultV0Constraints{10, 7}; + +using Lambda = mlir::zamalang::JitCompilerEngine::Lambda; + + +TEST(Lambda_check_param, int_to_void_missing_param) { + Lambda lambda = checkedJit(R"XXX( + func @main(%arg0: !HLFHE.eint<1>) { + return + } + )XXX"); + ASSERT_EXPECTED_FAILURE(lambda()); +} + +TEST(Lambda_check_param, DISABLED_int_to_void_good) { + // DISABLED Note: it segfaults + Lambda lambda = checkedJit(R"XXX( + func @main(%arg0: !HLFHE.eint<1>) { + return + } + )XXX"); + ASSERT_EXPECTED_SUCCESS(lambda(1_u64)); +} + +TEST(Lambda_check_param, int_to_void_superfluous_param) { + Lambda lambda = checkedJit(R"XXX( + func @main(%arg0: !HLFHE.eint<1>) { + return + } + )XXX"); + ASSERT_EXPECTED_FAILURE(lambda(1_u64, 1_u64)); +} + +TEST(Lambda_check_param, scalar_parameters_number) { + Lambda lambda = checkedJit(R"XXX( + func @main( + %arg0: !HLFHE.eint<1>, %arg1: !HLFHE.eint<1>, + %arg2: !HLFHE.eint<1>) -> !HLFHE.eint<1> + { + return %arg0: !HLFHE.eint<1> + } + )XXX"); + ASSERT_EXPECTED_FAILURE(lambda()); + ASSERT_EXPECTED_FAILURE(lambda(1_u64)); + ASSERT_EXPECTED_FAILURE(lambda(1_u64, 2_u64)); + ASSERT_EXPECTED_SUCCESS(lambda(1_u64, 2_u64, 3_u64)); + ASSERT_EXPECTED_FAILURE(lambda(1_u64, 2_u64, 3_u64, 4_u64)); +} + +TEST(Lambda_check_param, scalar_tensor_to_scalar_missing_param) { + Lambda lambda = checkedJit(R"XXX( + func @main( + %arg0: !HLFHE.eint<1>, %arg1: tensor<2x!HLFHE.eint<1>>) -> !HLFHE.eint<1> + { + return %arg0: !HLFHE.eint<1> + } + )XXX"); + ASSERT_EXPECTED_FAILURE(lambda(1_u64)); +} + +TEST(Lambda_check_param, scalar_tensor_to_scalar) { + Lambda lambda = checkedJit(R"XXX( + func @main( + %arg0: !HLFHE.eint<1>, %arg1: tensor<2x!HLFHE.eint<1>>) -> !HLFHE.eint<1> + { + return %arg0: !HLFHE.eint<1> + } + )XXX"); + uint8_t arg[2] = {1 ,2}; + ASSERT_EXPECTED_SUCCESS(lambda(1_u64, arg, ARRAY_SIZE(arg))); +} + +TEST(Lambda_check_param, scalar_tensor_to_scalar_superfluous_param) { + // DISABLED Note: "terminate called after throwing an instance of 'std::bad_alloc'" + Lambda lambda = checkedJit(R"XXX( + func @main( + %arg0: !HLFHE.eint<1>, %arg1: tensor<2x!HLFHE.eint<1>>) -> !HLFHE.eint<1> + { + return %arg0: !HLFHE.eint<1> + } + )XXX"); + uint8_t arg[2] = {1 ,2}; + ASSERT_EXPECTED_FAILURE(lambda(1_u64, arg, ARRAY_SIZE(arg), arg, ARRAY_SIZE(arg))); +} + +TEST(Lambda_check_param, scalar_tensor_to_tensor_good_number_param) { + Lambda lambda = checkedJit(R"XXX( + func @main( + %arg0: !HLFHE.eint<1>, %arg1: tensor<2x!HLFHE.eint<1>>) -> tensor<2x!HLFHE.eint<1>> + { + return %arg1: tensor<2x!HLFHE.eint<1>> + } + )XXX"); + uint8_t arg[2] = {1 ,2}; + ASSERT_EXPECTED_SUCCESS( + lambda.operator()>(1_u64, arg, ARRAY_SIZE(arg)) + ); +} + +TEST(Lambda_check_param, DISABLED_check_parameters_scalar_too_big) { + // DISABLED Note: loss of precision without any warning or error. + Lambda lambda = checkedJit(R"XXX( + func @main(%arg0: !HLFHE.eint<1>) -> !HLFHE.eint<1> + { + return %arg0: !HLFHE.eint<1> + } + )XXX"); + uint16_t arg = 3; + ASSERT_EXPECTED_FAILURE(lambda(arg)); +} diff --git a/compiler/tests/unittest/end_to_end_jit_test.h b/compiler/tests/unittest/end_to_end_jit_test.h index 94e23de6a..613550997 100644 --- a/compiler/tests/unittest/end_to_end_jit_test.h +++ b/compiler/tests/unittest/end_to_end_jit_test.h @@ -21,7 +21,7 @@ template static bool assert_expected_success(llvm::Expected &val) { if (!((bool)val)) { - llvm::errs() << llvm::toString(std::move(val.takeError())); + llvm::errs() << llvm::toString(std::move(val.takeError())) << "\n"; return false; } @@ -35,6 +35,13 @@ static bool assert_expected_success(llvm::Expected &&val) { return assert_expected_success(val); } +// Checks that the value `val` is not in an error state. Returns +// `true` if the test passes, otherwise `false`. +template +static bool assert_expected_failure(llvm::Expected &&val) { + return !assert_expected_success(val); +} + // Checks that the value `val` of type `llvm::Expected` is not in // an error state. #define ASSERT_EXPECTED_SUCCESS(val) \ @@ -43,6 +50,14 @@ static bool assert_expected_success(llvm::Expected &&val) { GTEST_FATAL_FAILURE_("Expected contained in error state"); \ } while (0) +// Checks that the value `val` of type `llvm::Expected` is in +// an error state. +#define ASSERT_EXPECTED_FAILURE(val) \ + do { \ + if (assert_expected_success(val)) \ + GTEST_FATAL_FAILURE_("Expected contained not in error state"); \ + } while (0) + // Checks that the value `val` is not in an error state and is equal // to the value given in `exp`. Returns `true` if the test passes, // otherwise `false`.