From ef26c73cb8c01ca1d5167559e92a75a3b8949ec3 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Tue, 15 Nov 2022 16:54:34 +0100 Subject: [PATCH] test(compiler): Add end-to-end test for batched operations This adds a new end-to-end test `apply_lookup_table_batched`, which forces batching of Concrete operations when invoking the compiler engine, indirectly causing the `concrete.bootstrap_lwe` and `concrete.keyswitch_lwe` operations generated from the `FHELinalg.apply_lookup_table` operation of the test to be batched into `concrete.batched_bootstrap_lwe` and `concrete.batched_keyswitch_lwe` operations. The batched operations trigger the generation of calls to batching wrapper functions further down the pipeline, effectively testing the lowering and implementation of batched operations altogether. --- .../end_to_end_jit_fhelinalg.cc | 43 +++++++++++++++++++ .../end_to_end_tests/end_to_end_jit_test.h | 4 +- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/compiler/tests/end_to_end_tests/end_to_end_jit_fhelinalg.cc b/compiler/tests/end_to_end_tests/end_to_end_jit_fhelinalg.cc index 20b15d236..f5af96463 100644 --- a/compiler/tests/end_to_end_tests/end_to_end_jit_fhelinalg.cc +++ b/compiler/tests/end_to_end_tests/end_to_end_jit_fhelinalg.cc @@ -1458,6 +1458,49 @@ TEST(End2EndJit_FHELinalg, apply_lookup_table) { } } +/////////////////////////////////////////////////////////////////////////////// +// FHELinalg apply_lookup_table with batching ///////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(End2EndJit_FHELinalg, apply_lookup_table_batched) { + checkedJit(lambda, R"XXX( + func.func @main(%t: tensor<3x3x!FHE.eint<2>>) -> tensor<3x3x!FHE.eint<3>> { + %lut = arith.constant dense<[1,3,5,7]> : tensor<4xi64> + %res = "FHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!FHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!FHE.eint<3>> + return %res : tensor<3x3x!FHE.eint<3>> + } +)XXX", + "main", false, false, false, true); + const uint8_t t[3][3]{ + {0, 1, 2}, + {3, 0, 1}, + {2, 3, 0}, + }; + const uint8_t expected[3][3]{ + {1, 3, 5}, + {7, 1, 3}, + {5, 7, 1}, + }; + + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + tArg(llvm::ArrayRef((const uint8_t *)t, 3 * 3), {3, 3}); + + llvm::Expected> res = + lambda.operator()>({&tArg}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)3 * 3); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 3; j++) { + EXPECT_EQ((*res)[i * 3 + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + /////////////////////////////////////////////////////////////////////////////// // FHELinalg apply_multi_lookup_table ///////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////// diff --git a/compiler/tests/end_to_end_tests/end_to_end_jit_test.h b/compiler/tests/end_to_end_tests/end_to_end_jit_test.h index 18e15b767..d35a47be0 100644 --- a/compiler/tests/end_to_end_tests/end_to_end_jit_test.h +++ b/compiler/tests/end_to_end_tests/end_to_end_jit_test.h @@ -20,7 +20,8 @@ inline llvm::Expected< internalCheckedJit(llvm::StringRef src, llvm::StringRef func = "main", bool useDefaultFHEConstraints = false, bool dataflowParallelize = false, - bool loopParallelize = false) { + bool loopParallelize = false, + bool batchConcreteOps = false) { auto options = mlir::concretelang::CompilationOptions(std::string(func.data())); @@ -39,6 +40,7 @@ internalCheckedJit(llvm::StringRef src, llvm::StringRef func = "main", options.dataflowParallelize = dataflowParallelize; #endif #endif + options.batchConcreteOps = batchConcreteOps; auto lambdaOrErr = mlir::concretelang::ClientServer::create(