test(compiler): Add end-to-end test for batched operations

This adds a new end-to-end test `apply_lookup_table_batched`, which
forces batching of Concrete operations when invoking the compiler
engine, indirectly causing the `concrete.bootstrap_lwe` and
`concrete.keyswitch_lwe` operations generated from the
`FHELinalg.apply_lookup_table` operation of the test to be batched
into `concrete.batched_bootstrap_lwe` and
`concrete.batched_keyswitch_lwe` operations. The batched operations
trigger the generation of calls to batching wrapper functions further
down the pipeline, effectively testing the lowering and implementation
of batched operations altogether.
This commit is contained in:
Andi Drebes
2022-11-15 16:54:34 +01:00
parent 46366eec41
commit ef26c73cb8
2 changed files with 46 additions and 1 deletions

View File

@@ -1458,6 +1458,49 @@ TEST(End2EndJit_FHELinalg, apply_lookup_table) {
}
}
///////////////////////////////////////////////////////////////////////////////
// FHELinalg apply_lookup_table with batching /////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(End2EndJit_FHELinalg, apply_lookup_table_batched) {
checkedJit(lambda, R"XXX(
func.func @main(%t: tensor<3x3x!FHE.eint<2>>) -> tensor<3x3x!FHE.eint<3>> {
%lut = arith.constant dense<[1,3,5,7]> : tensor<4xi64>
%res = "FHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!FHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!FHE.eint<3>>
return %res : tensor<3x3x!FHE.eint<3>>
}
)XXX",
"main", false, false, false, true);
const uint8_t t[3][3]{
{0, 1, 2},
{3, 0, 1},
{2, 3, 0},
};
const uint8_t expected[3][3]{
{1, 3, 5},
{7, 1, 3},
{5, 7, 1},
};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
tArg(llvm::ArrayRef<uint8_t>((const uint8_t *)t, 3 * 3), {3, 3});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&tArg});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)3 * 3);
for (size_t i = 0; i < 3; i++) {
for (size_t j = 0; j < 3; j++) {
EXPECT_EQ((*res)[i * 3 + j], expected[i][j])
<< ", at pos(" << i << "," << j << ")";
}
}
}
///////////////////////////////////////////////////////////////////////////////
// FHELinalg apply_multi_lookup_table /////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////

View File

@@ -20,7 +20,8 @@ inline llvm::Expected<
internalCheckedJit(llvm::StringRef src, llvm::StringRef func = "main",
bool useDefaultFHEConstraints = false,
bool dataflowParallelize = false,
bool loopParallelize = false) {
bool loopParallelize = false,
bool batchConcreteOps = false) {
auto options =
mlir::concretelang::CompilationOptions(std::string(func.data()));
@@ -39,6 +40,7 @@ internalCheckedJit(llvm::StringRef src, llvm::StringRef func = "main",
options.dataflowParallelize = dataflowParallelize;
#endif
#endif
options.batchConcreteOps = batchConcreteOps;
auto lambdaOrErr =
mlir::concretelang::ClientServer<mlir::concretelang::JITSupport>::create(