test(compiler): Add tests for JitCompilerEngine::Lambda with a generic result

This commit is contained in:
Andi Drebes
2021-11-03 17:07:38 +01:00
committed by Ayoub Benaissa
parent 9040e5ab00
commit 3d6ad06f46

View File

@@ -21,7 +21,7 @@ func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
}
// Same as CompileAndRunHLFHE::add_eint above, but using
// `LambdaArgument` instances
// `LambdaArgument` instances as arguments
TEST(CompileAndRunHLFHE, add_eint_lambda_argument) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
@@ -42,6 +42,41 @@ func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
ASSERT_EXPECTED_VALUE(lambda({&ila2, &ila7}), 9);
}
// Same as CompileAndRunHLFHE::add_eint above, but using
// `LambdaArgument` instances as arguments and as a result type
TEST(CompileAndRunHLFHE, add_eint_lambda_argument_res) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
return %1: !HLFHE.eint<7>
}
)XXX");
mlir::zamalang::IntLambdaArgument<> ila1(1);
mlir::zamalang::IntLambdaArgument<> ila2(2);
mlir::zamalang::IntLambdaArgument<> ila7(7);
mlir::zamalang::IntLambdaArgument<> ila9(9);
auto eval = [&](mlir::zamalang::IntLambdaArgument<> &arg0,
mlir::zamalang::IntLambdaArgument<> &arg1,
uint64_t expected) {
llvm::Expected<std::unique_ptr<mlir::zamalang::LambdaArgument>> res0 =
lambda.operator()<std::unique_ptr<mlir::zamalang::LambdaArgument>>(
{&arg0, &arg1});
ASSERT_EXPECTED_SUCCESS(res0);
ASSERT_TRUE((*res0)->isa<mlir::zamalang::IntLambdaArgument<>>());
ASSERT_EQ((*res0)->cast<mlir::zamalang::IntLambdaArgument<>>().getValue(),
expected);
};
eval(ila1, ila2, 3);
eval(ila7, ila9, 16);
eval(ila1, ila7, 8);
eval(ila1, ila9, 10);
eval(ila2, ila7, 9);
}
TEST(CompileAndRunHLFHE, add_u64) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%arg0: i64, %arg1: i64) -> i64 {
@@ -96,7 +131,7 @@ func @main(%t: tensor<10xi32>, %i: index) -> i32{
}
// Same as `CompileAndRunTensorStd::extract_32` above, but using
// `LambdaArgument` instances
// `LambdaArgument` instances as arguments
TEST(CompileAndRunTensorStd, extract_32_lambda_argument) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%t: tensor<10xi32>, %i: index) -> i32{
@@ -241,6 +276,38 @@ func @main(%0: !HLFHE.eint<5>) -> tensor<1x!HLFHE.eint<5>> {
ASSERT_EQ(res->at(0), 10_u64);
}
// Same as `CompileAndRunTensorEncrypted::from_elements_5 but with
// `LambdaArgument` instances as arguments and as a result type
TEST(CompileAndRunTensorEncrypted, from_elements_5_lambda_argument_res) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%0: !HLFHE.eint<5>) -> tensor<1x!HLFHE.eint<5>> {
%t = tensor.from_elements %0 : tensor<1x!HLFHE.eint<5>>
return %t: tensor<1x!HLFHE.eint<5>>
}
)XXX");
mlir::zamalang::IntLambdaArgument<> arg(10);
llvm::Expected<std::unique_ptr<mlir::zamalang::LambdaArgument>> res =
lambda.operator()<std::unique_ptr<mlir::zamalang::LambdaArgument>>(
{&arg});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_TRUE((*res)
->isa<mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<>>>());
mlir::zamalang::TensorLambdaArgument<mlir::zamalang::IntLambdaArgument<>>
&resp = (*res)
->cast<mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<>>>();
ASSERT_EQ(resp.getDimensions().size(), (size_t)1);
ASSERT_EQ(resp.getDimensions().at(0), 1);
ASSERT_EXPECTED_VALUE(resp.getNumElements(), 1);
// ASSERT_EQ(resp.getValue()[0], 10_u64);
}
TEST(CompileAndRunTensorEncrypted, in_out_tensor_with_op_5) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%in: tensor<2x!HLFHE.eint<5>>) -> tensor<3x!HLFHE.eint<5>> {