feat(compiler): Add the support of linalg.tensor_expand_shape and linalg.tensor_collapse_shape on encrypted tensors

This commit is contained in:
Quentin Bourgerie
2021-11-15 10:38:08 +01:00
parent c5e3d9add8
commit 100862e484
4 changed files with 203 additions and 0 deletions

View File

@@ -1,6 +1,7 @@
#ifndef ZAMALANG_CONVERSION_TENSOROPTYPECONVERSIONPATTERN_H_
#define ZAMALANG_CONVERSION_TENSOROPTYPECONVERSIONPATTERN_H_
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -32,6 +33,17 @@ populateWithTensorTypeConverterPatterns(mlir::RewritePatternSet &patterns,
patterns.getContext(), typeConverter);
addDynamicallyLegalTypeOp<mlir::tensor::FromElementsOp>(target,
typeConverter);
// TensorCollapseShapeOp
patterns
.add<GenericTypeConverterPattern<mlir::linalg::TensorCollapseShapeOp>>(
patterns.getContext(), typeConverter);
addDynamicallyLegalTypeOp<mlir::linalg::TensorCollapseShapeOp>(target,
typeConverter);
// TensorExpandShapeOp
patterns.add<GenericTypeConverterPattern<mlir::linalg::TensorExpandShapeOp>>(
patterns.getContext(), typeConverter);
addDynamicallyLegalTypeOp<mlir::linalg::TensorExpandShapeOp>(target,
typeConverter);
}
} // namespace zamalang
} // namespace mlir

View File

@@ -12,6 +12,7 @@
#include <llvm/ADT/SmallString.h>
#include <mlir/Analysis/DataFlowAnalysis.h>
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
#include <mlir/IR/Attributes.h>
@@ -737,6 +738,30 @@ static llvm::APInt getSqMANP(
operandMANPs[1]->getValue().getMANP().getValue());
}
static llvm::APInt getSqMANP(
mlir::linalg::TensorCollapseShapeOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
assert(
operandMANPs.size() >= 1 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
return operandMANPs[0]->getValue().getMANP().getValue();
}
static llvm::APInt getSqMANP(
mlir::linalg::TensorExpandShapeOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
assert(
operandMANPs.size() >= 1 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
return operandMANPs[0]->getValue().getMANP().getValue();
}
struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
using ForwardDataFlowAnalysis<MANPLatticeValue>::ForwardDataFlowAnalysis;
MANPAnalysis(mlir::MLIRContext *ctx, bool debug)
@@ -853,6 +878,32 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
isDummy = true;
}
}
// TensorCollapseShapeOp
else if (auto reshapeOp =
llvm::dyn_cast<mlir::linalg::TensorCollapseShapeOp>(op)) {
if (reshapeOp.result()
.getType()
.cast<mlir::TensorType>()
.getElementType()
.isa<mlir::zamalang::HLFHE::EncryptedIntegerType>()) {
norm2SqEquiv = getSqMANP(reshapeOp, operands);
} else {
isDummy = true;
}
}
// TensorExpandShapeOp
else if (auto reshapeOp =
llvm::dyn_cast<mlir::linalg::TensorExpandShapeOp>(op)) {
if (reshapeOp.result()
.getType()
.cast<mlir::TensorType>()
.getElementType()
.isa<mlir::zamalang::HLFHE::EncryptedIntegerType>()) {
norm2SqEquiv = getSqMANP(reshapeOp, operands);
} else {
isDummy = true;
}
}
else if (llvm::isa<mlir::arith::ConstantOp>(op)) {
isDummy = true;

View File

@@ -120,3 +120,41 @@ func @tensor_insert_slice_2(%a: !HLFHE.eint<5>) -> tensor<4x!HLFHE.eint<5>>
return %t0 : tensor<4x!HLFHE.eint<5>>
}
// -----
func @tensor_collapse_shape_1(%a: tensor<2x2x4x!HLFHE.eint<6>>) -> tensor<2x8x!HLFHE.eint<6>> {
// CHECK: linalg.tensor_collapse_shape %[[A:.*]] [[X:.*]] {MANP = 1 : ui{{[0-9]+}}}
%0 = linalg.tensor_collapse_shape %a [[0],[1,2]] : tensor<2x2x4x!HLFHE.eint<6>> into tensor<2x8x!HLFHE.eint<6>>
return %0 : tensor<2x8x!HLFHE.eint<6>>
}
// -----
func @tensor_collapse_shape_2(%a: tensor<2x2x4x!HLFHE.eint<2>>, %b: tensor<2x2x4xi3>) -> tensor<2x8x!HLFHE.eint<2>>
{
// CHECK: "HLFHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 9 : ui{{[0-9]+}}}
%0 = "HLFHELinalg.add_eint_int"(%a, %b) : (tensor<2x2x4x!HLFHE.eint<2>>, tensor<2x2x4xi3>) -> tensor<2x2x4x!HLFHE.eint<2>>
// CHECK-NEXT: linalg.tensor_collapse_shape %[[A:.*]] [[X:.*]] {MANP = 9 : ui{{[0-9]+}}}
%1 = linalg.tensor_collapse_shape %0 [[0],[1,2]] : tensor<2x2x4x!HLFHE.eint<2>> into tensor<2x8x!HLFHE.eint<2>>
return %1 : tensor<2x8x!HLFHE.eint<2>>
}
// -----
func @tensor_expand_shape_1(%a: tensor<2x8x!HLFHE.eint<6>>) -> tensor<2x2x4x!HLFHE.eint<6>> {
// CHECK: linalg.tensor_expand_shape %[[A:.*]] [[X:.*]] {MANP = 1 : ui{{[0-9]+}}}
%0 = linalg.tensor_expand_shape %a [[0],[1,2]] : tensor<2x8x!HLFHE.eint<6>> into tensor<2x2x4x!HLFHE.eint<6>>
return %0 : tensor<2x2x4x!HLFHE.eint<6>>
}
// -----
func @tensor_expand_shape_2(%a: tensor<2x8x!HLFHE.eint<2>>, %b: tensor<2x8xi3>) -> tensor<2x2x4x!HLFHE.eint<2>>
{
// CHECK: "HLFHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 9 : ui{{[0-9]+}}}
%0 = "HLFHELinalg.add_eint_int"(%a, %b) : (tensor<2x8x!HLFHE.eint<2>>, tensor<2x8xi3>) -> tensor<2x8x!HLFHE.eint<2>>
// CHECK-NEXT: linalg.tensor_expand_shape %[[A:.*]] [[X:.*]] {MANP = 9 : ui{{[0-9]+}}}
%1 = linalg.tensor_expand_shape %0 [[0],[1,2]] : tensor<2x8x!HLFHE.eint<2>> into tensor<2x2x4x!HLFHE.eint<2>>
return %1 : tensor<2x2x4x!HLFHE.eint<2>>
}

