From da9dbcef1d2f905439642d70ea82a7217f2b2b11 Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Sat, 30 Jul 2022 13:51:35 +0100 Subject: [PATCH] feat(compiler): coarsen task granularity by aggregating lightweight operations. --- .../RT/Analysis/BuildDataflowTaskGraph.cpp | 210 ++++++++---------- 1 file changed, 88 insertions(+), 122 deletions(-) diff --git a/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp b/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp index caf2b887b..122552ec3 100644 --- a/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp +++ b/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -43,46 +44,49 @@ namespace { // TODO: adjust these two functions based on cost model static bool isCandidateForTask(Operation *op) { return isa< - FHE::AddEintIntOp, FHE::AddEintOp, FHE::SubIntEintOp, FHE::MulEintIntOp, - FHE::ApplyLookupTableEintOp, FHELinalg::MatMulIntEintOp, - FHELinalg::MatMulEintIntOp, FHELinalg::AddEintIntOp, FHELinalg::AddEintOp, - FHELinalg::SubIntEintOp, FHELinalg::NegEintOp, FHELinalg::MulEintIntOp, - FHELinalg::ApplyLookupTableEintOp, FHELinalg::ApplyMultiLookupTableEintOp, + FHE::ApplyLookupTableEintOp, FHELinalg::MatMulEintIntOp, + FHELinalg::AddEintIntOp, FHELinalg::AddEintOp, FHELinalg::SubIntEintOp, + FHELinalg::SubEintIntOp, FHELinalg::SubEintOp, FHELinalg::NegEintOp, + FHELinalg::MulEintIntOp, FHELinalg::ApplyLookupTableEintOp, + FHELinalg::ApplyMultiLookupTableEintOp, FHELinalg::ApplyMappedLookupTableEintOp, FHELinalg::Dot, FHELinalg::MatMulEintIntOp, FHELinalg::MatMulIntEintOp, FHELinalg::SumOp, - FHELinalg::ConcatOp>(op); + FHELinalg::ConcatOp, FHELinalg::Conv2dOp, FHELinalg::TransposeOp>(op); } -/// Identify operations that are beneficial to sink into tasks. These +/// Identify operations that are beneficial to aggregate into tasks. These /// operations must not have side-effects and not be `isCandidateForTask` -static bool isSinkingBeneficiary(Operation *op) { - return isa(op); +static bool isAggregatingBeneficiary(Operation *op) { + return isa(op); } static bool -extractBeneficiaryOps(Operation *op, SetVector existingDependencies, - SetVector &beneficiaryOps, - llvm::SmallPtrSetImpl &availableValues) { +aggregateBeneficiaryOps(Operation *op, SetVector &beneficiaryOps, + llvm::SmallPtrSetImpl &availableValues) { if (beneficiaryOps.count(op)) return true; - if (!isSinkingBeneficiary(op)) + if (!isAggregatingBeneficiary(op)) return false; - for (Value operand : op->getOperands()) { - // It is already visible in the kernel, keep going. - if (availableValues.count(operand)) - continue; - // Else check whether it can be made available via sinking or already is a - // dependency. - Operation *definingOp = operand.getDefiningOp(); - if ((!definingOp || - !extractBeneficiaryOps(definingOp, existingDependencies, - beneficiaryOps, availableValues)) && - !existingDependencies.count(operand)) - return false; + // Gather the new potential dependences created by sinking this op. + llvm::SmallPtrSet newDependencesIfSunk; + for (Value operand : op->getOperands()) + if (!availableValues.count(operand)) + newDependencesIfSunk.insert(operand); + + // We further attempt to sink any new dependence + for (auto dep : newDependencesIfSunk) { + Operation *definingOp = dep.getDefiningOp(); + if (definingOp) + aggregateBeneficiaryOps(definingOp, beneficiaryOps, availableValues); } + // We will sink the operation, mark its results as now available. beneficiaryOps.insert(op); for (Value result : op->getResults()) @@ -90,12 +94,18 @@ extractBeneficiaryOps(Operation *op, SetVector existingDependencies, return true; } -static func::FuncOp getCalledFunction(CallOpInterface callOp) { - SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast(); +static bool isFunctionCallName(OpOperand *use, StringRef name) { + func::CallOp call = dyn_cast_or_null(use->getOwner()); + if (!call) + return false; + SymbolRefAttr sym = call.getCallableForCallee().dyn_cast(); if (!sym) - return nullptr; - return dyn_cast_or_null( - SymbolTable::lookupNearestSymbolFrom(callOp, sym)); + return false; + func::FuncOp called = dyn_cast_or_null( + SymbolTable::lookupNearestSymbolFrom(call, sym)); + if (!called) + return false; + return called.getName() == name; } static void getAliasedUses(Value val, DenseSet &aliasedUses) { @@ -106,9 +116,8 @@ static void getAliasedUses(Value val, DenseSet &aliasedUses) { } } -static bool extractOutputMemrefAllocations( - Operation *op, SetVector existingDependencies, - SetVector &beneficiaryOps, +static bool aggregateOutputMemrefAllocations( + Operation *op, SetVector &beneficiaryOps, llvm::SmallPtrSetImpl &availableValues, RT::DataflowTaskOp taskOp) { if (beneficiaryOps.count(op)) return true; @@ -125,29 +134,19 @@ static bool extractOutputMemrefAllocations( // Call ops targeting concrete-ffi do not have memory effects // interface, so handle apart. // TODO: this could be handled better in BConcrete or higher. - if (isa(use->getOwner())) { - if (getCalledFunction(use->getOwner()).getName() == - "memref_expand_lut_in_trivial_glwe_ct_u64" || - getCalledFunction(use->getOwner()).getName() == - "memref_add_lwe_ciphertexts_u64" || - getCalledFunction(use->getOwner()).getName() == - "memref_add_plaintext_lwe_ciphertext_u64" || - getCalledFunction(use->getOwner()).getName() == - "memref_mul_cleartext_lwe_ciphertext_u64" || - getCalledFunction(use->getOwner()).getName() == - "memref_negate_lwe_ciphertext_u64" || - getCalledFunction(use->getOwner()).getName() == - "memref_keyswitch_lwe_u64" || - getCalledFunction(use->getOwner()).getName() == - "memref_bootstrap_lwe_u64") - if (use->getOwner()->getOperand(0) == use->get()) - return true; + if (isFunctionCallName(use, "memref_expand_lut_in_trivial_glwe_ct_u64") || + isFunctionCallName(use, "memref_add_lwe_ciphertexts_u64") || + isFunctionCallName(use, "memref_add_plaintext_lwe_ciphertext_u64") || + isFunctionCallName(use, "memref_mul_cleartext_lwe_ciphertext_u64") || + isFunctionCallName(use, "memref_negate_lwe_ciphertext_u64") || + isFunctionCallName(use, "memref_keyswitch_lwe_u64") || + isFunctionCallName(use, "memref_bootstrap_lwe_u64")) + if (use->getOwner()->getOperand(0) == use->get()) + return true; - if (getCalledFunction(use->getOwner()).getName() == - "memref_copy_one_rank") - if (use->getOwner()->getOperand(1) == use->get()) - return true; - } + if (isFunctionCallName(use, "memref_copy_one_rank")) + if (use->getOwner()->getOperand(1) == use->get()) + return true; // Otherwise we rely on the memory effect interface auto effectInterface = dyn_cast(use->getOwner()); @@ -176,7 +175,7 @@ static bool extractOutputMemrefAllocations( return false; } -LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) { +LogicalResult coarsenDFTask(RT::DataflowTaskOp taskOp) { Region &taskOpBody = taskOp.body(); // Identify uses from values defined outside of the scope. @@ -184,14 +183,15 @@ LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) { getUsedValuesDefinedAbove(taskOpBody, sinkCandidates); SetVector toBeSunk; - llvm::SmallPtrSet availableValues; + llvm::SmallPtrSet availableValues(sinkCandidates.begin(), + sinkCandidates.end()); for (Value operand : sinkCandidates) { Operation *operandOp = operand.getDefiningOp(); if (!operandOp) continue; - extractBeneficiaryOps(operandOp, sinkCandidates, toBeSunk, availableValues); - extractOutputMemrefAllocations(operandOp, sinkCandidates, toBeSunk, - availableValues, taskOp); + aggregateBeneficiaryOps(operandOp, toBeSunk, availableValues); + aggregateOutputMemrefAllocations(operandOp, toBeSunk, availableValues, + taskOp); } // Insert operations so that the defs get cloned before uses. @@ -202,16 +202,13 @@ LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) { Operation *clonedOp = builder.clone(*op, map); for (auto pair : llvm::zip(op->getResults(), clonedOp->getResults())) replaceAllUsesInRegionWith(std::get<0>(pair), std::get<1>(pair), - taskOp.body()); - // Once this is sunk, remove all operands of the DFT covered by this - for (auto result : op->getResults()) - for (auto operand : llvm::enumerate(taskOp.getOperands())) - if (operand.value() == result) { - taskOp->eraseOperand(operand.index()); - // Once removed, we assume there are no duplicates - break; - } + taskOpBody); } + + SetVector deps; + getUsedValuesDefinedAbove(taskOpBody, deps); + taskOp->setOperands(deps.takeVector()); + return success(); } @@ -251,11 +248,14 @@ protected: builder.setInsertionPointAfter(op); auto dftop = builder.create( op->getLoc(), op->getResultTypes(), op->getOperands()); + // Add the operation to the task OpBuilder tbbuilder(dftop.body()); Operation *clonedOp = tbbuilder.clone(*op, map); - // Add sinkable operations to the task - assert(!failed(sinkOperationsIntoDFTask(dftop)) && + + // Coarsen granularity by aggregating all dependence related + // lower-weight operations. + assert(!failed(coarsenDFTask(dftop)) && "Failing to sink operations into DFT"); // Add terminator @@ -282,44 +282,6 @@ std::unique_ptr createBuildDataflowTaskGraphPass(bool debug) { return std::make_unique(debug); } -namespace { -/// Marker to avoid infinite recursion of the rewriting pattern -static const mlir::StringLiteral kTransformMarker = - "_internal_RT_FixDataflowTaskOpInputsPattern_marker__"; - -class FixDataflowTaskOpInputsPattern - : public mlir::OpRewritePattern { -public: - FixDataflowTaskOpInputsPattern(mlir::MLIRContext *context) - : mlir::OpRewritePattern( - context, ::mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} - - LogicalResult - matchAndRewrite(RT::DataflowTaskOp op, - mlir::PatternRewriter &rewriter) const override { - mlir::OpBuilder::InsertionGuard guard(rewriter); - - if (op->hasAttr(kTransformMarker)) - return failure(); - - // Identify which values need to be passed as dependences to the - // task - this is very conservative and will add constants, index - // operations, etc. A simplification will occur later. - SetVector deps; - getUsedValuesDefinedAbove(op.body(), deps); - auto newop = rewriter.create( - op.getLoc(), op.getResultTypes(), deps.getArrayRef()); - rewriter.mergeBlocks(op.getBody(), newop.getBody(), - newop.getBody()->getArguments()); - rewriter.replaceOp(op, {newop.getResults()}); - - // Mark this as processed to prevent infinite loop - newop.getOperation()->setAttr(kTransformMarker, rewriter.getUnitAttr()); - return success(); - } -}; -} // namespace - namespace { /// For documentation see Autopar.td struct FixupDataflowTaskOpsPass @@ -329,21 +291,25 @@ struct FixupDataflowTaskOpsPass auto module = getOperation(); auto *context = &getContext(); - RewritePatternSet patterns(context); - patterns.add(context); - - if (mlir::applyPatternsAndFoldGreedily(module, std::move(patterns)) - .failed()) - signalPassFailure(); - - // Clear mark and sink any newly created constants or indexing - // operations, etc. to reduce the number of input dependences to - // the task module->walk([](RT::DataflowTaskOp op) { - op.getOperation()->removeAttr(kTransformMarker); - assert(!failed(sinkOperationsIntoDFTask(op)) && + assert(!failed(coarsenDFTask(op)) && "Failing to sink operations into DFT"); }); + + // Finally clear up any remaining alloc/dealloc ops that are + // meaningless + SetVector eraseOps; + module->walk([&](memref::AllocOp op) { + // If this memref.alloc's only use left is the + // dealloc, erase both. + if (op->hasOneUse() && + isa(op->use_begin()->getOwner())) { + eraseOps.insert(op->use_begin()->getOwner()); + eraseOps.insert(op); + } + }); + for (auto op : eraseOps) + op->erase(); } FixupDataflowTaskOpsPass(bool debug) : debug(debug){};