// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include #include #include "boost/outcome.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Pass/PassManager.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Pass.h" #include "llvm/Support/raw_ostream.h" #include "concrete-optimizer.hpp" #include "concretelang/Common/Error.h" #include "concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h" #include "concretelang/Dialect/FHE/Analysis/utils.h" #include "concretelang/Dialect/FHE/IR/FHEOps.h" #include "concretelang/Dialect/FHE/IR/FHETypes.h" #include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h" #include "concretelang/Dialect/Tracing/IR/TracingOps.h" #include "concretelang/Support/V0Parameters.h" #include "concretelang/Support/logging.h" #define GEN_PASS_CLASSES #include "concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h.inc" namespace mlir { namespace concretelang { namespace optimizer { namespace { template rust::Slice slice(const std::vector &vec) { return rust::Slice(vec.data(), vec.size()); } template rust::Slice slice(const llvm::ArrayRef &vec) { return rust::Slice(vec.data(), vec.size()); } struct FunctionToDag { // Inputs of operators using Inputs = std::vector; const double NEGLIGIBLE_COMPLEXITY = 0.0; mlir::func::FuncOp func; optimizer::Config config; llvm::DenseMap index; bool setOptimizerID; FunctionToDag(mlir::func::FuncOp func, optimizer::Config config) : func(func), config(config) { setOptimizerID = config.strategy == optimizer::Strategy::DAG_MULTI; } #define DEBUG(MSG) \ if (mlir::concretelang::isVerbose()) { \ mlir::concretelang::log_verbose() << MSG << "\n"; \ } outcome::checked, ::concretelang::error::StringError> build() { auto dag = concrete_optimizer::dag::empty(); // Converting arguments as Input mlir::Builder builder(func.getContext()); for (size_t i = 0; i < func.getNumArguments(); i++) { auto arg = func.getArgument(i); auto optimizerIdx = addArg(dag, arg); if (optimizerIdx.has_value() && setOptimizerID) { func.setArgAttr(i, "TFHE.OId", builder.getI32IntegerAttr(optimizerIdx->index)); } } // Converting ops for (auto &bb : func.getBody().getBlocks()) { for (auto &op : bb.getOperations()) { addOperation(dag, op); } } for (auto &bb : func.getBody().getBlocks()) { for (auto &op : bb.getOperations()) { op.removeAttr("SMANP"); } } if (index.empty()) { // Dag is empty <=> classical function without encryption DEBUG("!!! concrete-optimizer: nothing to do in " << func.getName() << "\n"); return std::nullopt; }; DEBUG(std::string(dag->dump())); return std::move(dag); } std::optional addArg(optimizer::Dag &dag, mlir::Value &arg) { DEBUG("Arg " << arg << " " << arg.getType()); if (!fhe::utils::isEncryptedValue(arg)) { return std::nullopt; } auto precision = fhe::utils::getEintPrecision(arg); auto shape = getShape(arg); auto opI = dag->add_input(precision, slice(shape)); index[arg] = opI; return opI; } bool hasEncryptedResult(mlir::Operation &op) { for (auto val : op.getResults()) { if (fhe::utils::isEncryptedValue(val)) { return true; } } return false; } void addOperation(optimizer::Dag &dag, mlir::Operation &op) { DEBUG("Instr " << op); if (isReturn(op)) { // This op has no result return; } auto encrypted_inputs = encryptedInputs(op); if (!hasEncryptedResult(op)) { // This op is unrelated to FHE assert(encrypted_inputs.empty() || mlir::isa(op)); return; } assert(op.getNumResults() == 1); auto val = op.getResult(0); auto precision = fhe::utils::getEintPrecision(val); concrete_optimizer::dag::OperatorIndex index; if (auto inputType = isLut(op); inputType != nullptr) { addLut(dag, op, inputType, encrypted_inputs, precision); return; } else if (isRound(op)) { index = addRound(dag, val, encrypted_inputs, precision); } else if (isReinterpretPrecision(op)) { addReinterpretPrecision(dag, val, encrypted_inputs, precision); return; } else if (auto lsb = asLsb(op)) { addLsb(dag, lsb, encrypted_inputs); return; } else if (auto lsb = asLsbTensor(op)) { addLsb(dag, lsb, encrypted_inputs); return; } else if (auto dot = asDot(op)) { auto weightsOpt = dotWeights(dot); if (weightsOpt) { index = addDot(dag, val, encrypted_inputs, weightsOpt.value()); } else { // If can't find weights return default leveled op DEBUG("Replace Dot by LevelledOp on " << op); index = addLevelledOp(dag, op, encrypted_inputs); } } else if (auto dot = asDotEint(op)) { addDotEint(dag, dot, encrypted_inputs, precision); // The above function call sets the OIds, can return right away return; } else if (auto mul = asMul(op)) { // special case as mul are rewritten in several optimizer nodes addMul(dag, mul, encrypted_inputs, precision); return; } else if (auto mul = asMulTensor(op)) { // special case as mul are rewritten in several optimizer nodes addMul(dag, mul, encrypted_inputs, precision); return; } else if (auto max = asMax(op)) { // special case as max are rewritten in several optimizer nodes addMax(dag, max, encrypted_inputs, precision); return; } else if (auto maxpool2d = asMaxpool2d(op)) { // special case as max are rewritten in several optimizer nodes addMaxpool2d(dag, maxpool2d, encrypted_inputs, precision); return; } else if (auto matmulEintEint = asMatmulEintEint(op)) { addEncMatMulTensor(dag, matmulEintEint, encrypted_inputs, precision); return; } else { index = addLevelledOp(dag, op, encrypted_inputs); } mlir::Builder builder(op.getContext()); if (setOptimizerID) op.setAttr("TFHE.OId", builder.getI32IntegerAttr(index.index)); } void addLut(optimizer::Dag &dag, mlir::Operation &op, FHE::FheIntegerInterface inputType, Inputs &encrypted_inputs, int precision) { auto val = op.getResult(0); assert(encrypted_inputs.size() == 1); // No need to distinguish different lut kind until we do approximate // paradigm on outputs auto encrypted_input = encrypted_inputs[0]; std::vector unknowFunction; std::vector operatorIndexes; if (inputType.isSigned()) { // std::vector weights_vector{1}; auto addIndex = dag->add_dot(slice(encrypted_inputs), concrete_optimizer::weights::number(1)); encrypted_input = addIndex; operatorIndexes.push_back(addIndex.index); } auto lutIndex = dag->add_lut(encrypted_input, slice(unknowFunction), precision); operatorIndexes.push_back(lutIndex.index); mlir::Builder builder(op.getContext()); if (setOptimizerID) op.setAttr("TFHE.OId", builder.getDenseI32ArrayAttr(operatorIndexes)); index[val] = lutIndex; } concrete_optimizer::dag::OperatorIndex addRound(optimizer::Dag &dag, mlir::Value &val, Inputs &encrypted_inputs, int rounded_precision) { assert(encrypted_inputs.size() == 1); // No need to distinguish different lut kind until we do approximate // paradigm on outputs auto encrypted_input = encrypted_inputs[0]; index[val] = dag->add_round_op(encrypted_input, rounded_precision); return index[val]; } concrete_optimizer::dag::OperatorIndex addReinterpretPrecision(optimizer::Dag &dag, mlir::Value &val, Inputs &encrypted_inputs, int new_precision) { assert(encrypted_inputs.size() == 1); auto encrypted_input = encrypted_inputs[0]; index[val] = dag->add_unsafe_cast_op(encrypted_input, new_precision); return index[val]; } concrete_optimizer::dag::OperatorIndex addDot(optimizer::Dag &dag, mlir::Value &val, Inputs &encrypted_inputs, std::vector &weights_vector) { assert(encrypted_inputs.size() == 1); auto weights = concrete_optimizer::weights::vector(slice(weights_vector)); index[val] = dag->add_dot(slice(encrypted_inputs), std::move(weights)); return index[val]; } std::string loc_to_string(mlir::Location location) { std::string loc; llvm::raw_string_ostream loc_stream(loc); location.print(loc_stream); return loc; } concrete_optimizer::dag::OperatorIndex addLevelledOp(optimizer::Dag &dag, mlir::Operation &op, Inputs &inputs) { auto val = op.getResult(0); auto out_shape = getShape(val); if (inputs.empty()) { // Trivial encrypted constants encoding // There are converted to input + levelledop auto precision = fhe::utils::getEintPrecision(val); auto opI = dag->add_input(precision, slice(out_shape)); inputs.push_back(opI); } // Default complexity is negligible double fixed_cost = NEGLIGIBLE_COMPLEXITY; double lwe_dim_cost_factor = NEGLIGIBLE_COMPLEXITY; auto smanp_int = op.getAttrOfType("SMANP"); auto loc = loc_to_string(op.getLoc()); assert(smanp_int && "Missing manp value on a crypto operation"); // TODO: use APIFloat.sqrt when it's available double manp = sqrt(smanp_int.getValue().roundToDouble()); auto comment = std::string(op.getName().getStringRef()) + " " + loc; index[val] = dag->add_levelled_op(slice(inputs), lwe_dim_cost_factor, fixed_cost, manp, slice(out_shape), comment); return index[val]; } bool isSignedEint(mlir::Type type) { if (auto tensor = type.dyn_cast(); tensor != nullptr) { type = tensor.getElementType(); } return type.cast().isSigned(); } template void addLsb(optimizer::Dag &dag, LsbOp &lsbOp, Inputs &encrypted_inputs) { assert(encrypted_inputs.size() == 1); auto input = lsbOp.getInput(); auto result = lsbOp.getResult(); auto input_precision = fhe::utils::getEintPrecision(input); auto output_precision = fhe::utils::getEintPrecision(result); auto lsb_shiffted_as_1bit_wop = dag->add_dot(slice(encrypted_inputs), concrete_optimizer::weights::number(1 << input_precision)); std::vector unknownFunction; auto overflow_bit_precision = 0; auto lsb_as_0_bits = dag->add_unsafe_cast_op( lsb_shiffted_as_1bit_wop, overflow_bit_precision); // id for rotation auto lsb_result = dag->add_lut(lsb_as_0_bits, slice(unknownFunction), output_precision); auto lsb_result_corrected = idPlaceholder(dag, lsb_result); index[result] = lsb_result_corrected; if (!setOptimizerID) { return; } mlir::SmallVector operatorIndexes = { // see `extractBitWithClearedLowerBits` in // lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp (int32_t)lsb_shiffted_as_1bit_wop.index, // for shift (int32_t)lsb_as_0_bits.index, // for rotation (int32_t)lsb_result.index, // for ks (int32_t)lsb_result.index, // for bootstrap (int32_t)lsb_result_corrected.index, // for correction }; mlir::Builder builder(lsbOp.getContext()); lsbOp->setAttr("TFHE.OId", builder.getDenseI32ArrayAttr(operatorIndexes)); } template void addMul(optimizer::Dag &dag, MulOp &mulOp, Inputs &inputs, int precision) { // x * y = ((x + y)^2 / 4) - ((x - y)^2 / 4) == tlu(x + y) - tlu(x - y) mlir::Value result = mulOp.getResult(); const std::vector resultShape = getShape(result); Operation *xOp = mulOp.getLhs().getDefiningOp(); Operation *yOp = mulOp.getRhs().getDefiningOp(); const double fixedCost = NEGLIGIBLE_COMPLEXITY; const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY; llvm::APInt xSmanp = llvm::APInt{1, 1, false}; if (xOp != nullptr) { const auto xSmanpAttr = xOp->getAttrOfType("SMANP"); assert(xSmanpAttr && "Missing SMANP value on a crypto operation"); xSmanp = xSmanpAttr.getValue(); } llvm::APInt ySmanp = llvm::APInt{1, 1, false}; if (yOp != nullptr) { const auto ySmanpAttr = yOp->getAttrOfType("SMANP"); assert(ySmanpAttr && "Missing SMANP value on a crypto operation"); ySmanp = ySmanpAttr.getValue(); } auto loc = loc_to_string(mulOp.getLoc()); auto comment = std::string(mulOp->getName().getStringRef()) + " " + loc; // (x + y) and (x - y) const double addSubManp = sqrt(xSmanp.roundToDouble() + ySmanp.roundToDouble()); // tlu(v) const double tluManp = 1; // tlu(v1) - tlu(v2) const double tluSubManp = sqrt(tluManp + tluManp); // for tlus const std::vector unknownFunction; // tlu(x + y) auto addNode = dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, addSubManp, slice(resultShape), comment); std::optional lhsCorrectionNode; if (isSignedEint(mulOp.getType())) { // If signed mul we need to add the addition node for correction of the // signed tlu addNode = dag->add_dot( slice(std::vector{addNode}), concrete_optimizer::weights::vector( slice(std::vector{1}))); lhsCorrectionNode = addNode; } auto lhsTluNode = dag->add_lut(addNode, slice(unknownFunction), precision); // tlu(x - y) auto subNode = dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, addSubManp, slice(resultShape), comment); // This is a signed tlu so we need to also add the addition for correction // signed tlu auto rhsCorrectionNode = dag->add_dot( slice(std::vector{subNode}), concrete_optimizer::weights::vector( slice(std::vector{1}))); auto rhsTluNode = dag->add_lut(rhsCorrectionNode, slice(unknownFunction), precision); // tlu(x + y) - tlu(x - y) const std::vector subInputs = { lhsTluNode, rhsTluNode}; auto resultNode = dag->add_levelled_op(slice(subInputs), lweDimCostFactor, fixedCost, tluSubManp, slice(resultShape), comment); index[result] = resultNode; mlir::Builder builder(mulOp.getContext()); mlir::SmallVector operatorIndexes = { (int32_t)addNode.index, (int32_t)lhsTluNode.index, (int32_t)subNode.index, (int32_t)rhsCorrectionNode.index, (int32_t)rhsTluNode.index, (int32_t)resultNode.index, }; if (lhsCorrectionNode.has_value()) { // We push that at the end by convention operatorIndexes.push_back(lhsCorrectionNode.value().index); } if (setOptimizerID) mulOp->setAttr("TFHE.OId", builder.getDenseI32ArrayAttr(operatorIndexes)); } template concrete_optimizer::dag::OperatorIndex addTensorInnerProductEncEnc(optimizer::Dag &dag, InnerProductOp &innerProductOp, Inputs &inputs, int precision) { mlir::Value result = innerProductOp.getResult(); const std::vector resultShape = getShape(result); // We assume a first tensorized matmul step // is the construction of matrices of: // - sums of all pairs // - differences of all pairs // where pairs are pairs of values that are to be multiplied together // Compute the number of elements in each of the // matrices of pairs auto lhsType = ((mlir::Type)innerProductOp.getLhs().getType()) .cast(); auto rhsType = ((mlir::Type)innerProductOp.getRhs().getType()) .cast(); std::vector lhsShape = lhsType.getShape(); std::vector rhsShape = rhsType.getShape(); if (rhsShape.size() == 1) rhsShape.push_back(1); if (lhsShape.size() == 1) lhsShape.emplace(lhsShape.begin(), 1); int64_t rhsDims = (int64_t)rhsShape.size(); int64_t lhsDims = (int64_t)lhsShape.size(); // Suppose lhsDims is (5, 3, 2) -> 5 matrices of size 3x2 (2 is the // reduction dimension) and rhsDims is (3, 5, 2, 3) -> 3x 5 matrices of // size 2x3 the pair matrix would have size (3, 5, 3, 3, 2) this is the // shape of the matrix onto which we apply the TLUs that compute the // multiplication of all pairs of values // The RHS can be a (N,) matrix, the outer dimension is supposed to be 1 int64_t rhsOuterDim = rhsShape[rhsDims - 1]; int64_t lhsOuterDim = lhsShape[lhsDims - 2]; std::vector pairMatrixShape; // Compute the output matrix dimension // Corresponding dimensions that are considered "compatible" are "N, 1", "1, // N", "N, N" int64_t rhsDimIter = rhsDims - 3, lhsDimIter = lhsDims - 3; if (rhsDimIter >= 0 && lhsDimIter >= 0) { while (rhsDimIter >= 0 && lhsDimIter >= 0 && (lhsShape[lhsDimIter] == rhsShape[rhsDimIter] || lhsShape[lhsDimIter] == 1 || rhsShape[rhsDimIter] == 1)) { pairMatrixShape.push_back( std::max(rhsShape[rhsDimIter], lhsShape[lhsDimIter])); --lhsDimIter; --rhsDimIter; } } assert((lhsDimIter < 0 || rhsDimIter < 0) && "Bad dimensions given to matmul or dot"); while (lhsDimIter >= 0) { pairMatrixShape.push_back(lhsShape[lhsDimIter]); --lhsDimIter; } while (rhsDimIter >= 0) { pairMatrixShape.push_back(rhsShape[rhsDimIter]); --rhsDimIter; } // Add the outer dimensions of the individual matrices pairMatrixShape.push_back(lhsOuterDim); pairMatrixShape.push_back(rhsOuterDim); // Add the reduction dimension // The number of elements in the dot product // is the number of cells on the reduction axis (aka "destroyed dimension") int64_t reductionDimSize = rhsShape[rhsDims - 2]; assert(lhsShape[lhsDims - 1] == reductionDimSize); pairMatrixShape.push_back(reductionDimSize); // Compute the manp of the various steps // in the matmul of enc x enc: // 1. (x + y) and (x - y) -> supposing broadcasting is used // to tensorize this operation Operation *xOp = innerProductOp.getLhs().getDefiningOp(); Operation *yOp = innerProductOp.getRhs().getDefiningOp(); const double fixedCost = NEGLIGIBLE_COMPLEXITY; const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY; llvm::APInt xSmanp = llvm::APInt{1, 1, false}; if (xOp != nullptr) { const auto xSmanpAttr = xOp->getAttrOfType("SMANP"); assert(xSmanpAttr && "Missing SMANP value on a crypto operation"); xSmanp = xSmanpAttr.getValue(); } llvm::APInt ySmanp = llvm::APInt{1, 1, false}; if (yOp != nullptr) { const auto ySmanpAttr = yOp->getAttrOfType("SMANP"); assert(ySmanpAttr && "Missing SMANP value on a crypto operation"); ySmanp = ySmanpAttr.getValue(); } auto loc = loc_to_string(innerProductOp.getLoc()); auto comment = std::string(innerProductOp->getName().getStringRef()) + " " + loc; // (x + y) and (x - y) const double addSubManp = sqrt(xSmanp.roundToDouble() + ySmanp.roundToDouble()); // tlu(v) const double tluManp = 1; // tlu(v1) - tlu(v2) const double tluSubManp = sqrt(tluManp + tluManp); // for tlus const std::vector unknownFunction; // tlu(x + y) auto addNode = dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, addSubManp, slice(pairMatrixShape), comment); std::optional lhsCorrectionNode; if (isSignedEint(innerProductOp.getType())) { // If signed mul we need to add the addition node for correction of the // signed tlu addNode = dag->add_dot( slice(std::vector{addNode}), concrete_optimizer::weights::vector( slice(std::vector{1}))); lhsCorrectionNode = addNode; } auto lhsTluNode = dag->add_lut(addNode, slice(unknownFunction), precision); // tlu(x - y) auto subNode = dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, addSubManp, slice(pairMatrixShape), comment); // This is a signed tlu so we need to also add the addition for correction // signed tlu auto rhsCorrectionNode = dag->add_dot( slice(std::vector{subNode}), concrete_optimizer::weights::vector( slice(std::vector{1}))); auto rhsTluNode = dag->add_lut(rhsCorrectionNode, slice(unknownFunction), precision); // tlu(x + y) - tlu(x - y) const std::vector subInputs = { lhsTluNode, rhsTluNode}; auto resultNode = dag->add_levelled_op(slice(subInputs), lweDimCostFactor, fixedCost, tluSubManp, slice(pairMatrixShape), comment); // 3. Sum(tlu(x + y) - tlu(x - y)) // Create a leveled op that simulates concatenation. It takes // as inputs all the intermediary dot product results and produces // the output tensor // Default complexity is negligible double fixed_cost = NEGLIGIBLE_COMPLEXITY; double lwe_dim_cost_factor = NEGLIGIBLE_COMPLEXITY; // For the output of the operation, take the MANP from the MANP pass mlir::Operation *op = innerProductOp.getOperation(); mlir::IntegerAttr smanp_int = op->getAttrOfType("SMANP"); assert(smanp_int && "Missing manp value on a crypto operation"); const std::vector sumOperands = { resultNode}; // TODO: use APIFloat.sqrt when it's available double manp = sqrt(smanp_int.getValue().roundToDouble()); index[result] = dag->add_levelled_op(slice(sumOperands), lwe_dim_cost_factor, fixed_cost, manp, slice(resultShape), comment); // Create the TFHE.OId attributes // The first elements of the vector are nodes for the encrypted // multiplication mlir::Builder builder(innerProductOp.getContext()); mlir::SmallVector operatorIndexes = { (int32_t)addNode.index, (int32_t)lhsTluNode.index, (int32_t)subNode.index, (int32_t)rhsCorrectionNode.index, (int32_t)rhsTluNode.index, (int32_t)resultNode.index, }; if (lhsCorrectionNode.has_value()) { operatorIndexes.push_back(lhsCorrectionNode.value().index); } // The last element of the vector is the node for the addition operatorIndexes.push_back((int32_t)index[result].index); if (setOptimizerID) innerProductOp->setAttr("TFHE.OId", builder.getDenseI32ArrayAttr(operatorIndexes)); return index[result]; } concrete_optimizer::dag::OperatorIndex addEncMatMulTensor(optimizer::Dag &dag, FHELinalg::MatMulEintEintOp &matmulOp, Inputs &inputs, int precision) { return addTensorInnerProductEncEnc( dag, matmulOp, inputs, precision); } concrete_optimizer::dag::OperatorIndex addDotEint(optimizer::Dag &dag, FHELinalg::DotEint &dotOp, Inputs &inputs, int precision) { return addTensorInnerProductEncEnc(dag, dotOp, inputs, precision); } void addMax(optimizer::Dag &dag, FHE::MaxEintOp &maxOp, Inputs &inputs, int precision) { mlir::Value result = maxOp.getResult(); const std::vector resultShape = getShape(result); Operation *xOp = maxOp.getX().getDefiningOp(); Operation *yOp = maxOp.getY().getDefiningOp(); const double fixedCost = NEGLIGIBLE_COMPLEXITY; const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY; llvm::APInt xSmanp = llvm::APInt{1, 1, false}; if (xOp != nullptr) { const auto xSmanpAttr = xOp->getAttrOfType("SMANP"); assert(xSmanpAttr && "Missing SMANP value on a crypto operation"); xSmanp = xSmanpAttr.getValue(); } llvm::APInt ySmanp = llvm::APInt{1, 1, false}; if (yOp != nullptr) { const auto ySmanpAttr = yOp->getAttrOfType("SMANP"); assert(ySmanpAttr && "Missing SMANP value on a crypto operation"); ySmanp = ySmanpAttr.getValue(); } const double subManp = sqrt(xSmanp.roundToDouble() + ySmanp.roundToDouble()); auto loc = loc_to_string(maxOp.getLoc()); auto comment = std::string(maxOp->getName().getStringRef()) + " " + loc; auto subNode = dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, subManp, slice(resultShape), comment); const double tluNodeManp = 1; const std::vector unknownFunction; auto tluNode = dag->add_lut(subNode, slice(unknownFunction), precision); const double addManp = sqrt(tluNodeManp + ySmanp.roundToDouble()); const std::vector addInputs = { tluNode, inputs[1]}; auto resultNode = dag->add_levelled_op(slice(addInputs), lweDimCostFactor, fixedCost, addManp, slice(resultShape), comment); index[result] = resultNode; // Set attribute on the MLIR node mlir::Builder builder(maxOp.getContext()); mlir::SmallVector operatorIndexes = {(int32_t)subNode.index, (int32_t)tluNode.index, (int32_t)resultNode.index}; if (setOptimizerID) maxOp->setAttr("TFHE.OId", builder.getDenseI32ArrayAttr(operatorIndexes)); } void addMaxpool2d(optimizer::Dag &dag, FHELinalg::Maxpool2dOp &maxpool2dOp, Inputs &inputs, int precision) { mlir::Value result = maxpool2dOp.getResult(); const std::vector resultShape = getShape(result); // all TLUs are flattened into a dimension // to create a single TLU node in optimizer dag std::vector fakeShape = resultShape; uint64_t numberOfComparisons = 1; for (auto dimensionSize : maxpool2dOp.getKernelShape().getValues()) { numberOfComparisons *= dimensionSize; } fakeShape.push_back(numberOfComparisons); Operation *inputOp = maxpool2dOp.getInput().getDefiningOp(); const double fixedCost = NEGLIGIBLE_COMPLEXITY; const double lweDimCostFactor = NEGLIGIBLE_COMPLEXITY; llvm::APInt inputSmanp = llvm::APInt{1, 1, false}; if (inputOp != nullptr) { const auto inputSmanpAttr = inputOp->getAttrOfType("SMANP"); assert(inputSmanpAttr && "Missing SMANP value on a crypto operation"); inputSmanp = inputSmanpAttr.getValue(); } const double subManp = sqrt(2 * inputSmanp.roundToDouble() + 1); auto loc = loc_to_string(maxpool2dOp.getLoc()); auto comment = std::string(maxpool2dOp->getName().getStringRef()) + " " + loc; auto subNode = dag->add_levelled_op(slice(inputs), lweDimCostFactor, fixedCost, subManp, slice(fakeShape), comment); const std::vector unknownFunction; auto tluNode = dag->add_lut(subNode, slice(unknownFunction), precision); const double addManp = sqrt(inputSmanp.roundToDouble() + 1); const std::vector addInputs = { tluNode, inputs[0]}; auto resultNode = dag->add_levelled_op(slice(addInputs), lweDimCostFactor, fixedCost, addManp, slice(resultShape), comment); index[result] = resultNode; // Set attribute on the MLIR node mlir::Builder builder(maxpool2dOp.getContext()); mlir::SmallVector operatorIndexes = {(int32_t)subNode.index, (int32_t)tluNode.index, (int32_t)resultNode.index}; // TODO : The substraction of the signed case is not given to the optimizer // which could lead to some issue with the dag partitioning of the // optimizer. // Note: Should not be an issue while the partition are computed // on the precision. if (setOptimizerID) maxpool2dOp->setAttr("TFHE.OId", builder.getDenseI32ArrayAttr(operatorIndexes)); } concrete_optimizer::dag::OperatorIndex idPlaceholder(optimizer::Dag &dag, concrete_optimizer::dag::OperatorIndex input) { std::vector inputs = {input}; return dag->add_dot(slice(inputs), concrete_optimizer::weights::number(1)); } Inputs encryptedInputs(mlir::Operation &op) { Inputs inputs; for (auto operand : op.getOperands()) { auto entry = index.find(operand); if (entry == index.end()) { assert(!fhe::utils::isEncryptedValue(operand)); DEBUG("Ignoring as input " << operand); continue; } inputs.push_back(entry->getSecond()); } return inputs; } template FHE::FheIntegerInterface getEintTypeOfLut(LinalgApplyLookupTable op) { auto tensorType = op.getT().getType().template dyn_cast(); auto eint = tensorType.getElementType() .template dyn_cast(); assert(eint != nullptr); return eint; } // Returns the FHE integer type on which the lut is performed else return a // nullptr FHE::FheIntegerInterface isLut(mlir::Operation &op) { if (auto lut = llvm::dyn_cast(op); lut != nullptr) { auto eint = lut.getA().getType().dyn_cast(); assert(eint != nullptr); return eint; } if (auto lut = llvm::dyn_cast< mlir::concretelang::FHELinalg::ApplyLookupTableEintOp>(op); lut != nullptr) { return getEintTypeOfLut(lut); } if (auto lut = llvm::dyn_cast< mlir::concretelang::FHELinalg::ApplyMultiLookupTableEintOp>(op); lut != nullptr) { return getEintTypeOfLut(lut); } if (auto lut = llvm::dyn_cast< mlir::concretelang::FHELinalg::ApplyMappedLookupTableEintOp>(op); lut != nullptr) { return getEintTypeOfLut(lut); } return nullptr; } bool isRound(mlir::Operation &op) { return llvm::isa(op) || llvm::isa(op); } bool isReinterpretPrecision(mlir::Operation &op) { return llvm::isa(op) || llvm::isa( op); } mlir::concretelang::FHELinalg::Dot asDot(mlir::Operation &op) { return llvm::dyn_cast(op); } mlir::concretelang::FHELinalg::DotEint asDotEint(mlir::Operation &op) { return llvm::dyn_cast(op); } mlir::concretelang::FHE::MulEintOp asMul(mlir::Operation &op) { return llvm::dyn_cast(op); } mlir::concretelang::FHELinalg::MulEintOp asMulTensor(mlir::Operation &op) { return llvm::dyn_cast(op); } mlir::concretelang::FHE::MaxEintOp asMax(mlir::Operation &op) { return llvm::dyn_cast(op); } mlir::concretelang::FHELinalg::Maxpool2dOp asMaxpool2d(mlir::Operation &op) { return llvm::dyn_cast(op); } mlir::concretelang::FHELinalg::MatMulEintEintOp asMatmulEintEint(mlir::Operation &op) { return llvm::dyn_cast(op); } mlir::concretelang::FHE::LsbEintOp asLsb(mlir::Operation &op) { return llvm::dyn_cast(op); } mlir::concretelang::FHELinalg::LsbEintOp asLsbTensor(mlir::Operation &op) { return llvm::dyn_cast(op); } bool isReturn(mlir::Operation &op) { return llvm::isa(op); } bool isConst(mlir::Operation &op) { return llvm::isa(op); } bool isArg(const mlir::Value &value) { return value.isa(); } std::optional> resolveConstantVectorWeights(mlir::arith::ConstantOp &cstOp) { std::vector values; mlir::DenseIntElementsAttr denseVals = cstOp->getAttrOfType("value"); for (llvm::APInt val : denseVals.getValues()) { if (val.getActiveBits() > 64) { return std::nullopt; } values.push_back(val.getSExtValue()); } return values; } std::optional> resolveConstantWeights(mlir::Value &value) { if (auto cstOp = llvm::dyn_cast_or_null( value.getDefiningOp())) { auto shape = getShape(value); switch (shape.size()) { case 1: return resolveConstantVectorWeights(cstOp); default: DEBUG("High-Rank tensor: rely on MANP and levelledOp"); return std::nullopt; } } else { DEBUG("Dynamic Weights: rely on MANP and levelledOp"); return std::nullopt; } } std::optional> dotWeights(mlir::concretelang::FHELinalg::Dot &dot) { if (dot.getOperands().size() != 2) { return std::nullopt; } auto weights = dot.getOperands()[1]; return resolveConstantWeights(weights); } std::vector getShape(mlir::Value &value) { return getShape(value.getType()); } std::vector getShape(mlir::Type type_) { if (auto ranked_tensor = type_.dyn_cast_or_null()) { std::vector shape; for (auto v : ranked_tensor.getShape()) { shape.push_back(v); } return shape; } else { return {}; } } }; } // namespace struct DagPass : ConcreteOptimizerBase { optimizer::Config config; optimizer::FunctionsDag &dags; void runOnOperation() override { mlir::func::FuncOp func = getOperation(); auto name = std::string(func.getName()); DEBUG("ConcreteOptimizer Dag: " << name); auto dag = FunctionToDag(func, config).build(); if (dag) { dags.insert( optimizer::FunctionsDag::value_type(name, std::move(dag.value()))); } else { this->signalPassFailure(); } } DagPass() = delete; DagPass(optimizer::Config config, optimizer::FunctionsDag &dags) : config(config), dags(dags) {} }; // Create an instance of the ConcreteOptimizerPass pass. // A global pass result is communicated using `dags`. // If `debug` is true, for each operation, the pass emits a // remark containing the squared Minimal Arithmetic Noise Padding of // the equivalent dot operation. std::unique_ptr createDagPass(optimizer::Config config, optimizer::FunctionsDag &dags) { return std::make_unique(config, dags); } } // namespace optimizer } // namespace concretelang } // namespace mlir