mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
enhance(compiler): Support of tensor operators in MANP pass (close #169)
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
#include <llvm/ADT/SmallString.h>
|
||||
#include <mlir/Analysis/DataFlowAnalysis.h>
|
||||
#include <mlir/Dialect/StandardOps/IR/Ops.h>
|
||||
#include <mlir/Dialect/Tensor/IR/Tensor.h>
|
||||
#include <mlir/IR/Attributes.h>
|
||||
#include <mlir/IR/BuiltinAttributes.h>
|
||||
#include <mlir/Pass/PassManager.h>
|
||||
@@ -121,6 +122,16 @@ static llvm::APInt APIntWidthExtendUMul(const llvm::APInt &lhs,
|
||||
return lhs.zext(targetWidth) * rhs.zext(targetWidth);
|
||||
}
|
||||
|
||||
// Returns the maximum value beetwen `lhs` and `rhs`, where both values are
|
||||
// assumed to be positive. The bit width of the smaller `APInt` is extended
|
||||
// before comparison via `APInt::ult`.
|
||||
static llvm::APInt APIntUMax(const llvm::APInt &lhs, const llvm::APInt &rhs) {
|
||||
if (APIntWidthExtendULT(lhs, rhs)) {
|
||||
return rhs;
|
||||
}
|
||||
return lhs;
|
||||
}
|
||||
|
||||
// Calculates the square of `i`. The bit width `i` is extended in
|
||||
// order to guarantee that the product fits into the resulting
|
||||
// `APInt`.
|
||||
@@ -372,6 +383,58 @@ static llvm::APInt getSqMANP(
|
||||
return APIntWidthExtendUMul(sqNorm, eNorm);
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::tensor::ExtractOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::tensor::FromElementsOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
auto max = std::max_element(
|
||||
operandMANPs.begin(), operandMANPs.end(),
|
||||
[](mlir::LatticeElement<MANPLatticeValue> *const a,
|
||||
mlir::LatticeElement<MANPLatticeValue> *const b) {
|
||||
return APIntWidthExtendULT(a->getValue().getMANP().getValue(),
|
||||
b->getValue().getMANP().getValue());
|
||||
});
|
||||
return (*max)->getValue().getMANP().getValue();
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::tensor::ExtractSliceOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
return operandMANPs[0]->getValue().getMANP().getValue();
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::tensor::InsertSliceOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() >= 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
return APIntUMax(operandMANPs[0]->getValue().getMANP().getValue(),
|
||||
operandMANPs[1]->getValue().getMANP().getValue());
|
||||
}
|
||||
|
||||
struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
using ForwardDataFlowAnalysis<MANPLatticeValue>::ForwardDataFlowAnalysis;
|
||||
MANPAnalysis(mlir::MLIRContext *ctx, bool debug)
|
||||
@@ -387,6 +450,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
bool isDummy = false;
|
||||
llvm::APInt norm2SqEquiv;
|
||||
|
||||
// HLFHE Operaors
|
||||
if (auto dotOp = llvm::dyn_cast<mlir::zamalang::HLFHE::Dot>(op)) {
|
||||
norm2SqEquiv = getSqMANP(dotOp, operands);
|
||||
} else if (auto addEintIntOp =
|
||||
@@ -404,7 +468,58 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
} else if (llvm::isa<mlir::zamalang::HLFHE::ZeroEintOp>(op) ||
|
||||
llvm::isa<mlir::zamalang::HLFHE::ApplyLookupTableEintOp>(op)) {
|
||||
norm2SqEquiv = llvm::APInt{1, 1, false};
|
||||
} else if (llvm::isa<mlir::ConstantOp>(op)) {
|
||||
}
|
||||
// Tensor Operators
|
||||
// ExtractOp
|
||||
else if (auto extractOp = llvm::dyn_cast<mlir::tensor::ExtractOp>(op)) {
|
||||
if (extractOp.result()
|
||||
.getType()
|
||||
.isa<mlir::zamalang::HLFHE::EncryptedIntegerType>()) {
|
||||
norm2SqEquiv = getSqMANP(extractOp, operands);
|
||||
} else {
|
||||
isDummy = true;
|
||||
}
|
||||
}
|
||||
// ExtractSliceOp
|
||||
else if (auto extractSliceOp =
|
||||
llvm::dyn_cast<mlir::tensor::ExtractSliceOp>(op)) {
|
||||
if (extractSliceOp.result()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<mlir::zamalang::HLFHE::EncryptedIntegerType>()) {
|
||||
norm2SqEquiv = getSqMANP(extractSliceOp, operands);
|
||||
} else {
|
||||
isDummy = true;
|
||||
}
|
||||
}
|
||||
// InsertSliceOp
|
||||
else if (auto insertSliceOp =
|
||||
llvm::dyn_cast<mlir::tensor::InsertSliceOp>(op)) {
|
||||
if (insertSliceOp.result()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<mlir::zamalang::HLFHE::EncryptedIntegerType>()) {
|
||||
norm2SqEquiv = getSqMANP(insertSliceOp, operands);
|
||||
} else {
|
||||
isDummy = true;
|
||||
}
|
||||
}
|
||||
// FromElementOp
|
||||
else if (auto fromOp = llvm::dyn_cast<mlir::tensor::FromElementsOp>(op)) {
|
||||
if (fromOp.result()
|
||||
.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<mlir::zamalang::HLFHE::EncryptedIntegerType>()) {
|
||||
norm2SqEquiv = getSqMANP(fromOp, operands);
|
||||
} else {
|
||||
isDummy = true;
|
||||
}
|
||||
}
|
||||
|
||||
else if (llvm::isa<mlir::ConstantOp>(op)) {
|
||||
isDummy = true;
|
||||
} else if (llvm::isa<mlir::zamalang::HLFHE::HLFHEDialect>(
|
||||
*op->getDialect())) {
|
||||
@@ -488,6 +603,14 @@ protected:
|
||||
mlir::zamalang::HLFHE::EncryptedIntegerType eTy =
|
||||
res.getType()
|
||||
.dyn_cast_or_null<mlir::zamalang::HLFHE::EncryptedIntegerType>();
|
||||
if (eTy == nullptr) {
|
||||
auto tensorTy = res.getType().dyn_cast_or_null<mlir::TensorType>();
|
||||
if (tensorTy != nullptr) {
|
||||
eTy = tensorTy.getElementType()
|
||||
.dyn_cast_or_null<
|
||||
mlir::zamalang::HLFHE::EncryptedIntegerType>();
|
||||
}
|
||||
}
|
||||
|
||||
if (eTy) {
|
||||
bool upd = false;
|
||||
|
||||
122
compiler/tests/Dialect/HLFHE/Analysis/MANP_tensor.mlir
Normal file
122
compiler/tests/Dialect/HLFHE/Analysis/MANP_tensor.mlir
Normal file
@@ -0,0 +1,122 @@
|
||||
// RUN: zamacompiler --split-input-file --entry-dialect=hlfhe --action=dump-hlfhe-manp %s 2>&1 | FileCheck %s
|
||||
|
||||
func @tensor_from_elements_1(%a: !HLFHE.eint<2>, %b: !HLFHE.eint<2>, %c: !HLFHE.eint<2>, %d: !HLFHE.eint<2>) -> tensor<4x!HLFHE.eint<2>>
|
||||
{
|
||||
// The MANP value is 1 as all operands are function arguments
|
||||
// CHECK: %[[ret:.*]] = tensor.from_elements %[[a:.*]], %[[b:.*]], %[[c:.*]], %[[d:.*]] {MANP = 1 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<2>>
|
||||
%0 = tensor.from_elements %a, %b, %c, %d : tensor<4x!HLFHE.eint<2>>
|
||||
|
||||
return %0 : tensor<4x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @tensor_from_elements_2(%a: !HLFHE.eint<2>, %b: !HLFHE.eint<2>, %c: !HLFHE.eint<2>, %d: !HLFHE.eint<2>) -> tensor<4x!HLFHE.eint<2>>
|
||||
{
|
||||
%cst = constant 3 : i3
|
||||
|
||||
// CHECK: %[[V0:.*]] = "HLFHE.add_eint_int"(%[[a:.*]], %[[cst:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
|
||||
%0 = "HLFHE.add_eint_int"(%a, %cst) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
|
||||
|
||||
// The MANP value is 4, i.e. the max of all of its operands
|
||||
// CHECK: %[[V1:.*]] = tensor.from_elements %[[V0:.*]], %[[b:.*]], %[[c:.*]], %[[d:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<2>>
|
||||
%1 = tensor.from_elements %0, %b, %c, %d : tensor<4x!HLFHE.eint<2>>
|
||||
|
||||
return %1 : tensor<4x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @tensor_extract_1(%t: tensor<4x!HLFHE.eint<2>>) -> !HLFHE.eint<2>
|
||||
{
|
||||
%cst = constant 1 : index
|
||||
|
||||
// The MANP value is 1 as the tensor operand is a function argument
|
||||
// CHECK: %[[ret:.*]] = tensor.extract %[[t:.*]][%[[c1:.*]]] {MANP = 1 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<2>>
|
||||
%0 = tensor.extract %t[%cst] : tensor<4x!HLFHE.eint<2>>
|
||||
|
||||
return %0 : !HLFHE.eint<2>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @tensor_extract_2(%a: !HLFHE.eint<2>) -> !HLFHE.eint<2>
|
||||
{
|
||||
%c1 = constant 1 : index
|
||||
%c3 = constant 3 : i3
|
||||
|
||||
// CHECK: %[[V0:.*]] = "HLFHE.add_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
|
||||
%0 = "HLFHE.add_eint_int"(%a, %c3) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
|
||||
// CHECK: %[[V1:.*]] = tensor.from_elements %[[V0:.*]], %[[a:.*]], %[[a:.*]], %[[a:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<2>>
|
||||
%1 = tensor.from_elements %0, %a, %a, %a : tensor<4x!HLFHE.eint<2>>
|
||||
// CHECK: %[[ret:.*]] = tensor.extract %[[t:.*]][%[[c3:.*]]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<2>>
|
||||
%2 = tensor.extract %1[%c1] : tensor<4x!HLFHE.eint<2>>
|
||||
|
||||
return %2 : !HLFHE.eint<2>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @tensor_extract_slice_1(%t: tensor<2x10x!HLFHE.eint<2>>) -> tensor<1x5x!HLFHE.eint<2>>
|
||||
{
|
||||
// CHECK: %[[V0:.*]] = tensor.extract_slice %[[t:.*]][1, 5] [1, 5] [1, 1] {MANP = 1 : ui{{[[0-9]+}}} : tensor<2x10x!HLFHE.eint<2>> to tensor<1x5x!HLFHE.eint<2>>
|
||||
%0 = tensor.extract_slice %t[1, 5] [1, 5] [1, 1] : tensor<2x10x!HLFHE.eint<2>> to tensor<1x5x!HLFHE.eint<2>>
|
||||
|
||||
return %0 : tensor<1x5x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @tensor_extract_slice_2(%a: !HLFHE.eint<2>) -> tensor<2x!HLFHE.eint<2>>
|
||||
{
|
||||
%c3 = constant 3 : i3
|
||||
|
||||
// CHECK: %[[V0:.*]] = "HLFHE.add_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
|
||||
%0 = "HLFHE.add_eint_int"(%a, %c3) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
|
||||
// CHECK: %[[V1:.*]] = tensor.from_elements %[[V0:.*]], %[[a:.*]], %[[a:.*]], %[[a:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<2>>
|
||||
%1 = tensor.from_elements %0, %a, %a, %a : tensor<4x!HLFHE.eint<2>>
|
||||
// CHECK: tensor.extract_slice %[[V1]][2] [2] [1] {MANP = 4 : ui{{[0-9]+}}} : tensor<4x!HLFHE.eint<2>> to tensor<2x!HLFHE.eint<2>>
|
||||
%2 = tensor.extract_slice %1[2] [2] [1] : tensor<4x!HLFHE.eint<2>> to tensor<2x!HLFHE.eint<2>>
|
||||
|
||||
return %2 : tensor<2x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @tensor_insert_slice_1(%t0: tensor<2x10x!HLFHE.eint<2>>, %t1: tensor<2x2x!HLFHE.eint<2>>) -> tensor<2x10x!HLFHE.eint<2>>
|
||||
{
|
||||
// %[[V0:.*]] = tensor.insert_slice %[[t1:.*]] into %[[t0:.*]][0, 5] [2, 2] [1, 1] {MANP = 1 : ui{{[[0-9]+}}} : tensor<2x2x!HLFHE.eint<2>> into tensor<2x10x!HLFHE.eint<2>>
|
||||
%0 = tensor.insert_slice %t1 into %t0[0, 5] [2, 2] [1, 1] : tensor<2x2x!HLFHE.eint<2>> into tensor<2x10x!HLFHE.eint<2>>
|
||||
|
||||
return %0 : tensor<2x10x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @tensor_insert_slice_2(%a: !HLFHE.eint<5>) -> tensor<4x!HLFHE.eint<5>>
|
||||
{
|
||||
%c3 = constant 3 : i6
|
||||
%c6 = constant 6 : i6
|
||||
|
||||
// CHECK: %[[V0:.*]] = "HLFHE.add_eint_int"(%[[a:.*]], %[[c3:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!HLFHE.eint<5>, i6) -> !HLFHE.eint<5>
|
||||
%v0 = "HLFHE.add_eint_int"(%a, %c3) : (!HLFHE.eint<5>, i6) -> !HLFHE.eint<5>
|
||||
// CHECK: %[[V1:.*]] = "HLFHE.add_eint_int"(%[[a:.*]], %[[c6:.*]]) {MANP = 7 : ui{{[0-9]+}}} : (!HLFHE.eint<5>, i6) -> !HLFHE.eint<5>
|
||||
%v1 = "HLFHE.add_eint_int"(%a, %c6) : (!HLFHE.eint<5>, i6) -> !HLFHE.eint<5>
|
||||
|
||||
// CHECK: %[[T0:.*]] = tensor.from_elements %[[V0:.*]], %[[V0:.*]], %[[V0:.*]], %[[V0:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!HLFHE.eint<5>>
|
||||
%t0 = tensor.from_elements %v0, %v0, %v0, %v0 : tensor<4x!HLFHE.eint<5>>
|
||||
|
||||
// CHECK: %[[T1:.*]] = tensor.from_elements %[[V1:.*]], %[[V1:.*]] {MANP = 7 : ui{{[[0-9]+}}} : tensor<2x!HLFHE.eint<5>>
|
||||
%t1 = tensor.from_elements %v1, %v1 : tensor<2x!HLFHE.eint<5>>
|
||||
|
||||
// CHECK: %[[T2:.*]] = tensor.insert_slice %[[T1]] into %[[T0]][0] [2] [1] {MANP = 7 : ui{{[[0-9]+}}} : tensor<2x!HLFHE.eint<5>> into tensor<4x!HLFHE.eint<5>>
|
||||
%t2 = tensor.insert_slice %t1 into %t0[0] [2] [1] : tensor<2x!HLFHE.eint<5>> into tensor<4x!HLFHE.eint<5>>
|
||||
|
||||
// CHECK: %[[T3:.*]] = tensor.from_elements %[[V0:.*]], %[[V0:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<2x!HLFHE.eint<5>>
|
||||
%t3 = tensor.from_elements %v0, %v0 : tensor<2x!HLFHE.eint<5>>
|
||||
|
||||
// CHECK: %[[T4:.*]] = tensor.insert_slice %[[T3]] into %[[T2]][0] [2] [1] {MANP = 7 : ui{{[[0-9]+}}} : tensor<2x!HLFHE.eint<5>> into tensor<4x!HLFHE.eint<5>>
|
||||
%t4 = tensor.insert_slice %t3 into %t2[0] [2] [1] : tensor<2x!HLFHE.eint<5>> into tensor<4x!HLFHE.eint<5>>
|
||||
|
||||
return %t0 : tensor<4x!HLFHE.eint<5>>
|
||||
}
|
||||
Reference in New Issue
Block a user