mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: support lowering of convolution end to end
This commit is contained in:
@@ -154,6 +154,10 @@ lowerFHEToTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertFHETensorOpsToLinalg(), enablePass);
|
||||
// FHETensorOpsToLinalg does generate linalg named ops that need to be lowered
|
||||
// to linalg.generic operations
|
||||
addPotentiallyNestedPass(pm, mlir::createLinalgGeneralizationPass(),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::concretelang::createConvertFHEToTFHEPass(),
|
||||
enablePass);
|
||||
|
||||
@@ -241,6 +245,8 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
if (parallelizeLoops)
|
||||
addPotentiallyNestedPass(pm, mlir::createConvertSCFToOpenMPPass(),
|
||||
enablePass);
|
||||
// Lower affine
|
||||
addPotentiallyNestedPass(pm, mlir::createLowerAffinePass(), enablePass);
|
||||
|
||||
// Lower Dataflow tasks to DRF
|
||||
addPotentiallyNestedPass(
|
||||
|
||||
@@ -1436,6 +1436,382 @@ TEST(End2EndJit_FHELinalg, matmul_int_eint) {
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// FHELinalg conv2d ///////////////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(End2EndJit_FHELinalg, conv2d_simple_input44_kernel22) {
|
||||
|
||||
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%input: tensor<1x1x4x4x!FHE.eint<6>>, %weight: tensor<1x1x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> {
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
|
||||
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
|
||||
}
|
||||
)XXX");
|
||||
const uint8_t A[1][1][4][4]{{{
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
}}};
|
||||
const uint8_t B[1][1][2][2]{{{
|
||||
{1, 2},
|
||||
{2, 1},
|
||||
}}};
|
||||
const uint8_t expected[2][2]{
|
||||
{9, 21},
|
||||
{9, 21},
|
||||
};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
aArg(llvm::ArrayRef<uint8_t>((const uint8_t *)A, 4 * 4), {1, 1, 4, 4});
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
bArg(llvm::ArrayRef<uint8_t>((const uint8_t *)B, 2 * 2), {1, 1, 2, 2});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&aArg, &bArg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)2 * 2);
|
||||
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
for (size_t j = 0; j < 2; j++) {
|
||||
EXPECT_EQ((*res)[i * 2 + j], expected[i][j])
|
||||
<< ", at pos(" << i << "," << j << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, conv2d_simple_input44_const_kernel22) {
|
||||
|
||||
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%input: tensor<1x1x4x4x!FHE.eint<6>>) -> tensor<1x1x2x2x!FHE.eint<6>> {
|
||||
%weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7>
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
|
||||
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
|
||||
}
|
||||
)XXX");
|
||||
const uint8_t A[1][1][4][4]{{{
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
}}};
|
||||
const uint8_t expected[2][2]{
|
||||
{9, 21},
|
||||
{9, 21},
|
||||
};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
aArg(llvm::ArrayRef<uint8_t>((const uint8_t *)A, 4 * 4), {1, 1, 4, 4});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&aArg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)2 * 2);
|
||||
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
for (size_t j = 0; j < 2; j++) {
|
||||
EXPECT_EQ((*res)[i * 2 + j], expected[i][j])
|
||||
<< ", at pos(" << i << "," << j << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, conv2d_simple_input44_kernel22_const_bias) {
|
||||
|
||||
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%input: tensor<1x1x4x4x!FHE.eint<6>>, %weight: tensor<1x1x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> {
|
||||
%bias = arith.constant dense<[1]> : tensor<1xi7>
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
|
||||
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
|
||||
}
|
||||
)XXX");
|
||||
const uint8_t A[1][1][4][4]{{{
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
}}};
|
||||
const uint8_t B[1][1][2][2]{{{
|
||||
{1, 2},
|
||||
{2, 1},
|
||||
}}};
|
||||
const uint8_t expected[2][2]{
|
||||
{10, 22},
|
||||
{10, 22},
|
||||
};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
aArg(llvm::ArrayRef<uint8_t>((const uint8_t *)A, 4 * 4), {1, 1, 4, 4});
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
bArg(llvm::ArrayRef<uint8_t>((const uint8_t *)B, 2 * 2), {1, 1, 2, 2});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&aArg, &bArg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)2 * 2);
|
||||
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
for (size_t j = 0; j < 2; j++) {
|
||||
EXPECT_EQ((*res)[i * 2 + j], expected[i][j])
|
||||
<< ", at pos(" << i << "," << j << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, conv2d_batched_input44_kernel22) {
|
||||
|
||||
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%input: tensor<3x1x4x4x!FHE.eint<6>>, %weight: tensor<1x1x2x2xi7>) -> tensor<3x1x2x2x!FHE.eint<6>> {
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<3x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>) -> tensor<3x1x2x2x!FHE.eint<6>>
|
||||
return %0 : tensor<3x1x2x2x!FHE.eint<6>>
|
||||
}
|
||||
)XXX");
|
||||
const uint8_t A[3][1][4][4]{{{
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
}},
|
||||
{{
|
||||
{3, 2, 3, 4},
|
||||
{2, 3, 3, 4},
|
||||
{3, 2, 3, 4},
|
||||
{2, 3, 3, 4},
|
||||
}},
|
||||
{{
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 4, 2},
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 4, 2},
|
||||
}}};
|
||||
const uint8_t B[1][1][2][2]{{{
|
||||
{1, 2},
|
||||
{2, 1},
|
||||
}}};
|
||||
const uint8_t expected[3][1][2][2]{{{
|
||||
{9, 21},
|
||||
{9, 21},
|
||||
}},
|
||||
{{
|
||||
{14, 21},
|
||||
{14, 21},
|
||||
}},
|
||||
{{
|
||||
{9, 21},
|
||||
{9, 21},
|
||||
}}};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
aArg(llvm::ArrayRef<uint8_t>((const uint8_t *)A, 3 * 4 * 4),
|
||||
{3, 1, 4, 4});
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
bArg(llvm::ArrayRef<uint8_t>((const uint8_t *)B, 2 * 2), {1, 1, 2, 2});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&aArg, &bArg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)3 * 2 * 2);
|
||||
|
||||
for (size_t batch = 0; batch < 3; batch++) {
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
for (size_t j = 0; j < 2; j++) {
|
||||
EXPECT_EQ((*res)[batch * 4 + i * 2 + j], expected[batch][0][i][j])
|
||||
<< ", at pos(" << batch << "," << i << "," << j << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, conv2d_simple_input44_kernel2122) {
|
||||
|
||||
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%input: tensor<1x1x4x4x!FHE.eint<6>>, %weight: tensor<2x1x2x2xi7>) -> tensor<1x2x2x2x!FHE.eint<6>> {
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<2x1x2x2xi7>) -> tensor<1x2x2x2x!FHE.eint<6>>
|
||||
return %0 : tensor<1x2x2x2x!FHE.eint<6>>
|
||||
}
|
||||
)XXX");
|
||||
const uint8_t A[1][1][4][4]{{{
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
}}};
|
||||
const uint8_t B[2][1][2][2]{{{
|
||||
{1, 2},
|
||||
{2, 1},
|
||||
}},
|
||||
{{
|
||||
{2, 2},
|
||||
{2, 2},
|
||||
}}};
|
||||
const uint8_t expected[2][2][2]{{
|
||||
{9, 21},
|
||||
{9, 21},
|
||||
},
|
||||
{
|
||||
{12, 28},
|
||||
{12, 28},
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
aArg(llvm::ArrayRef<uint8_t>((const uint8_t *)A, 4 * 4), {1, 1, 4, 4});
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
bArg(llvm::ArrayRef<uint8_t>((const uint8_t *)B, 2 * 2 * 2),
|
||||
{2, 1, 2, 2});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&aArg, &bArg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)2 * 2 * 2);
|
||||
|
||||
for (size_t channel = 0; channel < 2; channel++) {
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
for (size_t j = 0; j < 2; j++) {
|
||||
EXPECT_EQ((*res)[channel * 4 + i * 2 + j], expected[channel][i][j])
|
||||
<< ", at pos(" << channel << "," << i << "," << j << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, conv2d_simple_input1244_kernel1222) {
|
||||
|
||||
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%input: tensor<1x2x4x4x!FHE.eint<6>>, %weight: tensor<1x2x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> {
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x2x4x4x!FHE.eint<6>>, tensor<1x2x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
|
||||
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
|
||||
}
|
||||
)XXX");
|
||||
const uint8_t A[1][2][4][4]{{{
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
},
|
||||
{
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
}}};
|
||||
const uint8_t B[1][2][2][2]{{{
|
||||
{1, 2},
|
||||
{2, 1},
|
||||
},
|
||||
{
|
||||
{1, 2},
|
||||
{2, 1},
|
||||
}}};
|
||||
const uint8_t expected[2][2]{
|
||||
{18, 42},
|
||||
{18, 42},
|
||||
};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
aArg(llvm::ArrayRef<uint8_t>((const uint8_t *)A, 2 * 4 * 4),
|
||||
{1, 2, 4, 4});
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
bArg(llvm::ArrayRef<uint8_t>((const uint8_t *)B, 2 * 2 * 2),
|
||||
{1, 2, 2, 2});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&aArg, &bArg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)2 * 2);
|
||||
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
for (size_t j = 0; j < 2; j++) {
|
||||
EXPECT_EQ((*res)[i * 2 + j], expected[i][j])
|
||||
<< ", at pos(" << i << "," << j << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, conv2d_simple_input44_kernel22_dilation2) {
|
||||
|
||||
mlir::concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%input: tensor<1x1x4x4x!FHE.eint<6>>, %weight: tensor<1x1x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> {
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight){
|
||||
strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[2,2]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
|
||||
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
|
||||
}
|
||||
)XXX");
|
||||
const uint8_t A[1][1][4][4]{{{
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
{1, 2, 3, 4},
|
||||
}}};
|
||||
const uint8_t B[1][1][2][2]{{{
|
||||
{1, 2},
|
||||
{2, 1},
|
||||
}}};
|
||||
const uint8_t expected[2][2]{
|
||||
{12, 18},
|
||||
{12, 18},
|
||||
};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
aArg(llvm::ArrayRef<uint8_t>((const uint8_t *)A, 4 * 4), {1, 1, 4, 4});
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
bArg(llvm::ArrayRef<uint8_t>((const uint8_t *)B, 2 * 2), {1, 1, 2, 2});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&aArg, &bArg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)2 * 2);
|
||||
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
for (size_t j = 0; j < 2; j++) {
|
||||
EXPECT_EQ((*res)[i * 2 + j], expected[i][j])
|
||||
<< ", at pos(" << i << "," << j << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// linalg.tensor_collapse_shape ///////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Reference in New Issue
Block a user