feat: support lowering of convolution end to end

This commit is contained in:
youben11
2022-02-17 09:27:16 +01:00
committed by Ayoub Benaissa
parent 86379096df
commit e82360a9fe
2 changed files with 382 additions and 0 deletions

View File

@@ -154,6 +154,10 @@ lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertFHETensorOpsToLinalg(), enablePass);
// FHETensorOpsToLinalg does generate linalg named ops that need to be lowered
// to linalg.generic operations
addPotentiallyNestedPass(pm, mlir::createLinalgGeneralizationPass(),
enablePass);
addPotentiallyNestedPass(pm, mlir::concretelang::createConvertFHEToTFHEPass(),
enablePass);
@@ -241,6 +245,8 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
if (parallelizeLoops)
addPotentiallyNestedPass(pm, mlir::createConvertSCFToOpenMPPass(),
enablePass);
// Lower affine
addPotentiallyNestedPass(pm, mlir::createLowerAffinePass(), enablePass);
// Lower Dataflow tasks to DRF
addPotentiallyNestedPass(

View File

@@ -1436,6 +1436,382 @@ TEST(End2EndJit_FHELinalg, matmul_int_eint) {
}
}
///////////////////////////////////////////////////////////////////////////////
// FHELinalg conv2d ///////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(End2EndJit_FHELinalg, conv2d_simple_input44_kernel22) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%input: tensor<1x1x4x4x!FHE.eint<6>>, %weight: tensor<1x1x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> {
%0 = "FHELinalg.conv2d"(%input, %weight){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
}
)XXX");
const uint8_t A[1][1][4][4]{{{
{1, 2, 3, 4},
{1, 2, 3, 4},
{1, 2, 3, 4},
{1, 2, 3, 4},
}}};
const uint8_t B[1][1][2][2]{{{
{1, 2},
{2, 1},
}}};
const uint8_t expected[2][2]{
{9, 21},
{9, 21},
};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
aArg(llvm::ArrayRef<uint8_t>((const uint8_t *)A, 4 * 4), {1, 1, 4, 4});
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
bArg(llvm::ArrayRef<uint8_t>((const uint8_t *)B, 2 * 2), {1, 1, 2, 2});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&aArg, &bArg});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)2 * 2);
for (size_t i = 0; i < 2; i++) {
for (size_t j = 0; j < 2; j++) {
EXPECT_EQ((*res)[i * 2 + j], expected[i][j])
<< ", at pos(" << i << "," << j << ")";
}
}
}
TEST(End2EndJit_FHELinalg, conv2d_simple_input44_const_kernel22) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%input: tensor<1x1x4x4x!FHE.eint<6>>) -> tensor<1x1x2x2x!FHE.eint<6>> {
%weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7>
%0 = "FHELinalg.conv2d"(%input, %weight){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
}
)XXX");
const uint8_t A[1][1][4][4]{{{
{1, 2, 3, 4},
{1, 2, 3, 4},
{1, 2, 3, 4},
{1, 2, 3, 4},
}}};
const uint8_t expected[2][2]{
{9, 21},
{9, 21},
};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
aArg(llvm::ArrayRef<uint8_t>((const uint8_t *)A, 4 * 4), {1, 1, 4, 4});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&aArg});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)2 * 2);
for (size_t i = 0; i < 2; i++) {
for (size_t j = 0; j < 2; j++) {
EXPECT_EQ((*res)[i * 2 + j], expected[i][j])
<< ", at pos(" << i << "," << j << ")";
}
}
}
TEST(End2EndJit_FHELinalg, conv2d_simple_input44_kernel22_const_bias) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%input: tensor<1x1x4x4x!FHE.eint<6>>, %weight: tensor<1x1x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> {
%bias = arith.constant dense<[1]> : tensor<1xi7>
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
}
)XXX");
const uint8_t A[1][1][4][4]{{{
{1, 2, 3, 4},
{1, 2, 3, 4},
{1, 2, 3, 4},
{1, 2, 3, 4},
}}};
const uint8_t B[1][1][2][2]{{{
{1, 2},
{2, 1},
}}};
const uint8_t expected[2][2]{
{10, 22},
{10, 22},
};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
aArg(llvm::ArrayRef<uint8_t>((const uint8_t *)A, 4 * 4), {1, 1, 4, 4});
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
bArg(llvm::ArrayRef<uint8_t>((const uint8_t *)B, 2 * 2), {1, 1, 2, 2});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&aArg, &bArg});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)2 * 2);
for (size_t i = 0; i < 2; i++) {
for (size_t j = 0; j < 2; j++) {
EXPECT_EQ((*res)[i * 2 + j], expected[i][j])
<< ", at pos(" << i << "," << j << ")";
}
}
}
TEST(End2EndJit_FHELinalg, conv2d_batched_input44_kernel22) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%input: tensor<3x1x4x4x!FHE.eint<6>>, %weight: tensor<1x1x2x2xi7>) -> tensor<3x1x2x2x!FHE.eint<6>> {
%0 = "FHELinalg.conv2d"(%input, %weight){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<3x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>) -> tensor<3x1x2x2x!FHE.eint<6>>
return %0 : tensor<3x1x2x2x!FHE.eint<6>>
}
)XXX");
const uint8_t A[3][1][4][4]{{{
{1, 2, 3, 4},
{1, 2, 3, 4},
{1, 2, 3, 4},
{1, 2, 3, 4},
}},
{{
{3, 2, 3, 4},
{2, 3, 3, 4},
{3, 2, 3, 4},
{2, 3, 3, 4},
}},
{{
{1, 2, 3, 4},
{1, 2, 4, 2},
{1, 2, 3, 4},
{1, 2, 4, 2},
}}};
const uint8_t B[1][1][2][2]{{{
{1, 2},
{2, 1},
}}};
const uint8_t expected[3][1][2][2]{{{
{9, 21},
{9, 21},
}},
{{
{14, 21},
{14, 21},
}},
{{
{9, 21},
{9, 21},
}}};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
aArg(llvm::ArrayRef<uint8_t>((const uint8_t *)A, 3 * 4 * 4),
{3, 1, 4, 4});
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
bArg(llvm::ArrayRef<uint8_t>((const uint8_t *)B, 2 * 2), {1, 1, 2, 2});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&aArg, &bArg});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)3 * 2 * 2);
for (size_t batch = 0; batch < 3; batch++) {
for (size_t i = 0; i < 2; i++) {
for (size_t j = 0; j < 2; j++) {
EXPECT_EQ((*res)[batch * 4 + i * 2 + j], expected[batch][0][i][j])
<< ", at pos(" << batch << "," << i << "," << j << ")";
}
}
}
}
TEST(End2EndJit_FHELinalg, conv2d_simple_input44_kernel2122) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%input: tensor<1x1x4x4x!FHE.eint<6>>, %weight: tensor<2x1x2x2xi7>) -> tensor<1x2x2x2x!FHE.eint<6>> {
%0 = "FHELinalg.conv2d"(%input, %weight){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<2x1x2x2xi7>) -> tensor<1x2x2x2x!FHE.eint<6>>
return %0 : tensor<1x2x2x2x!FHE.eint<6>>
}
)XXX");
const uint8_t A[1][1][4][4]{{{
{1, 2, 3, 4},
{1, 2, 3, 4},
{1, 2, 3, 4},
{1, 2, 3, 4},
}}};
const uint8_t B[2][1][2][2]{{{
{1, 2},
{2, 1},
}},
{{
{2, 2},
{2, 2},
}}};
const uint8_t expected[2][2][2]{{
{9, 21},
{9, 21},
},
{
{12, 28},
{12, 28},
}
};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
aArg(llvm::ArrayRef<uint8_t>((const uint8_t *)A, 4 * 4), {1, 1, 4, 4});
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
bArg(llvm::ArrayRef<uint8_t>((const uint8_t *)B, 2 * 2 * 2),
{2, 1, 2, 2});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&aArg, &bArg});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)2 * 2 * 2);
for (size_t channel = 0; channel < 2; channel++) {
for (size_t i = 0; i < 2; i++) {
for (size_t j = 0; j < 2; j++) {
EXPECT_EQ((*res)[channel * 4 + i * 2 + j], expected[channel][i][j])
<< ", at pos(" << channel << "," << i << "," << j << ")";
}
}
}
}
TEST(End2EndJit_FHELinalg, conv2d_simple_input1244_kernel1222) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%input: tensor<1x2x4x4x!FHE.eint<6>>, %weight: tensor<1x2x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> {
%0 = "FHELinalg.conv2d"(%input, %weight){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x2x4x4x!FHE.eint<6>>, tensor<1x2x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
}
)XXX");
const uint8_t A[1][2][4][4]{{{
{1, 2, 3, 4},
{1, 2, 3, 4},
{1, 2, 3, 4},
{1, 2, 3, 4},
},
{
{1, 2, 3, 4},
{1, 2, 3, 4},
{1, 2, 3, 4},
{1, 2, 3, 4},
}}};
const uint8_t B[1][2][2][2]{{{
{1, 2},
{2, 1},
},
{
{1, 2},
{2, 1},
}}};
const uint8_t expected[2][2]{
{18, 42},
{18, 42},
};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
aArg(llvm::ArrayRef<uint8_t>((const uint8_t *)A, 2 * 4 * 4),
{1, 2, 4, 4});
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
bArg(llvm::ArrayRef<uint8_t>((const uint8_t *)B, 2 * 2 * 2),
{1, 2, 2, 2});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&aArg, &bArg});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)2 * 2);
for (size_t i = 0; i < 2; i++) {
for (size_t j = 0; j < 2; j++) {
EXPECT_EQ((*res)[i * 2 + j], expected[i][j])
<< ", at pos(" << i << "," << j << ")";
}
}
}
TEST(End2EndJit_FHELinalg, conv2d_simple_input44_kernel22_dilation2) {
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%input: tensor<1x1x4x4x!FHE.eint<6>>, %weight: tensor<1x1x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> {
%0 = "FHELinalg.conv2d"(%input, %weight){
strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[2,2]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
}
)XXX");
const uint8_t A[1][1][4][4]{{{
{1, 2, 3, 4},
{1, 2, 3, 4},
{1, 2, 3, 4},
{1, 2, 3, 4},
}}};
const uint8_t B[1][1][2][2]{{{
{1, 2},
{2, 1},
}}};
const uint8_t expected[2][2]{
{12, 18},
{12, 18},
};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
aArg(llvm::ArrayRef<uint8_t>((const uint8_t *)A, 4 * 4), {1, 1, 4, 4});
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
bArg(llvm::ArrayRef<uint8_t>((const uint8_t *)B, 2 * 2), {1, 1, 2, 2});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&aArg, &bArg});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)2 * 2);
for (size_t i = 0; i < 2; i++) {
for (size_t j = 0; j < 2; j++) {
EXPECT_EQ((*res)[i * 2 + j], expected[i][j])
<< ", at pos(" << i << "," << j << ")";
}
}
}
///////////////////////////////////////////////////////////////////////////////
// linalg.tensor_collapse_shape ///////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////