// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #define GEN_PASS_CLASSES #include namespace mlir { namespace concretelang { namespace { // TODO: adjust these two functions based on cost model static bool isCandidateForTask(Operation *op) { return isa< FHE::AddEintIntOp, FHE::AddEintOp, FHE::SubIntEintOp, FHE::MulEintIntOp, FHE::ApplyLookupTableEintOp, FHELinalg::MatMulIntEintOp, FHELinalg::MatMulEintIntOp, FHELinalg::AddEintIntOp, FHELinalg::AddEintOp, FHELinalg::SubIntEintOp, FHELinalg::NegEintOp, FHELinalg::MulEintIntOp, FHELinalg::ApplyLookupTableEintOp, FHELinalg::ApplyMultiLookupTableEintOp, FHELinalg::ApplyMappedLookupTableEintOp, FHELinalg::Dot, FHELinalg::MatMulEintIntOp, FHELinalg::MatMulIntEintOp, FHELinalg::SumOp, FHELinalg::ConcatOp, FHELinalg::FhelinalgConv2DNchwFchwOp>(op); } // Identify operations that are beneficial to sink into tasks. These // operations must not have side-effects and not be `isCandidateForTask` static bool isSinkingBeneficiary(Operation *op) { return isa(op); } static bool extractBeneficiaryOps(Operation *op, SetVector existingDependencies, SetVector &beneficiaryOps, llvm::SmallPtrSetImpl &availableValues) { if (beneficiaryOps.count(op)) return true; if (!isSinkingBeneficiary(op)) return false; for (Value operand : op->getOperands()) { // It is already visible in the kernel, keep going. if (availableValues.count(operand)) continue; // Else check whether it can be made available via sinking or already is a // dependency. Operation *definingOp = operand.getDefiningOp(); if ((!definingOp || !extractBeneficiaryOps(definingOp, existingDependencies, beneficiaryOps, availableValues)) && !existingDependencies.count(operand)) return false; } // We will sink the operation, mark its results as now available. beneficiaryOps.insert(op); for (Value result : op->getResults()) availableValues.insert(result); return true; } LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) { Region &taskOpBody = taskOp.body(); // Identify uses from values defined outside of the scope. SetVector sinkCandidates; getUsedValuesDefinedAbove(taskOpBody, sinkCandidates); SetVector toBeSunk; llvm::SmallPtrSet availableValues; for (Value operand : sinkCandidates) { Operation *operandOp = operand.getDefiningOp(); if (!operandOp) continue; extractBeneficiaryOps(operandOp, sinkCandidates, toBeSunk, availableValues); } // Insert operations so that the defs get cloned before uses. BlockAndValueMapping map; OpBuilder builder(taskOpBody); for (Operation *op : toBeSunk) { OpBuilder::InsertionGuard guard(builder); Operation *clonedOp = builder.clone(*op, map); for (auto pair : llvm::zip(op->getResults(), clonedOp->getResults())) replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair), taskOp.body()); // Once this is sunk, remove all operands of the DFT covered by this for (auto result : op->getResults()) for (auto operand : llvm::enumerate(taskOp.getOperands())) if (operand.value() == result) { taskOp->eraseOperand(operand.index()); // Once removed, we assume there are no duplicates break; } } return success(); } // For documentation see Autopar.td struct BuildDataflowTaskGraphPass : public BuildDataflowTaskGraphBase { void runOnOperation() override { auto module = getOperation(); module.walk([&](mlir::func::FuncOp func) { if (!func->getAttr("_dfr_work_function_attribute")) func.walk( [&](mlir::Operation *childOp) { this->processOperation(childOp); }); // Perform simplifications, in particular DCE here in case some // of the operations sunk in tasks are no longer needed in the // main function. If the function fails it only means that // nothing was simplified. Doing this here - rather than later // in the compilation pipeline - allows to take advantage of // higher level semantics which we can attach to operations // (e.g., NoSideEffect on FHE::ZeroEintOp). IRRewriter rewriter(func->getContext()); (void)mlir::simplifyRegions(rewriter, func->getRegions()); }); } BuildDataflowTaskGraphPass(bool debug) : debug(debug){}; protected: void processOperation(mlir::Operation *op) { if (isCandidateForTask(op)) { BlockAndValueMapping map; Region &opBody = getOperation().getBody(); OpBuilder builder(opBody); // Create a DFTask for this operation builder.setInsertionPointAfter(op); auto dftop = builder.create( op->getLoc(), op->getResultTypes(), op->getOperands()); // Add the operation to the task OpBuilder tbbuilder(dftop.body()); Operation *clonedOp = tbbuilder.clone(*op, map); // Add sinkable operations to the task assert(!failed(sinkOperationsIntoDFTask(dftop)) && "Failing to sink operations into DFT"); // Add terminator tbbuilder.create(dftop.getLoc(), mlir::TypeRange(), op->getResults()); // Replace the uses of defined values for (auto pair : llvm::zip(op->getResults(), clonedOp->getResults())) replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair), dftop.body()); // Replace uses of the values defined by the task for (auto pair : llvm::zip(op->getResults(), dftop->getResults())) replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair), opBody); // Once uses are re-targeted to the task, delete the operation op->erase(); } } bool debug; }; } // end anonymous namespace std::unique_ptr createBuildDataflowTaskGraphPass(bool debug) { return std::make_unique(debug); } namespace { // Marker to avoid infinite recursion of the rewriting pattern static const mlir::StringLiteral kTransformMarker = "_internal_RT_FixDataflowTaskOpInputsPattern_marker__"; class FixDataflowTaskOpInputsPattern : public mlir::OpRewritePattern { public: FixDataflowTaskOpInputsPattern(mlir::MLIRContext *context) : mlir::OpRewritePattern( context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} LogicalResult matchAndRewrite(RT::DataflowTaskOp op, mlir::PatternRewriter &rewriter) const override { mlir::OpBuilder::InsertionGuard guard(rewriter); if (op->hasAttr(kTransformMarker)) return failure(); // Identify which values need to be passed as dependences to the // task - this is very conservative and will add constants, index // operations, etc. A simplification will occur later. SetVector deps; getUsedValuesDefinedAbove(op.body(), deps); auto newop = rewriter.create( op.getLoc(), op.getResultTypes(), deps.getArrayRef()); rewriter.mergeBlocks(op.getBody(), newop.getBody(), newop.getBody()->getArguments()); rewriter.replaceOp(op, {newop.getResults()}); // Mark this as processed to prevent infinite loop newop.getOperation()->setAttr(kTransformMarker, rewriter.getUnitAttr()); return success(); } }; } // namespace namespace { // For documentation see Autopar.td struct FixupDataflowTaskOpsPass : public FixupDataflowTaskOpsBase { void runOnOperation() override { auto module = getOperation(); auto *context = &getContext(); RewritePatternSet patterns(context); patterns.add(context); if (mlir::applyPatternsAndFoldGreedily(module, std::move(patterns)) .failed()) signalPassFailure(); // Clear mark and sink any newly created constants or indexing // operations, etc. to reduce the number of input dependences to // the task module->walk([](RT::DataflowTaskOp op) { op.getOperation()->removeAttr(kTransformMarker); assert(!failed(sinkOperationsIntoDFTask(op)) && "Failing to sink operations into DFT"); }); } FixupDataflowTaskOpsPass(bool debug) : debug(debug){}; protected: bool debug; }; } // end anonymous namespace std::unique_ptr createFixupDataflowTaskOpsPass(bool debug) { return std::make_unique(debug); } } // end namespace concretelang } // end namespace mlir