mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
test(compiler): Add tests for JitCompilerEngine::Lambda with a generic result
This commit is contained in:
committed by
Ayoub Benaissa
parent
9040e5ab00
commit
3d6ad06f46
@@ -21,7 +21,7 @@ func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
}
|
||||
|
||||
// Same as CompileAndRunHLFHE::add_eint above, but using
|
||||
// `LambdaArgument` instances
|
||||
// `LambdaArgument` instances as arguments
|
||||
TEST(CompileAndRunHLFHE, add_eint_lambda_argument) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
@@ -42,6 +42,41 @@ func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
ASSERT_EXPECTED_VALUE(lambda({&ila2, &ila7}), 9);
|
||||
}
|
||||
|
||||
// Same as CompileAndRunHLFHE::add_eint above, but using
|
||||
// `LambdaArgument` instances as arguments and as a result type
|
||||
TEST(CompileAndRunHLFHE, add_eint_lambda_argument_res) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
return %1: !HLFHE.eint<7>
|
||||
}
|
||||
)XXX");
|
||||
|
||||
mlir::zamalang::IntLambdaArgument<> ila1(1);
|
||||
mlir::zamalang::IntLambdaArgument<> ila2(2);
|
||||
mlir::zamalang::IntLambdaArgument<> ila7(7);
|
||||
mlir::zamalang::IntLambdaArgument<> ila9(9);
|
||||
|
||||
auto eval = [&](mlir::zamalang::IntLambdaArgument<> &arg0,
|
||||
mlir::zamalang::IntLambdaArgument<> &arg1,
|
||||
uint64_t expected) {
|
||||
llvm::Expected<std::unique_ptr<mlir::zamalang::LambdaArgument>> res0 =
|
||||
lambda.operator()<std::unique_ptr<mlir::zamalang::LambdaArgument>>(
|
||||
{&arg0, &arg1});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res0);
|
||||
ASSERT_TRUE((*res0)->isa<mlir::zamalang::IntLambdaArgument<>>());
|
||||
ASSERT_EQ((*res0)->cast<mlir::zamalang::IntLambdaArgument<>>().getValue(),
|
||||
expected);
|
||||
};
|
||||
|
||||
eval(ila1, ila2, 3);
|
||||
eval(ila7, ila9, 16);
|
||||
eval(ila1, ila7, 8);
|
||||
eval(ila1, ila9, 10);
|
||||
eval(ila2, ila7, 9);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunHLFHE, add_u64) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%arg0: i64, %arg1: i64) -> i64 {
|
||||
@@ -96,7 +131,7 @@ func @main(%t: tensor<10xi32>, %i: index) -> i32{
|
||||
}
|
||||
|
||||
// Same as `CompileAndRunTensorStd::extract_32` above, but using
|
||||
// `LambdaArgument` instances
|
||||
// `LambdaArgument` instances as arguments
|
||||
TEST(CompileAndRunTensorStd, extract_32_lambda_argument) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi32>, %i: index) -> i32{
|
||||
@@ -241,6 +276,38 @@ func @main(%0: !HLFHE.eint<5>) -> tensor<1x!HLFHE.eint<5>> {
|
||||
ASSERT_EQ(res->at(0), 10_u64);
|
||||
}
|
||||
|
||||
// Same as `CompileAndRunTensorEncrypted::from_elements_5 but with
|
||||
// `LambdaArgument` instances as arguments and as a result type
|
||||
TEST(CompileAndRunTensorEncrypted, from_elements_5_lambda_argument_res) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%0: !HLFHE.eint<5>) -> tensor<1x!HLFHE.eint<5>> {
|
||||
%t = tensor.from_elements %0 : tensor<1x!HLFHE.eint<5>>
|
||||
return %t: tensor<1x!HLFHE.eint<5>>
|
||||
}
|
||||
)XXX");
|
||||
|
||||
mlir::zamalang::IntLambdaArgument<> arg(10);
|
||||
|
||||
llvm::Expected<std::unique_ptr<mlir::zamalang::LambdaArgument>> res =
|
||||
lambda.operator()<std::unique_ptr<mlir::zamalang::LambdaArgument>>(
|
||||
{&arg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
ASSERT_TRUE((*res)
|
||||
->isa<mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<>>>());
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<mlir::zamalang::IntLambdaArgument<>>
|
||||
&resp = (*res)
|
||||
->cast<mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<>>>();
|
||||
|
||||
ASSERT_EQ(resp.getDimensions().size(), (size_t)1);
|
||||
ASSERT_EQ(resp.getDimensions().at(0), 1);
|
||||
ASSERT_EXPECTED_VALUE(resp.getNumElements(), 1);
|
||||
// ASSERT_EQ(resp.getValue()[0], 10_u64);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, in_out_tensor_with_op_5) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%in: tensor<2x!HLFHE.eint<5>>) -> tensor<3x!HLFHE.eint<5>> {
|
||||
|
||||
Reference in New Issue
Block a user