From 100862e4848ad876d2eb32990c00e31e68a09d9b Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Mon, 15 Nov 2021 10:38:08 +0100 Subject: [PATCH] feat(compiler): Add the support of linalg.tensor_expand_shape and linalg.tensor_collapse_shape on encrypted tensors --- .../Conversion/Utils/TensorOpTypeConversion.h | 12 +++ compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp | 51 +++++++++ .../Dialect/HLFHE/Analysis/MANP_tensor.mlir | 38 +++++++ .../unittest/end_to_end_jit_hlfhelinalg.cc | 102 ++++++++++++++++++ 4 files changed, 203 insertions(+) diff --git a/compiler/include/zamalang/Conversion/Utils/TensorOpTypeConversion.h b/compiler/include/zamalang/Conversion/Utils/TensorOpTypeConversion.h index 5f939fce8..e65f8c2eb 100644 --- a/compiler/include/zamalang/Conversion/Utils/TensorOpTypeConversion.h +++ b/compiler/include/zamalang/Conversion/Utils/TensorOpTypeConversion.h @@ -1,6 +1,7 @@ #ifndef ZAMALANG_CONVERSION_TENSOROPTYPECONVERSIONPATTERN_H_ #define ZAMALANG_CONVERSION_TENSOROPTYPECONVERSIONPATTERN_H_ +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" @@ -32,6 +33,17 @@ populateWithTensorTypeConverterPatterns(mlir::RewritePatternSet &patterns, patterns.getContext(), typeConverter); addDynamicallyLegalTypeOp(target, typeConverter); + // TensorCollapseShapeOp + patterns + .add>( + patterns.getContext(), typeConverter); + addDynamicallyLegalTypeOp(target, + typeConverter); + // TensorExpandShapeOp + patterns.add>( + patterns.getContext(), typeConverter); + addDynamicallyLegalTypeOp(target, + typeConverter); } } // namespace zamalang } // namespace mlir diff --git a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp index 47b311080..fe4a792f9 100644 --- a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -737,6 +738,30 @@ static llvm::APInt getSqMANP( operandMANPs[1]->getValue().getMANP().getValue()); } +static llvm::APInt getSqMANP( + mlir::linalg::TensorCollapseShapeOp op, + llvm::ArrayRef *> operandMANPs) { + + assert( + operandMANPs.size() >= 1 && + operandMANPs[0]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + return operandMANPs[0]->getValue().getMANP().getValue(); +} + +static llvm::APInt getSqMANP( + mlir::linalg::TensorExpandShapeOp op, + llvm::ArrayRef *> operandMANPs) { + + assert( + operandMANPs.size() >= 1 && + operandMANPs[0]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + return operandMANPs[0]->getValue().getMANP().getValue(); +} + struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; MANPAnalysis(mlir::MLIRContext *ctx, bool debug) @@ -853,6 +878,32 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { isDummy = true; } } + // TensorCollapseShapeOp + else if (auto reshapeOp = + llvm::dyn_cast(op)) { + if (reshapeOp.result() + .getType() + .cast() + .getElementType() + .isa()) { + norm2SqEquiv = getSqMANP(reshapeOp, operands); + } else { + isDummy = true; + } + } + // TensorExpandShapeOp + else if (auto reshapeOp = + llvm::dyn_cast(op)) { + if (reshapeOp.result() + .getType() + .cast() + .getElementType() + .isa()) { + norm2SqEquiv = getSqMANP(reshapeOp, operands); + } else { + isDummy = true; + } + } else if (llvm::isa(op)) { isDummy = true; diff --git a/compiler/tests/Dialect/HLFHE/Analysis/MANP_tensor.mlir b/compiler/tests/Dialect/HLFHE/Analysis/MANP_tensor.mlir index 548f09a62..906197cf6 100644 --- a/compiler/tests/Dialect/HLFHE/Analysis/MANP_tensor.mlir +++ b/compiler/tests/Dialect/HLFHE/Analysis/MANP_tensor.mlir @@ -120,3 +120,41 @@ func @tensor_insert_slice_2(%a: !HLFHE.eint<5>) -> tensor<4x!HLFHE.eint<5>> return %t0 : tensor<4x!HLFHE.eint<5>> } + +// ----- + +func @tensor_collapse_shape_1(%a: tensor<2x2x4x!HLFHE.eint<6>>) -> tensor<2x8x!HLFHE.eint<6>> { + // CHECK: linalg.tensor_collapse_shape %[[A:.*]] [[X:.*]] {MANP = 1 : ui{{[0-9]+}}} + %0 = linalg.tensor_collapse_shape %a [[0],[1,2]] : tensor<2x2x4x!HLFHE.eint<6>> into tensor<2x8x!HLFHE.eint<6>> + return %0 : tensor<2x8x!HLFHE.eint<6>> +} + +// ----- + +func @tensor_collapse_shape_2(%a: tensor<2x2x4x!HLFHE.eint<2>>, %b: tensor<2x2x4xi3>) -> tensor<2x8x!HLFHE.eint<2>> +{ + // CHECK: "HLFHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 9 : ui{{[0-9]+}}} + %0 = "HLFHELinalg.add_eint_int"(%a, %b) : (tensor<2x2x4x!HLFHE.eint<2>>, tensor<2x2x4xi3>) -> tensor<2x2x4x!HLFHE.eint<2>> + // CHECK-NEXT: linalg.tensor_collapse_shape %[[A:.*]] [[X:.*]] {MANP = 9 : ui{{[0-9]+}}} + %1 = linalg.tensor_collapse_shape %0 [[0],[1,2]] : tensor<2x2x4x!HLFHE.eint<2>> into tensor<2x8x!HLFHE.eint<2>> + return %1 : tensor<2x8x!HLFHE.eint<2>> +} + +// ----- + +func @tensor_expand_shape_1(%a: tensor<2x8x!HLFHE.eint<6>>) -> tensor<2x2x4x!HLFHE.eint<6>> { + // CHECK: linalg.tensor_expand_shape %[[A:.*]] [[X:.*]] {MANP = 1 : ui{{[0-9]+}}} + %0 = linalg.tensor_expand_shape %a [[0],[1,2]] : tensor<2x8x!HLFHE.eint<6>> into tensor<2x2x4x!HLFHE.eint<6>> + return %0 : tensor<2x2x4x!HLFHE.eint<6>> +} + +// ----- + +func @tensor_expand_shape_2(%a: tensor<2x8x!HLFHE.eint<2>>, %b: tensor<2x8xi3>) -> tensor<2x2x4x!HLFHE.eint<2>> +{ + // CHECK: "HLFHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 9 : ui{{[0-9]+}}} + %0 = "HLFHELinalg.add_eint_int"(%a, %b) : (tensor<2x8x!HLFHE.eint<2>>, tensor<2x8xi3>) -> tensor<2x8x!HLFHE.eint<2>> + // CHECK-NEXT: linalg.tensor_expand_shape %[[A:.*]] [[X:.*]] {MANP = 9 : ui{{[0-9]+}}} + %1 = linalg.tensor_expand_shape %0 [[0],[1,2]] : tensor<2x8x!HLFHE.eint<2>> into tensor<2x2x4x!HLFHE.eint<2>> + return %1 : tensor<2x2x4x!HLFHE.eint<2>> +} \ No newline at end of file diff --git a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc index d9ae4ed0e..e332ba086 100644 --- a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc @@ -1273,3 +1273,105 @@ TEST(End2EndJit_HLFHELinalg, matmul_eint_int) { } } } + +/////////////////////////////////////////////////////////////////////////////// +// linalg.tensor_collapse_shape /////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(End2EndJit_Linalg, tensor_collapse_shape) { + + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main(%a: tensor<2x2x4x!HLFHE.eint<6>>) -> tensor<2x8x!HLFHE.eint<6>> { + %0 = linalg.tensor_collapse_shape %a [[0],[1,2]] : tensor<2x2x4x!HLFHE.eint<6>> into tensor<2x8x!HLFHE.eint<6>> + return %0 : tensor<2x8x!HLFHE.eint<6>> +} +)XXX"); + static uint8_t A[2][2][4]{ + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{10, 11, 12, 13}, {14, 15, 16, 17}}, + }; + static uint8_t expected[2][8]{ + {1, 2, 3, 4, 5, 6, 7, 8}, + {10, 11, 12, 13, 14, 15, 16, 17}, + }; + + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + aArg(llvm::MutableArrayRef((uint8_t *)A, 2 * 2 * 4), {2, 2, 4}); + + llvm::Expected> res = + lambda.operator()>( + {&aArg}); + + ASSERT_EXPECTED_SUCCESS(res); + + mlir::zamalang::TensorLambdaArgument> + &resp = (*res) + ->cast>>(); + + ASSERT_EQ(resp.getDimensions().size(), (size_t)2); + ASSERT_EQ(resp.getDimensions().at(0), 2); + ASSERT_EQ(resp.getDimensions().at(1), 8); + ASSERT_EXPECTED_VALUE(resp.getNumElements(), 2 * 8); + + for (size_t i = 0; i < 2; i++) { + for (size_t j = 0; j < 8; j++) { + EXPECT_EQ(resp.getValue()[i * 8 + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +// linalg.tensor_expand_shape /////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(End2EndJit_Linalg, tensor_expand_shape) { + + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main(%a: tensor<2x8x!HLFHE.eint<6>>) -> tensor<2x2x4x!HLFHE.eint<6>> { + %0 = linalg.tensor_expand_shape %a [[0],[1,2]] : tensor<2x8x!HLFHE.eint<6>> into tensor<2x2x4x!HLFHE.eint<6>> + return %0 : tensor<2x2x4x!HLFHE.eint<6>> +} +)XXX"); + + static uint8_t A[2][8]{ + {1, 2, 3, 4, 5, 6, 7, 8}, + {10, 11, 12, 13, 14, 15, 16, 17}, + }; + static uint8_t expected[2][2][4]{ + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{10, 11, 12, 13}, {14, 15, 16, 17}}, + }; + + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + aArg(llvm::MutableArrayRef((uint8_t *)A, 2 * 8), {2, 8}); + + llvm::Expected> res = + lambda.operator()>( + {&aArg}); + + ASSERT_EXPECTED_SUCCESS(res); + + mlir::zamalang::TensorLambdaArgument> + &resp = (*res) + ->cast>>(); + + ASSERT_EQ(resp.getDimensions().size(), (size_t)3); + ASSERT_EQ(resp.getDimensions().at(0), 2); + ASSERT_EQ(resp.getDimensions().at(1), 2); + ASSERT_EQ(resp.getDimensions().at(2), 4); + ASSERT_EXPECTED_VALUE(resp.getNumElements(), 2 * 2 * 4); + + for (size_t i = 0; i < 2; i++) { + for (size_t j = 0; j < 2; j++) { + for (size_t k = 0; k < 4; k++) { + EXPECT_EQ(resp.getValue()[i * 8 + j * 4 + k], expected[i][j][k]) + << ", at pos(" << i << "," << j << "," << k << ")"; + } + } + } +} \ No newline at end of file