From e949e7e2a71d4c5bbe8296a327815e9543a07e8f Mon Sep 17 00:00:00 2001 From: Umut Date: Thu, 2 Mar 2023 16:45:01 +0100 Subject: [PATCH] feat: introduce FHELinalg.mul_eint --- .../Dialect/FHELinalg/IR/FHELinalgOps.td | 52 ++++++++++++++ .../TensorOpsToLinalg.cpp | 4 ++ .../FHE/Analysis/ConcreteOptimizer.cpp | 72 +++++++++++++++++++ .../lib/Dialect/FHE/Analysis/MANP.cpp | 27 +++++++ .../Dialect/FHELinalg/ops.invalid.mlir | 40 +++++++++++ .../check_tests/Dialect/FHELinalg/ops.mlir | 54 ++++++++++++++ .../tests_cpu/end_to_end_fhelinalg.yaml | 32 +++++++++ 7 files changed, 281 insertions(+) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index f509b16f4..6b6bfbe32 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -367,6 +367,58 @@ def FHELinalg_MulEintIntOp : FHELinalg_Op<"mul_eint_int", [TensorBroadcastingRul let hasFolder = 1; } +def FHELinalg_MulEintOp : FHELinalg_Op<"mul_eint", [TensorBroadcastingRules, TensorBinaryEint]> { + let summary = "Returns a tensor that contains the multiplication of two tensor of encrypted integers."; + + let description = [{ + Performs an addition following the broadcasting rules between two tensors of encrypted integers. + The width of the encrypted integers must be equals. + + Examples: + ```mlir + // Returns the term to term multiplication of `%a0` with `%a1` + "FHELinalg.mul_eint"(%a0, %a1) : (tensor<4x!FHE.eint<8>>, tensor<4x!FHE.eint<8>>) -> tensor<4x!FHE.eint<8>> + + // Returns the term to term multiplication of `%a0` with `%a1`, where dimensions equal to one are stretched. + "FHELinalg.mul_eint"(%a0, %a1) : (tensor<4x1x4x!FHE.eint<8>>, tensor<1x4x4x!FHE.eint<8>>) -> tensor<4x4x4x!FHE.eint<8>> + + // Returns the multiplication of a 3x3 matrix of encrypted integers and a 3x1 matrix (a column) of encrypted integers. + // + // [1,2,3] [1] [1,2,3] + // [4,5,6] * [2] = [8,10,12] + // [7,8,9] [3] [21,24,27] + // + // The dimension #1 of operand #2 is stretched as it is equals to 1. + "FHELinalg.mul_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<8>>, tensor<3x1x!FHE.eint<8>>) -> tensor<3x3x!FHE.eint<8>> + + // Returns the multiplication of a 3x3 matrix of encrypted integers and a 1x3 matrix (a line) of encrypted integers. + // + // [1,2,3] [1,4,9] + // [4,5,6] * [1,2,3] = [4,10,18] + // [7,8,9] [7,16,27] + // + // The dimension #2 of operand #2 is stretched as it is equals to 1. + "FHELinalg.mul_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<8>>, tensor<1x3x!FHE.eint<8>>) -> tensor<3x3x!FHE.eint<8>> + + // Same behavior than the previous one, but as the dimension #2 of operand #2 is missing. + "FHELinalg.mul_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<8>>, tensor<3x!FHE.eint<8>>) -> tensor<3x3x!FHE.eint<8>> + ``` + }]; + + let arguments = (ins + Type.predicate, HasStaticShapePred]>>:$lhs, + Type.predicate, HasStaticShapePred]>>:$rhs + ); + + let results = (outs Type.predicate, HasStaticShapePred]>>); + + let builders = [ + OpBuilder<(ins "Value":$rhs, "Value":$lhs), [{ + build($_builder, $_state, rhs.getType(), rhs, lhs); + }]> + ]; +} + def FHELinalg_ApplyLookupTableEintOp : FHELinalg_Op<"apply_lookup_table", []> { let summary = "Returns a tensor that contains the result of the lookup on a table."; diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index 57732ccc4..c313c8979 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -2080,6 +2080,10 @@ void FHETensorOpsToLinalg::runOnOperation() { FHELinalgOpToLinalgGeneric>( &getContext()); + patterns.insert< + FHELinalgOpToLinalgGeneric>( + &getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert resultShape = getShape(result); + + Operation *xOp = mulOp.getLhs().getDefiningOp(); + Operation *yOp = mulOp.getRhs().getDefiningOp(); + + const double fixedCost = NEGLIGIBLE_COMPLEXITY; + const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY; + + llvm::APInt xSmanp = llvm::APInt{1, 1, false}; + if (xOp != nullptr) { + const auto xSmanpAttr = xOp->getAttrOfType("SMANP"); + assert(xSmanpAttr && "Missing SMANP value on a crypto operation"); + xSmanp = xSmanpAttr.getValue(); + } + + llvm::APInt ySmanp = llvm::APInt{1, 1, false}; + if (yOp != nullptr) { + const auto ySmanpAttr = yOp->getAttrOfType("SMANP"); + assert(ySmanpAttr && "Missing SMANP value on a crypto operation"); + ySmanp = ySmanpAttr.getValue(); + } + + auto loc = loc_to_string(mulOp.getLoc()); + auto comment = std::string(mulOp->getName().getStringRef()) + " " + loc; + + // (x + y) and (x - y) + const double addSubManp = + sqrt(xSmanp.roundToDouble() + ySmanp.roundToDouble()); + + // tlu(v) + const double tluManp = 1; + + // tlu(v1) - tlu(v2) + const double tluSubManp = sqrt(tluManp + tluManp); + + // for tlus + const std::vector unknownFunction; + + // tlu(x + y) + auto addNode = + dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, + addSubManp, slice(resultShape), comment); + auto lhsTluNode = dag->add_lut(addNode, slice(unknownFunction), precision); + + // tlu(x - y) + auto subNode = + dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, + addSubManp, slice(resultShape), comment); + auto rhsTluNode = dag->add_lut(subNode, slice(unknownFunction), precision); + + // tlu(x + y) - tlu(x - y) + const std::vector subInputs = { + lhsTluNode, rhsTluNode}; + index[result] = + dag->add_levelled_op(slice(subInputs), lweDimCostFactor, fixedCost, + tluSubManp, slice(resultShape), comment); + } + void addMax(optimizer::Dag &dag, FHE::MaxEintOp &maxOp, Inputs &inputs, int precision) { mlir::Value result = maxOp.getResult(); @@ -420,6 +488,10 @@ struct FunctionToDag { return llvm::dyn_cast(op); } + mlir::concretelang::FHELinalg::MulEintOp asMulTensor(mlir::Operation &op) { + return llvm::dyn_cast(op); + } + mlir::concretelang::FHE::MaxEintOp asMax(mlir::Operation &op) { return llvm::dyn_cast(op); } diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index eb69d55dd..9b2f48093 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -687,6 +687,29 @@ static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MulEintIntOp op, return APIntWidthExtendUMul(sqNorm, eNorm); } +/// Calculates the squared Minimal Arithmetic Noise Padding +/// of `FHE.mul_eint` operation. +static llvm::APInt getSqMANP(mlir::concretelang::FHELinalg::MulEintOp op, + llvm::ArrayRef operandMANPs) { + assert(operandMANPs.size() == 2 && + operandMANPs[0]->getValue().getMANP().has_value() && + operandMANPs[1]->getValue().getMANP().has_value() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted " + "operands"); + + // x * y = ((x + y)^2 / 4) - ((x - y)^2 / 4) == tlu(x + y) - tlu(x - y) + + const llvm::APInt x = operandMANPs[0]->getValue().getMANP().value(); + const llvm::APInt y = operandMANPs[1]->getValue().getMANP().value(); + + const llvm::APInt beforeTLUs = APIntWidthExtendUAdd(x, y); + const llvm::APInt tlu = {1, 1, false}; + const llvm::APInt result = APIntWidthExtendUAdd(tlu, tlu); + + // this is not optimal as it can increase the resulting noise unnecessarily + return APIntUMax(beforeTLUs, result); +} + static llvm::APInt computeVectorNorm( llvm::ArrayRef shape, int64_t axis, mlir::DenseIntElementsAttr denseValues, llvm::APInt encryptedOperandNorm, @@ -1307,6 +1330,10 @@ public: llvm::dyn_cast( op)) { norm2SqEquiv = getSqMANP(mulEintIntOp, operands); + } else if (auto mulEintOp = + llvm::dyn_cast( + op)) { + norm2SqEquiv = getSqMANP(mulEintOp, operands); } else if (auto matmulEintIntOp = llvm::dyn_cast< mlir::concretelang::FHELinalg::MatMulEintIntOp>(op)) { norm2SqEquiv = getSqMANP(matmulEintIntOp, operands); diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/ops.invalid.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/ops.invalid.mlir index a2c9f99bc..cf85112b4 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/ops.invalid.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/ops.invalid.mlir @@ -120,6 +120,46 @@ func.func @main(%a0: tensor<2x3x4x!FHE.eint<2>>, %a1: tensor<2x3x4xi4>) -> tenso // ----- +///////////////////////////////////////////////// +// FHELinalg.mul_eint +///////////////////////////////////////////////// + +// Incompatible dimension of operands +func.func @main(%a0: tensor<2x2x3x4x!FHE.eint<2>>, %a1: tensor<2x2x2x4x!FHE.eint<2>>) -> tensor<2x2x3x4x!FHE.eint<2>> { + // expected-error @+1 {{'FHELinalg.mul_eint' op has the dimension #2 of the operand #1 incompatible with other operands, got 2 expect 1 or 3}} + %1 = "FHELinalg.mul_eint"(%a0, %a1) : (tensor<2x2x3x4x!FHE.eint<2>>, tensor<2x2x2x4x!FHE.eint<2>>) -> tensor<2x2x3x4x!FHE.eint<2>> + return %1 : tensor<2x2x3x4x!FHE.eint<2>> +} + +// ----- + +// Incompatible dimension of result +func.func @main(%a0: tensor<2x2x3x4x!FHE.eint<2>>, %a1: tensor<2x2x2x4x!FHE.eint<2>>) -> tensor<2x10x3x4x!FHE.eint<2>> { + // expected-error @+1 {{'FHELinalg.mul_eint' op has the dimension #3 of the result incompatible with operands dimension, got 10 expect 2}} + %1 = "FHELinalg.mul_eint"(%a0, %a1) : (tensor<2x2x3x4x!FHE.eint<2>>, tensor<2x2x2x4x!FHE.eint<2>>) -> tensor<2x10x3x4x!FHE.eint<2>> + return %1 : tensor<2x10x3x4x!FHE.eint<2>> +} + +// ----- + +// Incompatible number of dimension between operands and result +func.func @main(%a0: tensor<2x2x3x4x!FHE.eint<2>>, %a1: tensor<2x2x2x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.eint<2>> { + // expected-error @+1 {{'FHELinalg.mul_eint' op should have the number of dimensions of the result equal to the highest number of dimensions of operands, got 3 expect 4}} + %1 = "FHELinalg.mul_eint"(%a0, %a1) : (tensor<2x2x3x4x!FHE.eint<2>>, tensor<2x2x2x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.eint<2>> + return %1 : tensor<2x3x4x!FHE.eint<2>> +} + +// ----- + +// Incompatible width between clear and encrypted witdh +func.func @main(%a0: tensor<2x3x4x!FHE.eint<2>>, %a1: tensor<2x3x4x!FHE.eint<3>>) -> tensor<2x3x4x!FHE.eint<2>> { + // expected-error @+1 {{'FHELinalg.mul_eint' op should have the width of encrypted equals, got 3 expect 2}} + %1 = "FHELinalg.mul_eint"(%a0, %a1) : (tensor<2x3x4x!FHE.eint<2>>, tensor<2x3x4x!FHE.eint<3>>) -> tensor<2x3x4x!FHE.eint<2>> + return %1 : tensor<2x3x4x!FHE.eint<2>> +} + +// ----- + ///////////////////////////////////////////////// // FHELinalg.apply_lookup_table ///////////////////////////////////////////////// diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/ops.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/ops.mlir index 7bbcc6cd5..e66b0f743 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/ops.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/ops.mlir @@ -363,6 +363,60 @@ func.func @mul_eint_int_broadcast_2(%a0: tensor<4x!FHE.eint<2>>, %a1: tensor<3x4 return %1: tensor<3x4x!FHE.eint<2>> } +///////////////////////////////////////////////// +// FHELinalg.mul_eint +///////////////////////////////////////////////// + +// 1D tensor +// CHECK: func.func @mul_eint_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>, %[[a1:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.mul_eint"(%[[a0]], %[[a1]]) : (tensor<4x!FHE.eint<2>>, tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<4x!FHE.eint<2>> +// CHECK-NEXT: } +func.func @mul_eint_1D(%a0: tensor<4x!FHE.eint<2>>, %a1: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { + %1 = "FHELinalg.mul_eint"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> + return %1: tensor<4x!FHE.eint<2>> +} + +// 2D tensor +// CHECK: func.func @mul_eint_2D(%[[a0:.*]]: tensor<2x4x!FHE.eint<2>>, %[[a1:.*]]: tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.mul_eint"(%[[a0]], %[[a1]]) : (tensor<2x4x!FHE.eint<2>>, tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<2x4x!FHE.eint<2>> +// CHECK-NEXT: } +func.func @mul_eint_2D(%a0: tensor<2x4x!FHE.eint<2>>, %a1: tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>> { + %1 = "FHELinalg.mul_eint"(%a0, %a1) : (tensor<2x4x!FHE.eint<2>>, tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>> + return %1: tensor<2x4x!FHE.eint<2>> +} + +// 10D tensor +// CHECK: func.func @mul_eint_10D(%[[a0:.*]]: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, %[[a1:.*]]: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.mul_eint"(%[[a0]], %[[a1]]) : (tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> +// CHECK-NEXT: } +func.func @mul_eint_10D(%a0: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, %a1: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> { + %1 = "FHELinalg.mul_eint"(%a0, %a1) : (tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> + return %1: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> +} + +// Broadcasting with tensor with dimensions equals to one +// CHECK: func.func @mul_eint_broadcast_1(%[[a0:.*]]: tensor<1x4x5x!FHE.eint<2>>, %[[a1:.*]]: tensor<3x4x1x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.mul_eint"(%[[a0]], %[[a1]]) : (tensor<1x4x5x!FHE.eint<2>>, tensor<3x4x1x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<3x4x5x!FHE.eint<2>> +// CHECK-NEXT: } +func.func @mul_eint_broadcast_1(%a0: tensor<1x4x5x!FHE.eint<2>>, %a1: tensor<3x4x1x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>> { + %1 = "FHELinalg.mul_eint"(%a0, %a1) : (tensor<1x4x5x!FHE.eint<2>>, tensor<3x4x1x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>> + return %1: tensor<3x4x5x!FHE.eint<2>> +} + +// Broadcasting with a tensor less dimensions of another +// CHECK: func.func @mul_eint_broadcast_2(%[[a0:.*]]: tensor<4x!FHE.eint<2>>, %[[a1:.*]]: tensor<3x4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.mul_eint"(%[[a0]], %[[a1]]) : (tensor<4x!FHE.eint<2>>, tensor<3x4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<3x4x!FHE.eint<2>> +// CHECK-NEXT: } +func.func @mul_eint_broadcast_2(%a0: tensor<4x!FHE.eint<2>>, %a1: tensor<3x4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>> { + %1 ="FHELinalg.mul_eint"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<3x4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>> + return %1: tensor<3x4x!FHE.eint<2>> +} + ///////////////////////////////////////////////// // FHELinalg.apply_lookup_table ///////////////////////////////////////////////// diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhelinalg.yaml b/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhelinalg.yaml index 634b74c8b..ead4990ae 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhelinalg.yaml +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhelinalg.yaml @@ -206,6 +206,38 @@ tests: - tensor: [65535, 65535, 23844, 0] shape: [4] --- +description: mul_eint_term_to_term_6bits +program: | + func.func @main(%a0: tensor<4x!FHE.eint<7>>, %a1: tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> { + %res = "FHELinalg.mul_eint"(%a0, %a1) : (tensor<4x!FHE.eint<7>>, tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> + return %res : tensor<4x!FHE.eint<7>> + } +tests: + - inputs: + - tensor: [6, 3, 12, 9] + shape: [4] + - tensor: [10, 20, 2, 3] + shape: [4] + outputs: + - tensor: [60, 60, 24, 27] + shape: [4] +--- +description: mul_eint_term_to_term_15bits +program: | + func.func @main(%a0: tensor<4x!FHE.eint<16>>, %a1: tensor<4x!FHE.eint<16>>) -> tensor<4x!FHE.eint<16>> { + %res = "FHELinalg.mul_eint"(%a0, %a1) : (tensor<4x!FHE.eint<16>>, tensor<4x!FHE.eint<16>>) -> tensor<4x!FHE.eint<16>> + return %res : tensor<4x!FHE.eint<16>> + } +tests: + - inputs: + - tensor: [300, 5, 30000, 0] + shape: [4] + - tensor: [100, 1, 1, 0] + shape: [4] + outputs: + - tensor: [30000, 5, 30000, 0] + shape: [4] +--- description: transpose1d program: | func.func @main(%input: tensor<3x!FHE.eint<6>>) -> tensor<3x!FHE.eint<6>> {