diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 36fe58e35..8961ecf24 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -154,6 +154,10 @@ lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, addPotentiallyNestedPass( pm, mlir::concretelang::createConvertFHETensorOpsToLinalg(), enablePass); + // FHETensorOpsToLinalg does generate linalg named ops that need to be lowered + // to linalg.generic operations + addPotentiallyNestedPass(pm, mlir::createLinalgGeneralizationPass(), + enablePass); addPotentiallyNestedPass(pm, mlir::concretelang::createConvertFHEToTFHEPass(), enablePass); @@ -241,6 +245,8 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, if (parallelizeLoops) addPotentiallyNestedPass(pm, mlir::createConvertSCFToOpenMPPass(), enablePass); + // Lower affine + addPotentiallyNestedPass(pm, mlir::createLowerAffinePass(), enablePass); // Lower Dataflow tasks to DRF addPotentiallyNestedPass( diff --git a/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc index a0dff637b..63b7acedb 100644 --- a/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc @@ -1436,6 +1436,382 @@ TEST(End2EndJit_FHELinalg, matmul_int_eint) { } } +/////////////////////////////////////////////////////////////////////////////// +// FHELinalg conv2d /////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(End2EndJit_FHELinalg, conv2d_simple_input44_kernel22) { + + mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + func @main(%input: tensor<1x1x4x4x!FHE.eint<6>>, %weight: tensor<1x1x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> { + %0 = "FHELinalg.conv2d"(%input, %weight){ + strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> + } : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> + return %0 : tensor<1x1x2x2x!FHE.eint<6>> + } +)XXX"); + const uint8_t A[1][1][4][4]{{{ + {1, 2, 3, 4}, + {1, 2, 3, 4}, + {1, 2, 3, 4}, + {1, 2, 3, 4}, + }}}; + const uint8_t B[1][1][2][2]{{{ + {1, 2}, + {2, 1}, + }}}; + const uint8_t expected[2][2]{ + {9, 21}, + {9, 21}, + }; + + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + aArg(llvm::ArrayRef((const uint8_t *)A, 4 * 4), {1, 1, 4, 4}); + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + bArg(llvm::ArrayRef((const uint8_t *)B, 2 * 2), {1, 1, 2, 2}); + + llvm::Expected> res = + lambda.operator()>({&aArg, &bArg}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)2 * 2); + + for (size_t i = 0; i < 2; i++) { + for (size_t j = 0; j < 2; j++) { + EXPECT_EQ((*res)[i * 2 + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + +TEST(End2EndJit_FHELinalg, conv2d_simple_input44_const_kernel22) { + + mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + func @main(%input: tensor<1x1x4x4x!FHE.eint<6>>) -> tensor<1x1x2x2x!FHE.eint<6>> { + %weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7> + %0 = "FHELinalg.conv2d"(%input, %weight){ + strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> + } : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> + return %0 : tensor<1x1x2x2x!FHE.eint<6>> + } +)XXX"); + const uint8_t A[1][1][4][4]{{{ + {1, 2, 3, 4}, + {1, 2, 3, 4}, + {1, 2, 3, 4}, + {1, 2, 3, 4}, + }}}; + const uint8_t expected[2][2]{ + {9, 21}, + {9, 21}, + }; + + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + aArg(llvm::ArrayRef((const uint8_t *)A, 4 * 4), {1, 1, 4, 4}); + + llvm::Expected> res = + lambda.operator()>({&aArg}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)2 * 2); + + for (size_t i = 0; i < 2; i++) { + for (size_t j = 0; j < 2; j++) { + EXPECT_EQ((*res)[i * 2 + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + +TEST(End2EndJit_FHELinalg, conv2d_simple_input44_kernel22_const_bias) { + + mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + func @main(%input: tensor<1x1x4x4x!FHE.eint<6>>, %weight: tensor<1x1x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> { + %bias = arith.constant dense<[1]> : tensor<1xi7> + %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ + strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> + } : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> + return %0 : tensor<1x1x2x2x!FHE.eint<6>> + } +)XXX"); + const uint8_t A[1][1][4][4]{{{ + {1, 2, 3, 4}, + {1, 2, 3, 4}, + {1, 2, 3, 4}, + {1, 2, 3, 4}, + }}}; + const uint8_t B[1][1][2][2]{{{ + {1, 2}, + {2, 1}, + }}}; + const uint8_t expected[2][2]{ + {10, 22}, + {10, 22}, + }; + + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + aArg(llvm::ArrayRef((const uint8_t *)A, 4 * 4), {1, 1, 4, 4}); + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + bArg(llvm::ArrayRef((const uint8_t *)B, 2 * 2), {1, 1, 2, 2}); + + llvm::Expected> res = + lambda.operator()>({&aArg, &bArg}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)2 * 2); + + for (size_t i = 0; i < 2; i++) { + for (size_t j = 0; j < 2; j++) { + EXPECT_EQ((*res)[i * 2 + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + +TEST(End2EndJit_FHELinalg, conv2d_batched_input44_kernel22) { + + mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + func @main(%input: tensor<3x1x4x4x!FHE.eint<6>>, %weight: tensor<1x1x2x2xi7>) -> tensor<3x1x2x2x!FHE.eint<6>> { + %0 = "FHELinalg.conv2d"(%input, %weight){ + strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> + } : (tensor<3x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>) -> tensor<3x1x2x2x!FHE.eint<6>> + return %0 : tensor<3x1x2x2x!FHE.eint<6>> + } +)XXX"); + const uint8_t A[3][1][4][4]{{{ + {1, 2, 3, 4}, + {1, 2, 3, 4}, + {1, 2, 3, 4}, + {1, 2, 3, 4}, + }}, + {{ + {3, 2, 3, 4}, + {2, 3, 3, 4}, + {3, 2, 3, 4}, + {2, 3, 3, 4}, + }}, + {{ + {1, 2, 3, 4}, + {1, 2, 4, 2}, + {1, 2, 3, 4}, + {1, 2, 4, 2}, + }}}; + const uint8_t B[1][1][2][2]{{{ + {1, 2}, + {2, 1}, + }}}; + const uint8_t expected[3][1][2][2]{{{ + {9, 21}, + {9, 21}, + }}, + {{ + {14, 21}, + {14, 21}, + }}, + {{ + {9, 21}, + {9, 21}, + }}}; + + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + aArg(llvm::ArrayRef((const uint8_t *)A, 3 * 4 * 4), + {3, 1, 4, 4}); + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + bArg(llvm::ArrayRef((const uint8_t *)B, 2 * 2), {1, 1, 2, 2}); + + llvm::Expected> res = + lambda.operator()>({&aArg, &bArg}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)3 * 2 * 2); + + for (size_t batch = 0; batch < 3; batch++) { + for (size_t i = 0; i < 2; i++) { + for (size_t j = 0; j < 2; j++) { + EXPECT_EQ((*res)[batch * 4 + i * 2 + j], expected[batch][0][i][j]) + << ", at pos(" << batch << "," << i << "," << j << ")"; + } + } + } +} + +TEST(End2EndJit_FHELinalg, conv2d_simple_input44_kernel2122) { + + mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + func @main(%input: tensor<1x1x4x4x!FHE.eint<6>>, %weight: tensor<2x1x2x2xi7>) -> tensor<1x2x2x2x!FHE.eint<6>> { + %0 = "FHELinalg.conv2d"(%input, %weight){ + strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> + } : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<2x1x2x2xi7>) -> tensor<1x2x2x2x!FHE.eint<6>> + return %0 : tensor<1x2x2x2x!FHE.eint<6>> + } +)XXX"); + const uint8_t A[1][1][4][4]{{{ + {1, 2, 3, 4}, + {1, 2, 3, 4}, + {1, 2, 3, 4}, + {1, 2, 3, 4}, + }}}; + const uint8_t B[2][1][2][2]{{{ + {1, 2}, + {2, 1}, + }}, + {{ + {2, 2}, + {2, 2}, + }}}; + const uint8_t expected[2][2][2]{{ + {9, 21}, + {9, 21}, + }, + { + {12, 28}, + {12, 28}, + } + + }; + + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + aArg(llvm::ArrayRef((const uint8_t *)A, 4 * 4), {1, 1, 4, 4}); + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + bArg(llvm::ArrayRef((const uint8_t *)B, 2 * 2 * 2), + {2, 1, 2, 2}); + + llvm::Expected> res = + lambda.operator()>({&aArg, &bArg}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)2 * 2 * 2); + + for (size_t channel = 0; channel < 2; channel++) { + for (size_t i = 0; i < 2; i++) { + for (size_t j = 0; j < 2; j++) { + EXPECT_EQ((*res)[channel * 4 + i * 2 + j], expected[channel][i][j]) + << ", at pos(" << channel << "," << i << "," << j << ")"; + } + } + } +} + +TEST(End2EndJit_FHELinalg, conv2d_simple_input1244_kernel1222) { + + mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + func @main(%input: tensor<1x2x4x4x!FHE.eint<6>>, %weight: tensor<1x2x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> { + %0 = "FHELinalg.conv2d"(%input, %weight){ + strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> + } : (tensor<1x2x4x4x!FHE.eint<6>>, tensor<1x2x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> + return %0 : tensor<1x1x2x2x!FHE.eint<6>> + } +)XXX"); + const uint8_t A[1][2][4][4]{{{ + {1, 2, 3, 4}, + {1, 2, 3, 4}, + {1, 2, 3, 4}, + {1, 2, 3, 4}, + }, + { + {1, 2, 3, 4}, + {1, 2, 3, 4}, + {1, 2, 3, 4}, + {1, 2, 3, 4}, + }}}; + const uint8_t B[1][2][2][2]{{{ + {1, 2}, + {2, 1}, + }, + { + {1, 2}, + {2, 1}, + }}}; + const uint8_t expected[2][2]{ + {18, 42}, + {18, 42}, + }; + + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + aArg(llvm::ArrayRef((const uint8_t *)A, 2 * 4 * 4), + {1, 2, 4, 4}); + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + bArg(llvm::ArrayRef((const uint8_t *)B, 2 * 2 * 2), + {1, 2, 2, 2}); + + llvm::Expected> res = + lambda.operator()>({&aArg, &bArg}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)2 * 2); + + for (size_t i = 0; i < 2; i++) { + for (size_t j = 0; j < 2; j++) { + EXPECT_EQ((*res)[i * 2 + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + +TEST(End2EndJit_FHELinalg, conv2d_simple_input44_kernel22_dilation2) { + + mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + func @main(%input: tensor<1x1x4x4x!FHE.eint<6>>, %weight: tensor<1x1x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> { + %0 = "FHELinalg.conv2d"(%input, %weight){ + strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[2,2]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> + } : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> + return %0 : tensor<1x1x2x2x!FHE.eint<6>> + } +)XXX"); + const uint8_t A[1][1][4][4]{{{ + {1, 2, 3, 4}, + {1, 2, 3, 4}, + {1, 2, 3, 4}, + {1, 2, 3, 4}, + }}}; + const uint8_t B[1][1][2][2]{{{ + {1, 2}, + {2, 1}, + }}}; + const uint8_t expected[2][2]{ + {12, 18}, + {12, 18}, + }; + + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + aArg(llvm::ArrayRef((const uint8_t *)A, 4 * 4), {1, 1, 4, 4}); + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + bArg(llvm::ArrayRef((const uint8_t *)B, 2 * 2), {1, 1, 2, 2}); + + llvm::Expected> res = + lambda.operator()>({&aArg, &bArg}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)2 * 2); + + for (size_t i = 0; i < 2; i++) { + for (size_t j = 0; j < 2; j++) { + EXPECT_EQ((*res)[i * 2 + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + /////////////////////////////////////////////////////////////////////////////// // linalg.tensor_collapse_shape /////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////