mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(compiler): Add the support of linalg.tensor_expand_shape and linalg.tensor_collapse_shape on encrypted tensors
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#ifndef ZAMALANG_CONVERSION_TENSOROPTYPECONVERSIONPATTERN_H_
|
||||
#define ZAMALANG_CONVERSION_TENSOROPTYPECONVERSIONPATTERN_H_
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
@@ -32,6 +33,17 @@ populateWithTensorTypeConverterPatterns(mlir::RewritePatternSet &patterns,
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::tensor::FromElementsOp>(target,
|
||||
typeConverter);
|
||||
// TensorCollapseShapeOp
|
||||
patterns
|
||||
.add<GenericTypeConverterPattern<mlir::linalg::TensorCollapseShapeOp>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::linalg::TensorCollapseShapeOp>(target,
|
||||
typeConverter);
|
||||
// TensorExpandShapeOp
|
||||
patterns.add<GenericTypeConverterPattern<mlir::linalg::TensorExpandShapeOp>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::linalg::TensorExpandShapeOp>(target,
|
||||
typeConverter);
|
||||
}
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include <llvm/ADT/SmallString.h>
|
||||
#include <mlir/Analysis/DataFlowAnalysis.h>
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
|
||||
#include <mlir/Dialect/StandardOps/IR/Ops.h>
|
||||
#include <mlir/Dialect/Tensor/IR/Tensor.h>
|
||||
#include <mlir/IR/Attributes.h>
|
||||
@@ -737,6 +738,30 @@ static llvm::APInt getSqMANP(
|
||||
operandMANPs[1]->getValue().getMANP().getValue());
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::linalg::TensorCollapseShapeOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() >= 1 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
return operandMANPs[0]->getValue().getMANP().getValue();
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::linalg::TensorExpandShapeOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() >= 1 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
return operandMANPs[0]->getValue().getMANP().getValue();
|
||||
}
|
||||
|
||||
struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
using ForwardDataFlowAnalysis<MANPLatticeValue>::ForwardDataFlowAnalysis;
|
||||
MANPAnalysis(mlir::MLIRContext *ctx, bool debug)
|
||||
@@ -853,6 +878,32 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
isDummy = true;
|
||||
}
|
||||
}
|
||||
// TensorCollapseShapeOp
|
||||
else if (auto reshapeOp =
|
||||
llvm::dyn_cast<mlir::linalg::TensorCollapseShapeOp>(op)) {
|
||||
if (reshapeOp.result()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<mlir::zamalang::HLFHE::EncryptedIntegerType>()) {
|
||||
norm2SqEquiv = getSqMANP(reshapeOp, operands);
|
||||
} else {
|
||||
isDummy = true;
|
||||
}
|
||||
}
|
||||
// TensorExpandShapeOp
|
||||
else if (auto reshapeOp =
|
||||
llvm::dyn_cast<mlir::linalg::TensorExpandShapeOp>(op)) {
|
||||
if (reshapeOp.result()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<mlir::zamalang::HLFHE::EncryptedIntegerType>()) {
|
||||
norm2SqEquiv = getSqMANP(reshapeOp, operands);
|
||||
} else {
|
||||
isDummy = true;
|
||||
}
|
||||
}
|
||||
|
||||
else if (llvm::isa<mlir::arith::ConstantOp>(op)) {
|
||||
isDummy = true;
|
||||
|
||||
@@ -120,3 +120,41 @@ func @tensor_insert_slice_2(%a: !HLFHE.eint<5>) -> tensor<4x!HLFHE.eint<5>>
|
||||
|
||||
return %t0 : tensor<4x!HLFHE.eint<5>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @tensor_collapse_shape_1(%a: tensor<2x2x4x!HLFHE.eint<6>>) -> tensor<2x8x!HLFHE.eint<6>> {
|
||||
// CHECK: linalg.tensor_collapse_shape %[[A:.*]] [[X:.*]] {MANP = 1 : ui{{[0-9]+}}}
|
||||
%0 = linalg.tensor_collapse_shape %a [[0],[1,2]] : tensor<2x2x4x!HLFHE.eint<6>> into tensor<2x8x!HLFHE.eint<6>>
|
||||
return %0 : tensor<2x8x!HLFHE.eint<6>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @tensor_collapse_shape_2(%a: tensor<2x2x4x!HLFHE.eint<2>>, %b: tensor<2x2x4xi3>) -> tensor<2x8x!HLFHE.eint<2>>
|
||||
{
|
||||
// CHECK: "HLFHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 9 : ui{{[0-9]+}}}
|
||||
%0 = "HLFHELinalg.add_eint_int"(%a, %b) : (tensor<2x2x4x!HLFHE.eint<2>>, tensor<2x2x4xi3>) -> tensor<2x2x4x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: linalg.tensor_collapse_shape %[[A:.*]] [[X:.*]] {MANP = 9 : ui{{[0-9]+}}}
|
||||
%1 = linalg.tensor_collapse_shape %0 [[0],[1,2]] : tensor<2x2x4x!HLFHE.eint<2>> into tensor<2x8x!HLFHE.eint<2>>
|
||||
return %1 : tensor<2x8x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @tensor_expand_shape_1(%a: tensor<2x8x!HLFHE.eint<6>>) -> tensor<2x2x4x!HLFHE.eint<6>> {
|
||||
// CHECK: linalg.tensor_expand_shape %[[A:.*]] [[X:.*]] {MANP = 1 : ui{{[0-9]+}}}
|
||||
%0 = linalg.tensor_expand_shape %a [[0],[1,2]] : tensor<2x8x!HLFHE.eint<6>> into tensor<2x2x4x!HLFHE.eint<6>>
|
||||
return %0 : tensor<2x2x4x!HLFHE.eint<6>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @tensor_expand_shape_2(%a: tensor<2x8x!HLFHE.eint<2>>, %b: tensor<2x8xi3>) -> tensor<2x2x4x!HLFHE.eint<2>>
|
||||
{
|
||||
// CHECK: "HLFHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 9 : ui{{[0-9]+}}}
|
||||
%0 = "HLFHELinalg.add_eint_int"(%a, %b) : (tensor<2x8x!HLFHE.eint<2>>, tensor<2x8xi3>) -> tensor<2x8x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: linalg.tensor_expand_shape %[[A:.*]] [[X:.*]] {MANP = 9 : ui{{[0-9]+}}}
|
||||
%1 = linalg.tensor_expand_shape %0 [[0],[1,2]] : tensor<2x8x!HLFHE.eint<2>> into tensor<2x2x4x!HLFHE.eint<2>>
|
||||
return %1 : tensor<2x2x4x!HLFHE.eint<2>>
|
||||
}
|
||||
@@ -1273,3 +1273,105 @@ TEST(End2EndJit_HLFHELinalg, matmul_eint_int) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// linalg.tensor_collapse_shape ///////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(End2EndJit_Linalg, tensor_collapse_shape) {
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%a: tensor<2x2x4x!HLFHE.eint<6>>) -> tensor<2x8x!HLFHE.eint<6>> {
|
||||
%0 = linalg.tensor_collapse_shape %a [[0],[1,2]] : tensor<2x2x4x!HLFHE.eint<6>> into tensor<2x8x!HLFHE.eint<6>>
|
||||
return %0 : tensor<2x8x!HLFHE.eint<6>>
|
||||
}
|
||||
)XXX");
|
||||
static uint8_t A[2][2][4]{
|
||||
{{1, 2, 3, 4}, {5, 6, 7, 8}},
|
||||
{{10, 11, 12, 13}, {14, 15, 16, 17}},
|
||||
};
|
||||
static uint8_t expected[2][8]{
|
||||
{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
{10, 11, 12, 13, 14, 15, 16, 17},
|
||||
};
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>
|
||||
aArg(llvm::MutableArrayRef<uint8_t>((uint8_t *)A, 2 * 2 * 4), {2, 2, 4});
|
||||
|
||||
llvm::Expected<std::unique_ptr<mlir::zamalang::LambdaArgument>> res =
|
||||
lambda.operator()<std::unique_ptr<mlir::zamalang::LambdaArgument>>(
|
||||
{&aArg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<mlir::zamalang::IntLambdaArgument<>>
|
||||
&resp = (*res)
|
||||
->cast<mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<>>>();
|
||||
|
||||
ASSERT_EQ(resp.getDimensions().size(), (size_t)2);
|
||||
ASSERT_EQ(resp.getDimensions().at(0), 2);
|
||||
ASSERT_EQ(resp.getDimensions().at(1), 8);
|
||||
ASSERT_EXPECTED_VALUE(resp.getNumElements(), 2 * 8);
|
||||
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
for (size_t j = 0; j < 8; j++) {
|
||||
EXPECT_EQ(resp.getValue()[i * 8 + j], expected[i][j])
|
||||
<< ", at pos(" << i << "," << j << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// linalg.tensor_expand_shape ///////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(End2EndJit_Linalg, tensor_expand_shape) {
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%a: tensor<2x8x!HLFHE.eint<6>>) -> tensor<2x2x4x!HLFHE.eint<6>> {
|
||||
%0 = linalg.tensor_expand_shape %a [[0],[1,2]] : tensor<2x8x!HLFHE.eint<6>> into tensor<2x2x4x!HLFHE.eint<6>>
|
||||
return %0 : tensor<2x2x4x!HLFHE.eint<6>>
|
||||
}
|
||||
)XXX");
|
||||
|
||||
static uint8_t A[2][8]{
|
||||
{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
{10, 11, 12, 13, 14, 15, 16, 17},
|
||||
};
|
||||
static uint8_t expected[2][2][4]{
|
||||
{{1, 2, 3, 4}, {5, 6, 7, 8}},
|
||||
{{10, 11, 12, 13}, {14, 15, 16, 17}},
|
||||
};
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>
|
||||
aArg(llvm::MutableArrayRef<uint8_t>((uint8_t *)A, 2 * 8), {2, 8});
|
||||
|
||||
llvm::Expected<std::unique_ptr<mlir::zamalang::LambdaArgument>> res =
|
||||
lambda.operator()<std::unique_ptr<mlir::zamalang::LambdaArgument>>(
|
||||
{&aArg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<mlir::zamalang::IntLambdaArgument<>>
|
||||
&resp = (*res)
|
||||
->cast<mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<>>>();
|
||||
|
||||
ASSERT_EQ(resp.getDimensions().size(), (size_t)3);
|
||||
ASSERT_EQ(resp.getDimensions().at(0), 2);
|
||||
ASSERT_EQ(resp.getDimensions().at(1), 2);
|
||||
ASSERT_EQ(resp.getDimensions().at(2), 4);
|
||||
ASSERT_EXPECTED_VALUE(resp.getNumElements(), 2 * 2 * 4);
|
||||
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
for (size_t j = 0; j < 2; j++) {
|
||||
for (size_t k = 0; k < 4; k++) {
|
||||
EXPECT_EQ(resp.getValue()[i * 8 + j * 4 + k], expected[i][j][k])
|
||||
<< ", at pos(" << i << "," << j << "," << k << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user