// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include #include "boost/outcome.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Pass/PassManager.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Pass.h" #include "llvm/Support/raw_ostream.h" #include "concrete-optimizer.hpp" #include "concretelang/Common/Error.h" #include "concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h" #include "concretelang/Dialect/FHE/Analysis/utils.h" #include "concretelang/Dialect/FHE/IR/FHEOps.h" #include "concretelang/Dialect/FHE/IR/FHETypes.h" #include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h" #include "concretelang/Support/V0Parameters.h" #include "concretelang/Support/logging.h" #define GEN_PASS_CLASSES #include "concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h.inc" namespace mlir { namespace concretelang { namespace optimizer { namespace { template rust::Slice slice(const std::vector &vec) { return rust::Slice(vec.data(), vec.size()); } template rust::Slice slice(const llvm::ArrayRef &vec) { return rust::Slice(vec.data(), vec.size()); } struct FunctionToDag { // Inputs of operators using Inputs = std::vector; const double NEGLIGIBLE_COMPLEXITY = 0.0; mlir::func::FuncOp func; optimizer::Config config; llvm::DenseMap index; FunctionToDag(mlir::func::FuncOp func, optimizer::Config config) : func(func), config(config) {} #define DEBUG(MSG) \ if (mlir::concretelang::isVerbose()) { \ mlir::concretelang::log_verbose() << MSG << "\n"; \ } outcome::checked, ::concretelang::error::StringError> build() { auto dag = concrete_optimizer::dag::empty(); // Converting arguments as Input for (auto &arg : func.getArguments()) { addArg(dag, arg); } // Converting ops for (auto &bb : func.getBody().getBlocks()) { for (auto &op : bb.getOperations()) { addOperation(dag, op); op.removeAttr("SMANP"); } } if (index.empty()) { // Dag is empty <=> classical function without encryption DEBUG("!!! concrete-optimizer: nothing to do in " << func.getName() << "\n"); return llvm::None; }; DEBUG(std::string(dag->dump())); return std::move(dag); } void addArg(optimizer::Dag &dag, mlir::Value &arg) { DEBUG("Arg " << arg << " " << arg.getType()); if (!fhe::utils::isEncryptedValue(arg)) { return; } auto precision = fhe::utils::getEintPrecision(arg); auto shape = getShape(arg); auto opI = dag->add_input(precision, slice(shape)); index[arg] = opI; } bool hasEncryptedResult(mlir::Operation &op) { for (auto val : op.getResults()) { if (fhe::utils::isEncryptedValue(val)) { return true; } } return false; } void addOperation(optimizer::Dag &dag, mlir::Operation &op) { DEBUG("Instr " << op); if (isReturn(op)) { // This op has no result return; } auto encrypted_inputs = encryptedInputs(op); if (!hasEncryptedResult(op)) { // This op is unrelated to FHE assert(encrypted_inputs.empty()); return; } assert(op.getNumResults() == 1); auto val = op.getResult(0); auto precision = fhe::utils::getEintPrecision(val); if (isLut(op)) { addLut(dag, val, encrypted_inputs, precision); return; } if (auto dot = asDot(op)) { auto weightsOpt = dotWeights(dot); if (weightsOpt) { addDot(dag, val, encrypted_inputs, weightsOpt.getValue()); return; } // If can't find weights return default leveled op DEBUG("Replace Dot by LevelledOp on " << op); } // default addLevelledOp(dag, op, encrypted_inputs); } void addLut(optimizer::Dag &dag, mlir::Value &val, Inputs &encrypted_inputs, int precision) { assert(encrypted_inputs.size() == 1); // No need to distinguish different lut kind until we do approximate // paradigm on outputs auto encrypted_input = encrypted_inputs[0]; std::vector unknowFunction; index[val] = dag->add_lut(encrypted_input, slice(unknowFunction), precision); } void addDot(optimizer::Dag &dag, mlir::Value &val, Inputs &encrypted_inputs, std::vector &weights_vector) { assert(encrypted_inputs.size() == 1); auto weights = concrete_optimizer::weights::vector(slice(weights_vector)); index[val] = dag->add_dot(slice(encrypted_inputs), std::move(weights)); } std::string loc_to_string(mlir::Location location) { std::string loc; llvm::raw_string_ostream loc_stream(loc); location.print(loc_stream); return loc; } void addLevelledOp(optimizer::Dag &dag, mlir::Operation &op, Inputs &inputs) { auto val = op.getResult(0); auto out_shape = getShape(val); if (inputs.empty()) { // Trivial encrypted constants encoding // There are converted to input + levelledop auto precision = fhe::utils::getEintPrecision(val); auto opI = dag->add_input(precision, slice(out_shape)); inputs.push_back(opI); } // Default complexity is negligible double fixed_cost = NEGLIGIBLE_COMPLEXITY; double lwe_dim_cost_factor = NEGLIGIBLE_COMPLEXITY; auto smanp_int = op.getAttrOfType("SMANP"); auto loc = loc_to_string(op.getLoc()); assert(smanp_int && "Missing manp value on a crypto operation"); // TODO: use APIFloat.sqrt when it's available double manp = sqrt(smanp_int.getValue().roundToDouble()); auto comment = std::string(op.getName().getStringRef()) + " " + loc; index[val] = dag->add_levelled_op(slice(inputs), lwe_dim_cost_factor, fixed_cost, manp, slice(out_shape), comment); } Inputs encryptedInputs(mlir::Operation &op) { Inputs inputs; for (auto operand : op.getOperands()) { auto entry = index.find(operand); if (entry == index.end()) { assert(!fhe::utils::isEncryptedValue(operand)); DEBUG("Ignoring as input " << operand); continue; } inputs.push_back(entry->getSecond()); } return inputs; } bool isLut(mlir::Operation &op) { return llvm::isa< mlir::concretelang::FHE::ApplyLookupTableEintOp, mlir::concretelang::FHELinalg::ApplyLookupTableEintOp, mlir::concretelang::FHELinalg::ApplyMultiLookupTableEintOp, mlir::concretelang::FHELinalg::ApplyMappedLookupTableEintOp>(op); } mlir::concretelang::FHELinalg::Dot asDot(mlir::Operation &op) { return llvm::dyn_cast(op); } bool isReturn(mlir::Operation &op) { return llvm::isa(op); } bool isConst(mlir::Operation &op) { return llvm::isa(op); } bool isArg(const mlir::Value &value) { return value.isa(); } llvm::Optional> resolveConstantVectorWeights(mlir::arith::ConstantOp &cstOp) { std::vector values; mlir::DenseIntElementsAttr denseVals = cstOp->getAttrOfType("value"); for (llvm::APInt val : denseVals.getValues()) { if (val.getActiveBits() > 64) { return llvm::None; } values.push_back(val.getSExtValue()); } return values; } llvm::Optional> resolveConstantWeights(mlir::Value &value) { if (auto cstOp = llvm::dyn_cast_or_null( value.getDefiningOp())) { auto shape = getShape(value); switch (shape.size()) { case 1: return resolveConstantVectorWeights(cstOp); default: DEBUG("High-Rank tensor: rely on MANP and levelledOp"); return llvm::None; } } else { DEBUG("Dynamic Weights: rely on MANP and levelledOp"); return llvm::None; } } llvm::Optional> dotWeights(mlir::concretelang::FHELinalg::Dot &dot) { if (dot.getOperands().size() != 2) { return llvm::None; } auto weights = dot.getOperands()[1]; return resolveConstantWeights(weights); } std::vector getShape(mlir::Value &value) { return getShape(value.getType()); } std::vector getShape(mlir::Type type_) { if (auto ranked_tensor = type_.dyn_cast_or_null()) { std::vector shape; for (auto v : ranked_tensor.getShape()) { shape.push_back(v); } return shape; } else { return {}; } } }; } // namespace struct DagPass : ConcreteOptimizerBase { optimizer::Config config; optimizer::FunctionsDag &dags; void runOnOperation() override { mlir::func::FuncOp func = getOperation(); auto name = std::string(func.getName()); DEBUG("ConcreteOptimizer Dag: " << name); auto dag = FunctionToDag(func, config).build(); if (dag) { dags.insert( optimizer::FunctionsDag::value_type(name, std::move(dag.value()))); } else { this->signalPassFailure(); } } DagPass() = delete; DagPass(optimizer::Config config, optimizer::FunctionsDag &dags) : config(config), dags(dags) {} }; // Create an instance of the ConcreteOptimizerPass pass. // A global pass result is communicated using `dags`. // If `debug` is true, for each operation, the pass emits a // remark containing the squared Minimal Arithmetic Noise Padding of // the equivalent dot operation. std::unique_ptr createDagPass(optimizer::Config config, optimizer::FunctionsDag &dags) { return std::make_unique(config, dags); } } // namespace optimizer } // namespace concretelang } // namespace mlir