mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
enhance(compiler): Remove the restriction of analysis on the HLFHELinalg.dot
While the manp analysis wasn't handle tensor the dot operation restrict the operands to come from a block argument. Since the tensor are handled in the manp pass this restriction has no more meaning.
This commit is contained in:
@@ -212,7 +212,8 @@ static std::string APIntToStringValUnsigned(const llvm::APInt &i) {
|
||||
// Calculates the square of the 2-norm of a tensor initialized with a
|
||||
// dense matrix of constant, signless integers. Aborts if the value
|
||||
// type or initialization of of `cstOp` is incorrect.
|
||||
static llvm::APInt denseCstTensorNorm2Sq(mlir::arith::ConstantOp cstOp) {
|
||||
static llvm::APInt denseCstTensorNorm2Sq(mlir::arith::ConstantOp cstOp,
|
||||
llvm::APInt eNorm) {
|
||||
mlir::DenseIntElementsAttr denseVals =
|
||||
cstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value");
|
||||
|
||||
@@ -230,8 +231,9 @@ static llvm::APInt denseCstTensorNorm2Sq(mlir::arith::ConstantOp cstOp) {
|
||||
llvm::APInt accu{1, 0, false};
|
||||
|
||||
for (llvm::APInt val : denseVals.getValues<llvm::APInt>()) {
|
||||
llvm::APInt valSq = APIntWidthExtendUSq(val);
|
||||
accu = APIntWidthExtendUAdd(accu, valSq);
|
||||
llvm::APInt valSqNorm = APIntWidthExtendUSq(val);
|
||||
llvm::APInt mulSqNorm = APIntWidthExtendUMul(valSqNorm, eNorm);
|
||||
accu = APIntWidthExtendUAdd(accu, mulSqNorm);
|
||||
}
|
||||
|
||||
return accu;
|
||||
@@ -241,7 +243,8 @@ static llvm::APInt denseCstTensorNorm2Sq(mlir::arith::ConstantOp cstOp) {
|
||||
// integers by conservatively assuming that the dynamic values are the
|
||||
// maximum for the integer width. Aborts if the tensor type `tTy` is
|
||||
// incorrect.
|
||||
static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy) {
|
||||
static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy,
|
||||
llvm::APInt eNorm) {
|
||||
assert(tTy && tTy.getElementType().isSignlessInteger() &&
|
||||
tTy.hasStaticShape() && tTy.getRank() == 1 &&
|
||||
"Plaintext operand must be a statically shaped 1D tensor of integers");
|
||||
@@ -254,6 +257,7 @@ static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy) {
|
||||
|
||||
llvm::APInt maxVal = APInt::getMaxValue(elWidth);
|
||||
llvm::APInt maxValSq = APIntWidthExtendUSq(maxVal);
|
||||
llvm::APInt maxMulSqNorm = APIntWidthExtendUMul(maxValSq, eNorm);
|
||||
|
||||
// Calculate number of bits for APInt to store number of elements
|
||||
uint64_t nElts = (uint64_t)tTy.getNumElements();
|
||||
@@ -262,7 +266,7 @@ static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy) {
|
||||
|
||||
llvm::APInt nEltsAP{nEltsBits, nElts, false};
|
||||
|
||||
return APIntWidthExtendUMul(maxValSq, nEltsAP);
|
||||
return APIntWidthExtendUMul(maxMulSqNorm, nEltsAP);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of an
|
||||
@@ -270,9 +274,12 @@ static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy) {
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::zamalang::HLFHELinalg::Dot op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
assert(op->getOpOperand(0).get().isa<mlir::BlockArgument>() &&
|
||||
"Only dot operations with tensors that are function arguments are "
|
||||
"currently supported");
|
||||
assert(operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
|
||||
"operands");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
|
||||
mlir::arith::ConstantOp cstOp =
|
||||
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
|
||||
@@ -281,7 +288,7 @@ static llvm::APInt getSqMANP(
|
||||
if (cstOp) {
|
||||
// Dot product between a vector of encrypted integers and a vector
|
||||
// of plaintext constants -> return 2-norm of constant vector
|
||||
return denseCstTensorNorm2Sq(cstOp);
|
||||
return denseCstTensorNorm2Sq(cstOp, eNorm);
|
||||
} else {
|
||||
// Dot product between a vector of encrypted integers and a vector
|
||||
// of dynamic plaintext values -> conservatively assume that all
|
||||
@@ -292,7 +299,7 @@ static llvm::APInt getSqMANP(
|
||||
.getType()
|
||||
.dyn_cast_or_null<mlir::TensorType>();
|
||||
|
||||
return denseDynTensorNorm2Sq(tTy);
|
||||
return denseDynTensorNorm2Sq(tTy, eNorm);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,28 +10,6 @@ func @single_zero() -> !HLFHE.eint<2>
|
||||
|
||||
// -----
|
||||
|
||||
func @single_cst_dot(%t: tensor<4x!HLFHE.eint<2>>) -> !HLFHE.eint<2>
|
||||
{
|
||||
%cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi3>
|
||||
|
||||
// CHECK: %[[ret:.*]] = "HLFHELinalg.dot_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 6 : ui{{[[0-9]+}}} : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2>
|
||||
%0 = "HLFHELinalg.dot_eint_int"(%t, %cst) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2>
|
||||
|
||||
return %0 : !HLFHE.eint<2>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @single_dyn_dot(%t: tensor<4x!HLFHE.eint<2>>, %dyn: tensor<4xi3>) -> !HLFHE.eint<2>
|
||||
{
|
||||
// CHECK: %[[ret:.*]] = "HLFHELinalg.dot_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 14 : ui{{[[0-9]+}}} : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2>
|
||||
%0 = "HLFHELinalg.dot_eint_int"(%t, %dyn) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2>
|
||||
|
||||
return %0 : !HLFHE.eint<2>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @single_cst_add_eint_int(%e: !HLFHE.eint<2>) -> !HLFHE.eint<2>
|
||||
{
|
||||
%cst = arith.constant 3 : i3
|
||||
|
||||
@@ -160,6 +160,63 @@ func @apply_multi_lookup_table_after_op(%t: tensor<8x!HLFHE.eint<2>>, %i: tensor
|
||||
|
||||
// -----
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// HLFHELinalg.dot_eint_int
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
func @single_cst_dot(%t: tensor<4x!HLFHE.eint<2>>) -> !HLFHE.eint<2>
|
||||
{
|
||||
%cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi3>
|
||||
// sqrt(1^2*1 + 2^2*1 + 3^2*1 + 4^2*1) = 5.477225575
|
||||
// CHECK: %[[V0:.*]] = "HLFHELinalg.dot_eint_int"(%[[T:.*]], %[[CST:.*]]) {MANP = 6 : ui{{[[0-9]+}}} : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2>
|
||||
%0 = "HLFHELinalg.dot_eint_int"(%t, %cst) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2>
|
||||
return %0 : !HLFHE.eint<2>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @single_dyn_dot(%t: tensor<4x!HLFHE.eint<2>>, %dyn: tensor<4xi3>) -> !HLFHE.eint<2>
|
||||
{
|
||||
// sqrt(1*(2^3-1)^2*4) = 14
|
||||
// CHECK: %[[V0:.*]] = "HLFHELinalg.dot_eint_int"([[T:.*]], %[[DYN:.*]]) {MANP = 14 : ui{{[[0-9]+}}} : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2>
|
||||
%0 = "HLFHELinalg.dot_eint_int"(%t, %dyn) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2>
|
||||
|
||||
return %0 : !HLFHE.eint<2>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @single_cst_dot_after_op(%t: tensor<4x!HLFHE.eint<2>>, %i: tensor<4xi3>) -> !HLFHE.eint<2>
|
||||
{
|
||||
// sqrt((2^3)^2*1) = sqrt(64) = 8
|
||||
// CHECK: %[[V0:.*]] = "HLFHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}}
|
||||
%0 = "HLFHELinalg.mul_eint_int"(%t, %i) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> tensor<4x!HLFHE.eint<2>>
|
||||
|
||||
%cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi3>
|
||||
// sqrt(1^2*64 + 2^2*64 + 3^2*64 + 4^2*64) = sqrt(1920) = 43.8178046
|
||||
// CHECK: %[[V1:.*]] = "HLFHELinalg.dot_eint_int"(%[[V0]], %[[CST:.*]]) {MANP = 44 : ui{{[[0-9]+}}}
|
||||
%1 = "HLFHELinalg.dot_eint_int"(%0, %cst) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2>
|
||||
|
||||
return %1 : !HLFHE.eint<2>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @single_dyn_dot_after_op(%t: tensor<4x!HLFHE.eint<2>>, %i: tensor<4xi3>) -> !HLFHE.eint<2>
|
||||
{
|
||||
// sqrt((2^3)^2*1) = sqrt(64) = 8
|
||||
// CHECK: %[[V0:.*]] = "HLFHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}}
|
||||
%0 = "HLFHELinalg.mul_eint_int"(%t, %i) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> tensor<4x!HLFHE.eint<2>>
|
||||
|
||||
// sqrt(4*(2^3-1)^2*64) = sqrt(12544) = 112
|
||||
// CHECK: %[[V1:.*]] = "HLFHELinalg.dot_eint_int"(%[[V0]], %[[I]]) {MANP = 112 : ui{{[[0-9]+}}}
|
||||
%1 = "HLFHELinalg.dot_eint_int"(%0, %i) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2>
|
||||
|
||||
return %1 : !HLFHE.eint<2>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// HLFHELinalg.matmul_ent_int
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
Reference in New Issue
Block a user