// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #define GEN_PASS_CLASSES #include namespace mlir { namespace concretelang { namespace { // TODO: adjust these two functions based on cost model static bool isCandidateForTask(Operation *op) { return isa< FHE::ApplyLookupTableEintOp, FHELinalg::MatMulEintIntOp, FHELinalg::AddEintIntOp, FHELinalg::AddEintOp, FHELinalg::SubIntEintOp, FHELinalg::SubEintIntOp, FHELinalg::SubEintOp, FHELinalg::NegEintOp, FHELinalg::MulEintIntOp, FHELinalg::ApplyLookupTableEintOp, FHELinalg::ApplyMultiLookupTableEintOp, FHELinalg::ApplyMappedLookupTableEintOp, FHELinalg::Dot, FHELinalg::MatMulEintIntOp, FHELinalg::MatMulIntEintOp, FHELinalg::SumOp, FHELinalg::ConcatOp, FHELinalg::Conv2dOp, FHELinalg::TransposeOp>(op); } /// Identify operations that are beneficial to aggregate into tasks. These /// operations must not have side-effects and not be `isCandidateForTask` static bool isAggregatingBeneficiary(Operation *op) { return isa(op); } static bool aggregateBeneficiaryOps(Operation *op, SetVector &beneficiaryOps, llvm::SmallPtrSetImpl &availableValues) { if (beneficiaryOps.count(op)) return true; if (!isAggregatingBeneficiary(op)) return false; // Gather the new potential dependences created by sinking this op. llvm::SmallPtrSet newDependencesIfSunk; for (Value operand : op->getOperands()) if (!availableValues.count(operand)) newDependencesIfSunk.insert(operand); // We further attempt to sink any new dependence for (auto dep : newDependencesIfSunk) { Operation *definingOp = dep.getDefiningOp(); if (definingOp) aggregateBeneficiaryOps(definingOp, beneficiaryOps, availableValues); } // We will sink the operation, mark its results as now available. beneficiaryOps.insert(op); for (Value result : op->getResults()) availableValues.insert(result); return true; } LogicalResult coarsenDFTask(RT::DataflowTaskOp taskOp) { Region &taskOpBody = taskOp.body(); // Identify uses from values defined outside of the scope. SetVector sinkCandidates; getUsedValuesDefinedAbove(taskOpBody, sinkCandidates); SetVector toBeSunk; llvm::SmallPtrSet availableValues(sinkCandidates.begin(), sinkCandidates.end()); for (Value operand : sinkCandidates) { Operation *operandOp = operand.getDefiningOp(); if (!operandOp) continue; aggregateBeneficiaryOps(operandOp, toBeSunk, availableValues); } // Insert operations so that the defs get cloned before uses. BlockAndValueMapping map; OpBuilder builder(taskOpBody); for (Operation *op : toBeSunk) { OpBuilder::InsertionGuard guard(builder); Operation *clonedOp = builder.clone(*op, map); for (auto pair : llvm::zip(op->getResults(), clonedOp->getResults())) replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair), taskOpBody); } SetVector deps; getUsedValuesDefinedAbove(taskOpBody, deps); taskOp->setOperands(deps.takeVector()); return success(); } /// For documentation see Autopar.td struct BuildDataflowTaskGraphPass : public BuildDataflowTaskGraphBase { void runOnOperation() override { auto module = getOperation(); module.walk([&](mlir::func::FuncOp func) { if (!func->getAttr("_dfr_work_function_attribute")) func.walk( [&](mlir::Operation *childOp) { this->processOperation(childOp); }); // Perform simplifications, in particular DCE here in case some // of the operations sunk in tasks are no longer needed in the // main function. If the function fails it only means that // nothing was simplified. Doing this here - rather than later // in the compilation pipeline - allows to take advantage of // higher level semantics which we can attach to operations // (e.g., NoSideEffect on FHE::ZeroEintOp). IRRewriter rewriter(func->getContext()); (void)mlir::simplifyRegions(rewriter, func->getRegions()); }); } BuildDataflowTaskGraphPass(bool debug) : debug(debug){}; protected: void processOperation(mlir::Operation *op) { if (isCandidateForTask(op)) { BlockAndValueMapping map; Region &opBody = getOperation().getBody(); OpBuilder builder(opBody); // Create a DFTask for this operation builder.setInsertionPointAfter(op); auto dftop = builder.create( op->getLoc(), op->getResultTypes(), op->getOperands()); // Add the operation to the task OpBuilder tbbuilder(dftop.body()); Operation *clonedOp = tbbuilder.clone(*op, map); // Coarsen granularity by aggregating all dependence related // lower-weight operations. assert(!failed(coarsenDFTask(dftop)) && "Failing to sink operations into DFT"); // Add terminator tbbuilder.create(dftop.getLoc(), mlir::TypeRange(), op->getResults()); // Replace the uses of defined values for (auto pair : llvm::zip(op->getResults(), clonedOp->getResults())) replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair), dftop.body()); // Replace uses of the values defined by the task for (auto pair : llvm::zip(op->getResults(), dftop->getResults())) replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair), opBody); // Once uses are re-targeted to the task, delete the operation op->erase(); } } bool debug; }; } // end anonymous namespace std::unique_ptr createBuildDataflowTaskGraphPass(bool debug) { return std::make_unique(debug); } } // end namespace concretelang } // end namespace mlir