mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(compiler): Add support for tensor.extract operations in MANP pass
Add support for `tensor.extract` operations in the MANP pass. This
currently only supports extract operations on tensors of encrypted
integers, which are passed as function arguments, e.g.:
func @extract_ith(%t: tensor<10x!HLFHE.eint<5>>, %i: index) -> !HLFHE.eint<5>{
%c = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>>
return %c : !HLFHE.eint<5>
}
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
#include <zamalang/Dialect/HLFHE/Analysis/MANP.h>
|
||||
#include <zamalang/Dialect/HLFHE/IR/HLFHEDialect.h>
|
||||
#include <zamalang/Dialect/HLFHE/IR/HLFHEOps.h>
|
||||
@@ -11,6 +12,7 @@
|
||||
#include <llvm/ADT/SmallString.h>
|
||||
#include <mlir/Analysis/DataFlowAnalysis.h>
|
||||
#include <mlir/Dialect/StandardOps/IR/Ops.h>
|
||||
#include <mlir/Dialect/Tensor/IR/Tensor.h>
|
||||
#include <mlir/IR/Attributes.h>
|
||||
#include <mlir/IR/BuiltinAttributes.h>
|
||||
#include <mlir/Pass/PassManager.h>
|
||||
@@ -77,6 +79,19 @@ protected:
|
||||
llvm::Optional<llvm::APInt> manp;
|
||||
};
|
||||
|
||||
// Checks if `lhs` is equal to `rhs`, where both values are assumed to
|
||||
// be positive. The bit width of the smaller `APInt` is extended
|
||||
// before comparison via `APInt::operator==`.
|
||||
static bool APIntWidthExtendCompare(const llvm::APInt &lhs,
|
||||
const llvm::APInt &rhs) {
|
||||
if (lhs.getBitWidth() < rhs.getBitWidth())
|
||||
return lhs.zext(rhs.getBitWidth()) == rhs;
|
||||
else if (lhs.getBitWidth() > rhs.getBitWidth())
|
||||
return lhs == rhs.zext(lhs.getBitWidth());
|
||||
else
|
||||
return lhs == rhs;
|
||||
}
|
||||
|
||||
// Checks if `lhs` is less than `rhs`, where both values are assumed
|
||||
// to be positive. The bit width of the smaller `APInt` is extended
|
||||
// before comparison via `APInt::ult`.
|
||||
@@ -305,6 +320,37 @@ static llvm::APInt getSqMANP(
|
||||
return APIntWidthExtendUAdd(a, b);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot
|
||||
// operation that is equivalent to an `tensor.extract`
|
||||
// operation. Currently, this only supports extractions of elements
|
||||
// from tensors passed as function arguments, for which the MANP is
|
||||
// assumed to be 1.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::tensor::ExtractOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
mlir::zamalang::HLFHE::EncryptedIntegerType elTy =
|
||||
op.getOperand(0)
|
||||
.getType()
|
||||
.dyn_cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.dyn_cast_or_null<mlir::zamalang::HLFHE::EncryptedIntegerType>();
|
||||
|
||||
assert(elTy && "Can only calculate MANP for tensor.extract operations on "
|
||||
"HLFHE.eint tensors");
|
||||
|
||||
assert(operandMANPs.size() >= 1 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"MANP value for tensor is unknown");
|
||||
|
||||
llvm::APInt one{1, 1, false};
|
||||
|
||||
assert(APIntWidthExtendCompare(
|
||||
operandMANPs[0]->getValue().getMANP().getValue(), one) &&
|
||||
"MANP value for tensor is not 1 as expected");
|
||||
|
||||
return one;
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `HLFHE.sub_int_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
@@ -404,6 +450,17 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
} else if (llvm::isa<mlir::zamalang::HLFHE::ZeroEintOp>(op) ||
|
||||
llvm::isa<mlir::zamalang::HLFHE::ApplyLookupTableEintOp>(op)) {
|
||||
norm2SqEquiv = llvm::APInt{1, 1, false};
|
||||
} else if (auto tensorExtractOp =
|
||||
llvm::dyn_cast<mlir::tensor::ExtractOp>(op)) {
|
||||
// Only handle extract operations that produce an encrypted integer
|
||||
if (tensorExtractOp->getResultTypes()
|
||||
.front()
|
||||
.dyn_cast_or_null<
|
||||
mlir::zamalang::HLFHE::EncryptedIntegerType>()) {
|
||||
norm2SqEquiv = getSqMANP(tensorExtractOp, operands);
|
||||
} else {
|
||||
isDummy = true;
|
||||
}
|
||||
} else if (llvm::isa<mlir::ConstantOp>(op)) {
|
||||
isDummy = true;
|
||||
} else if (llvm::isa<mlir::zamalang::HLFHE::HLFHEDialect>(
|
||||
|
||||
Reference in New Issue
Block a user