View File

@@ -1273,3 +1273,105 @@ TEST(End2EndJit_HLFHELinalg, matmul_eint_int) {
}
}
}
///////////////////////////////////////////////////////////////////////////////
// linalg.tensor_collapse_shape ///////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(End2EndJit_Linalg, tensor_collapse_shape) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%a: tensor<2x2x4x!HLFHE.eint<6>>) -> tensor<2x8x!HLFHE.eint<6>> {
%0 = linalg.tensor_collapse_shape %a [[0],[1,2]] : tensor<2x2x4x!HLFHE.eint<6>> into tensor<2x8x!HLFHE.eint<6>>
return %0 : tensor<2x8x!HLFHE.eint<6>>
}
)XXX");
static uint8_t A[2][2][4]{
{{1, 2, 3, 4}, {5, 6, 7, 8}},
{{10, 11, 12, 13}, {14, 15, 16, 17}},
};
static uint8_t expected[2][8]{
{1, 2, 3, 4, 5, 6, 7, 8},
{10, 11, 12, 13, 14, 15, 16, 17},
};
mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint8_t>>
aArg(llvm::MutableArrayRef<uint8_t>((uint8_t *)A, 2 * 2 * 4), {2, 2, 4});
llvm::Expected<std::unique_ptr<mlir::zamalang::LambdaArgument>> res =
lambda.operator()<std::unique_ptr<mlir::zamalang::LambdaArgument>>(
{&aArg});
ASSERT_EXPECTED_SUCCESS(res);
mlir::zamalang::TensorLambdaArgument<mlir::zamalang::IntLambdaArgument<>>
&resp = (*res)
->cast<mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<>>>();
ASSERT_EQ(resp.getDimensions().size(), (size_t)2);
ASSERT_EQ(resp.getDimensions().at(0), 2);
ASSERT_EQ(resp.getDimensions().at(1), 8);
ASSERT_EXPECTED_VALUE(resp.getNumElements(), 2 * 8);
for (size_t i = 0; i < 2; i++) {
for (size_t j = 0; j < 8; j++) {
EXPECT_EQ(resp.getValue()[i * 8 + j], expected[i][j])
<< ", at pos(" << i << "," << j << ")";
}
}
}
///////////////////////////////////////////////////////////////////////////////
// linalg.tensor_expand_shape ///////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(End2EndJit_Linalg, tensor_expand_shape) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
func @main(%a: tensor<2x8x!HLFHE.eint<6>>) -> tensor<2x2x4x!HLFHE.eint<6>> {
%0 = linalg.tensor_expand_shape %a [[0],[1,2]] : tensor<2x8x!HLFHE.eint<6>> into tensor<2x2x4x!HLFHE.eint<6>>
return %0 : tensor<2x2x4x!HLFHE.eint<6>>
}
)XXX");
static uint8_t A[2][8]{
{1, 2, 3, 4, 5, 6, 7, 8},
{10, 11, 12, 13, 14, 15, 16, 17},
};
static uint8_t expected[2][2][4]{
{{1, 2, 3, 4}, {5, 6, 7, 8}},
{{10, 11, 12, 13}, {14, 15, 16, 17}},
};
mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint8_t>>
aArg(llvm::MutableArrayRef<uint8_t>((uint8_t *)A, 2 * 8), {2, 8});
llvm::Expected<std::unique_ptr<mlir::zamalang::LambdaArgument>> res =
lambda.operator()<std::unique_ptr<mlir::zamalang::LambdaArgument>>(
{&aArg});
ASSERT_EXPECTED_SUCCESS(res);
mlir::zamalang::TensorLambdaArgument<mlir::zamalang::IntLambdaArgument<>>
&resp = (*res)
->cast<mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<>>>();
ASSERT_EQ(resp.getDimensions().size(), (size_t)3);
ASSERT_EQ(resp.getDimensions().at(0), 2);
ASSERT_EQ(resp.getDimensions().at(1), 2);
ASSERT_EQ(resp.getDimensions().at(2), 4);
ASSERT_EXPECTED_VALUE(resp.getNumElements(), 2 * 2 * 4);
for (size_t i = 0; i < 2; i++) {
for (size_t j = 0; j < 2; j++) {
for (size_t k = 0; k < 4; k++) {
EXPECT_EQ(resp.getValue()[i * 8 + j * 4 + k], expected[i][j][k])
<< ", at pos(" << i << "," << j << "," << k << ")";
}
}
}
}