mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
test(compiler): Add end-to-end test for batched operations
This adds a new end-to-end test `apply_lookup_table_batched`, which forces batching of Concrete operations when invoking the compiler engine, indirectly causing the `concrete.bootstrap_lwe` and `concrete.keyswitch_lwe` operations generated from the `FHELinalg.apply_lookup_table` operation of the test to be batched into `concrete.batched_bootstrap_lwe` and `concrete.batched_keyswitch_lwe` operations. The batched operations trigger the generation of calls to batching wrapper functions further down the pipeline, effectively testing the lowering and implementation of batched operations altogether.
This commit is contained in:
@@ -1458,6 +1458,49 @@ TEST(End2EndJit_FHELinalg, apply_lookup_table) {
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// FHELinalg apply_lookup_table with batching /////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(End2EndJit_FHELinalg, apply_lookup_table_batched) {
|
||||
checkedJit(lambda, R"XXX(
|
||||
func.func @main(%t: tensor<3x3x!FHE.eint<2>>) -> tensor<3x3x!FHE.eint<3>> {
|
||||
%lut = arith.constant dense<[1,3,5,7]> : tensor<4xi64>
|
||||
%res = "FHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!FHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!FHE.eint<3>>
|
||||
return %res : tensor<3x3x!FHE.eint<3>>
|
||||
}
|
||||
)XXX",
|
||||
"main", false, false, false, true);
|
||||
const uint8_t t[3][3]{
|
||||
{0, 1, 2},
|
||||
{3, 0, 1},
|
||||
{2, 3, 0},
|
||||
};
|
||||
const uint8_t expected[3][3]{
|
||||
{1, 3, 5},
|
||||
{7, 1, 3},
|
||||
{5, 7, 1},
|
||||
};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
tArg(llvm::ArrayRef<uint8_t>((const uint8_t *)t, 3 * 3), {3, 3});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&tArg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)3 * 3);
|
||||
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
for (size_t j = 0; j < 3; j++) {
|
||||
EXPECT_EQ((*res)[i * 3 + j], expected[i][j])
|
||||
<< ", at pos(" << i << "," << j << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// FHELinalg apply_multi_lookup_table /////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -20,7 +20,8 @@ inline llvm::Expected<
|
||||
internalCheckedJit(llvm::StringRef src, llvm::StringRef func = "main",
|
||||
bool useDefaultFHEConstraints = false,
|
||||
bool dataflowParallelize = false,
|
||||
bool loopParallelize = false) {
|
||||
bool loopParallelize = false,
|
||||
bool batchConcreteOps = false) {
|
||||
|
||||
auto options =
|
||||
mlir::concretelang::CompilationOptions(std::string(func.data()));
|
||||
@@ -39,6 +40,7 @@ internalCheckedJit(llvm::StringRef src, llvm::StringRef func = "main",
|
||||
options.dataflowParallelize = dataflowParallelize;
|
||||
#endif
|
||||
#endif
|
||||
options.batchConcreteOps = batchConcreteOps;
|
||||
|
||||
auto lambdaOrErr =
|
||||
mlir::concretelang::ClientServer<mlir::concretelang::JITSupport>::create(
|
||||
|
||||
Reference in New Issue
Block a user