diff --git a/compiler/tests/unittest/EndToEndFixture.h b/compiler/tests/unittest/EndToEndFixture.h index b8c181ec5..705dd791e 100644 --- a/compiler/tests/unittest/EndToEndFixture.h +++ b/compiler/tests/unittest/EndToEndFixture.h @@ -52,9 +52,8 @@ valueDescriptionToLambdaArgument(ValueDescription desc); llvm::Error checkResult(ScalarDesc &desc, mlir::concretelang::LambdaArgument *res); -llvm::Error -checkResult(ValueDescription &desc, - std::unique_ptr &res); +llvm::Error checkResult(ValueDescription &desc, + mlir::concretelang::LambdaArgument &res); std::vector loadEndToEndDesc(std::string path); diff --git a/compiler/tests/unittest/end_to_end_jit_fhe.cc b/compiler/tests/unittest/end_to_end_jit_fhe.cc index 0e88a6793..17bbe2eb9 100644 --- a/compiler/tests/unittest/end_to_end_jit_fhe.cc +++ b/compiler/tests/unittest/end_to_end_jit_fhe.cc @@ -7,60 +7,64 @@ #include "concretelang/Support/JitLambdaSupport.h" #include "concretelang/Support/LibraryLambdaSupport.h" +template +void compile_and_run(EndToEndDesc desc, LambdaSupport support) { + + /* 1 - Compile the program */ + auto compilationResult = support.compile(desc.program); + ASSERT_EXPECTED_SUCCESS(compilationResult); + + /* 2 - Load the client parameters and build the keySet */ + auto clientParameters = support.loadClientParameters(**compilationResult); + ASSERT_EXPECTED_SUCCESS(clientParameters); + + auto keySet = support.keySet(*clientParameters, getTestKeySetCache()); + ASSERT_EXPECTED_SUCCESS(keySet); + + /* 3 - Load the server lambda */ + auto serverLambda = support.loadServerLambda(**compilationResult); + ASSERT_EXPECTED_SUCCESS(serverLambda); + + /* For each test entries */ + for (auto test : desc.tests) { + std::vector inputArguments; + inputArguments.reserve(test.inputs.size()); + for (auto input : test.inputs) { + auto arg = valueDescriptionToLambdaArgument(input); + ASSERT_EXPECTED_SUCCESS(arg); + inputArguments.push_back(arg.get()); + } + + /* 4 - Create the public arguments */ + auto publicArguments = + support.exportArguments(*clientParameters, **keySet, inputArguments); + ASSERT_EXPECTED_SUCCESS(publicArguments); + + /* 5 - Call the server lambda */ + auto publicResult = support.serverCall(*serverLambda, **publicArguments); + ASSERT_EXPECTED_SUCCESS(publicResult); + + /* 6 - Decrypt the public result */ + auto result = mlir::concretelang::typedResult< + std::unique_ptr>(**keySet, + **publicResult); + + /* 7 - Check result */ + ASSERT_EXPECTED_SUCCESS(result); + ASSERT_LLVM_ERROR(checkResult(test.outputs[0], **result)); + + for (auto arg : inputArguments) { + delete arg; + } + } +} + // Macro to define and end to end TestSuite that run test thanks the // LambdaSupport according a EndToEndDesc -#define INSTANTIATE_END_TO_END_COMPILE_AND_RUN(TestSuite, LambdaSupport) \ +#define INSTANTIATE_END_TO_END_COMPILE_AND_RUN(TestSuite, lambdaSupport) \ TEST_P(TestSuite, compile_and_run) { \ - \ auto desc = GetParam(); \ - \ - auto support = LambdaSupport; \ - \ - /* 1 - Compile the program */ \ - auto compilationResult = support.compile(desc.program); \ - ASSERT_EXPECTED_SUCCESS(compilationResult); \ - \ - /* 2 - Load the client parameters and build the keySet */ \ - auto clientParameters = support.loadClientParameters(**compilationResult); \ - ASSERT_EXPECTED_SUCCESS(clientParameters); \ - \ - auto keySet = support.keySet(*clientParameters, getTestKeySetCache()); \ - ASSERT_EXPECTED_SUCCESS(keySet); \ - \ - /* 3 - Load the server lambda */ \ - auto serverLambda = support.loadServerLambda(**compilationResult); \ - ASSERT_EXPECTED_SUCCESS(serverLambda); \ - \ - /* For each test entries */ \ - for (auto test : desc.tests) { \ - std::vector inputArguments; \ - inputArguments.reserve(test.inputs.size()); \ - for (auto input : test.inputs) { \ - auto arg = valueDescToLambdaArgument(input); \ - ASSERT_EXPECTED_SUCCESS(arg); \ - inputArguments.push_back(arg.get()); \ - } \ - /* 4 - Create the public arguments */ \ - auto publicArguments = support.exportArguments( \ - *clientParameters, **keySet, inputArguments); \ - ASSERT_EXPECTED_SUCCESS(publicArguments); \ - \ - /* 5 - Call the server lambda */ \ - auto publicResult = \ - support.serverCall(*serverLambda, **publicArguments); \ - ASSERT_EXPECTED_SUCCESS(publicResult); \ - \ - /* 6 - Decrypt the public result */ \ - auto result = mlir::concretelang::typedResult< \ - std::unique_ptr>( \ - **keySet, **publicResult); \ - \ - ASSERT_EXPECTED_SUCCESS(result); \ - \ - for (auto arg : inputArguments) { \ - delete arg; \ - } \ - } \ + compile_and_run(desc, lambdaSupport); \ } #define INSTANTIATE_END_TO_END_TEST_SUITE_FROM_FILE(prefix, suite, \