#include "triton/Analysis/Utility.h" #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Matchers.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Tools/Sys/GetEnv.hpp" #include namespace mlir { namespace { int getParentAxis(Attribute layout, int axis) { if (auto sliceEncoding = layout.dyn_cast()) { axis = axis < sliceEncoding.getDim() ? axis : axis + 1; return getParentAxis(sliceEncoding.getParent(), axis); } return axis; } SmallVector getParentOrder(Attribute layout) { if (auto sliceEncoding = layout.dyn_cast()) { return getParentOrder(sliceEncoding.getParent()); } return triton::gpu::getOrder(layout); } } // namespace bool ReduceOpHelper::isFastReduction() { // Disable fast reduction only for debugging purpose if (::triton::tools::getBoolEnv("DISABLE_FAST_REDUCTION")) return false; return getParentAxis(getSrcLayout(), axis) == getParentOrder(getSrcLayout())[0]; } unsigned ReduceOpHelper::getInterWarpSize() { auto srcReduceDimSize = static_cast(srcShape[axis]); unsigned sizeIntraWarps = getIntraWarpSize(); return std::min(srcReduceDimSize / sizeIntraWarps, triton::gpu::getWarpsPerCTA(getSrcLayout())[axis]); } unsigned ReduceOpHelper::getIntraWarpSize() { auto srcReduceDimSize = static_cast(srcShape[axis]); return std::min(srcReduceDimSize, triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]); } unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() { auto srcReduceDimSize = static_cast(srcShape[axis]); unsigned sizeIntraWarps = getIntraWarpSizeWithUniqueData(); return std::min(srcReduceDimSize / sizeIntraWarps, triton::gpu::getWarpsPerCTAWithUniqueData( getSrcLayout(), getSrcShape())[axis]); } unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() { auto srcReduceDimSize = static_cast(srcShape[axis]); unsigned elementPerThreads = triton::gpu::getUniqueContigPerThread( getSrcLayout(), getSrcShape())[axis]; return std::min(srcReduceDimSize / elementPerThreads, triton::gpu::getThreadsPerWarpWithUniqueData( getSrcLayout(), getSrcShape())[axis]); } unsigned ReduceOpHelper::getThreadsReductionAxis() { auto srcLayout = getSrcLayout(); auto srcShape = getSrcShape(); return triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape)[axis] * triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis]; } SmallVector ReduceOpHelper::getScratchConfigBasic() { auto smemShape = convertType(getSrcShape()); smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis()); return smemShape; } bool ReduceOpHelper::isWarpSynchronous() { auto argsLayout = getSrcLayout(); return isFastReduction() && (triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1); } SmallVector> ReduceOpHelper::getScratchConfigsFast() { SmallVector> smemShapes(3); auto argLayout = getSrcLayout(); auto argLayoutMma = argLayout.dyn_cast(); // that case doesn't need inter-warp communication if (isWarpSynchronous()) return {{0, 0}, {0, 0}}; /// shared memory block0 smemShapes[0] = convertType(getSrcShape()); smemShapes[0][axis] = getInterWarpSize(); /// FIXME(Qingyi): This size is actually larger than required. /// shared memory block1: auto mod = op->getParentOfType(); unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); unsigned threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); smemShapes[1].push_back(numWarps * threadsPerWarp); return smemShapes; } unsigned ReduceOpHelper::getScratchSizeInBytes() { unsigned elems = 0; if (isFastReduction()) { auto smemShapes = getScratchConfigsFast(); for (const auto &smemShape : smemShapes) elems = std::max(elems, product(smemShape)); } else { auto smemShape = getScratchConfigBasic(); elems = product(smemShape); } unsigned bytesPerElem = 0; for (const auto &ty : srcElementTypes) { bytesPerElem += ty.getIntOrFloatBitWidth() / 8; } return bytesPerElem * elems; } bool ReduceOpHelper::isSupportedLayout() { auto srcLayout = getSrcLayout(); if (srcLayout.isa()) { return true; } if (auto mmaLayout = srcLayout.dyn_cast()) { if (mmaLayout.isAmpere()) { return true; } } if (auto mfmaLayout = srcLayout.dyn_cast()) { return true; } if (auto sliceLayout = srcLayout.dyn_cast()) { return true; } return false; } unsigned ScanLoweringHelper::getAxisNumElementsPerThread() { return getEncoding().getSizePerThread()[getAxis()]; } unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() { SmallVector sizePerThreads = getContigPerThread(getEncoding()); sizePerThreads[getAxis()] = 1; return product(sizePerThreads); } Region &ScanLoweringHelper::getCombineOp() { return scanOp.getCombineOp(); } unsigned ScanLoweringHelper::getAxisNumThreadsPerWarp() { return triton::gpu::getThreadsPerWarp(getEncoding())[getAxis()]; } unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() { auto threadsPerWarp = triton::gpu::getThreadsPerWarp(getEncoding()); threadsPerWarp[getAxis()] = 1; return product(threadsPerWarp); } // Return the flat numbers of threads computing independent scan results. unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() { unsigned numParallelThreadsPerWarp = getNonAxisNumThreadsPerWarp(); auto warpsPerCTA = triton::gpu::getWarpsPerCTA(getEncoding()); warpsPerCTA[getAxis()] = 1; unsigned numParallelWarpsPerCTA = product(warpsPerCTA); return numParallelThreadsPerWarp * numParallelWarpsPerCTA; } unsigned ScanLoweringHelper::getAxisNumWarps() { auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); return warpsPerCTA[getAxis()]; } unsigned ScanLoweringHelper::getAxisNumBlocks() { auto type = scanOp.getOperand(0).getType().cast(); auto sizePerThreads = triton::gpu::getSizePerThread(srcEncoding); auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); unsigned axis = getAxis(); return ceil( type.getShape()[axis], (sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis])); } unsigned ScanLoweringHelper::getNonAxisNumBlocks() { auto type = scanOp.getOperand(0).getType().cast(); auto sizePerThreads = triton::gpu::getSizePerThread(srcEncoding); auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); unsigned axis = getAxis(); unsigned numBlocks = 1; for (unsigned i = 0; i < sizePerThreads.size(); i++) { if (i == axis) continue; numBlocks *= ceil( type.getShape()[i], (sizePerThreads[i] * threadsPerWarp[i] * warpsPerCTA[i])); } return numBlocks; } bool ScanLoweringHelper::isSupported() { // TODO: Support the following cases: // 1. Scan on non-blocking encodings // 2. Scan with multiple operands if (!isa(srcEncoding)) return false; if (scanOp.getNumOperands() != 1) return false; return true; } unsigned ScanLoweringHelper::getScratchSizeInBytes() { auto type = scanOp.getOperand(0).getType().cast(); unsigned elementSizeInBytes = type.getElementTypeBitWidth() / 8; auto mod = scanOp->getParentOfType(); unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); unsigned numNonAxisElementsPerWapr = getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread(); unsigned numElements = numWarps * numNonAxisElementsPerWapr * getAxisNumBlocks() * getNonAxisNumBlocks(); return elementSizeInBytes * numElements; } triton::gpu::BlockedEncodingAttr ScanLoweringHelper::getEncoding() { return srcEncoding.cast(); } unsigned ScanLoweringHelper::getAxisElementStride() { auto order = triton::gpu::getOrder(srcEncoding); unsigned stride = 1; for (unsigned dim : order) { if (dim == getAxis()) return stride; stride *= getContigPerThread(getEncoding())[dim]; } llvm_unreachable("Axis not found in order"); } unsigned ScanLoweringHelper::getAxisThreadStride() { auto order = triton::gpu::getOrder(srcEncoding); unsigned stride = 1; for (unsigned dim : order) { if (dim == getAxis()) return stride; stride *= getEncoding().getThreadsPerWarp()[dim]; } llvm_unreachable("Axis not found in order"); } unsigned ScanLoweringHelper::getAxisBlockStride() { auto order = triton::gpu::getOrder(srcEncoding); unsigned stride = 1; auto type = scanOp.getOperand(0).getType().cast(); auto sizePerThreads = triton::gpu::getSizePerThread(srcEncoding); auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); for (unsigned dim : order) { if (dim == getAxis()) return stride; stride *= type.getShape()[dim] / (sizePerThreads[dim] * threadsPerWarp[dim] * warpsPerCTA[dim]); } llvm_unreachable("Axis not found in order"); } bool maybeSharedAllocationOp(Operation *op) { // TODO(Keren): This function can be replaced by adding // MemoryEffectOpInterface. We can then use the MemoryEffectOpInterface to // query the memory effects of the op. auto *dialect = op->getDialect(); return dialect && (dialect->getTypeID() == mlir::TypeID::get() || dialect->getTypeID() == mlir::TypeID::get() || dialect->getTypeID() == mlir::TypeID::get() || dialect->getTypeID() == mlir::TypeID::get()); } bool maybeAliasOp(Operation *op) { return isa(op) || isa(op) || isa(op) || isa(op); } bool supportMMA(triton::DotOp op, int version) { // Refer to mma section for the data type supported by Volta and Hopper // Tensor Core in // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 auto aElemTy = op.getA().getType().cast().getElementType(); auto bElemTy = op.getB().getType().cast().getElementType(); if (aElemTy.isF32() && bElemTy.isF32()) { return (op.getAllowTF32() && version == 2) || version == 3; } return supportMMA(op.getA(), version) && supportMMA(op.getB(), version); } #ifdef USE_ROCM static bool supportMFMAGranularity(int m, int n, int k) { // these limitations are dtype dependent, in future we may relax them const int granularityMN = 32; const int granularityK = 8; if (m % granularityMN != 0 || n % granularityMN != 0) return false; if (k % granularityK != 0) return false; return true; } bool supportMFMA(triton::DotOp op) { auto aTy = op.getA().getType().cast(); auto bTy = op.getB().getType().cast(); auto aElemTy = aTy.getElementType(); auto bElemTy = bTy.getElementType(); if (aElemTy != bElemTy) return false; auto aShape = aTy.getShape(); auto bShape = bTy.getShape(); assert(aShape[1] == bShape[0]); if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1])) return false; return aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32() || aElemTy.isInteger(8); } #endif bool supportMMA(Value value, int version) { // Tell whether a DotOp support HMMA by the operand type(either $a or $b). // We cannot get both the operand types(in TypeConverter), here we assume the // types of both the operands are identical here. assert((version == 1 || version == 2) && "Unexpected MMA layout version found"); auto elemTy = value.getType().cast().getElementType(); return elemTy.isF16() || elemTy.isBF16() || (elemTy.isF32() && version >= 2) || (elemTy.isInteger(8) && version >= 2); } Type getElementType(Value value) { auto type = value.getType(); if (auto tensorType = type.dyn_cast()) return tensorType.getElementType(); return type; } bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { // dot_op = #mma // when #mma = MmaEncoding auto srcLayout = srcTy.getEncoding(); auto dstLayout = dstTy.getEncoding(); auto mmaLayout = srcLayout.cast(); auto dotOperandLayout = dstLayout.cast(); return mmaLayout.getVersionMajor() == 2 && mmaLayout.getWarpsPerCTA()[1] == 1 && dotOperandLayout.getOpIdx() == 0 && dotOperandLayout.getParent() == mmaLayout && !srcTy.getElementType().isF32(); } #ifdef USE_ROCM bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { auto srcLayout = srcTy.getEncoding(); auto dstLayout = dstTy.getEncoding(); auto mfmaLayout = srcLayout.cast(); auto dotOperandLayout = dstLayout.cast(); // TODO: Remove the restriction on the warpsPerCTA once chain dot testing is // improved. In addition, we can enable this shortcut for regular MFMA // layout when opIdx == 1. return mfmaLayout.getWarpsPerCTA()[1] == 1 && dotOperandLayout.getOpIdx() == 0 && dotOperandLayout.getKWidth() == 8 && dotOperandLayout.getParent() == mfmaLayout && mfmaLayout.getIsTransposed() && (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16()); } #endif bool isSingleValue(Value value) { // Don't consider load as expensive if it is loading a scalar. if (auto tensorTy = value.getType().dyn_cast()) return tensorTy.getNumElements() == 1; // TODO: Handle other cases. // For example, when ptr is a tensor of single value. // It means that ptr is a resultant of broadcast or generated through // a chain of broadcast and other operations. // Rematerialize it without considering contiguous memory access pattern is // fine. return true; } namespace { /// A data structure similar to SetVector but maintains /// a deque instead of a vector to allow for efficient /// push_back and pop_front operations. /// Using SetVector doesn't suffice our needs because /// it only pushes and pops from the back. /// For example, if we have a queue like this: /// 0->4 1->2->3 /// ^-------- /// where 3 depends on 4, once we pop 3, we found /// 4 is not ready, so we check 2 and push 3 back /// to the queue. struct DFSSubgraphState { DFSSubgraphState() : set(), deque() {} DenseSet set; std::deque deque; bool push_back(Operation *op) { if (set.insert(op).second) { deque.push_back(op); return true; } return false; } Operation *pop_front() { Operation *op = deque.front(); deque.pop_front(); set.erase(op); return op; } bool empty() { return deque.empty(); } }; /// DFS post-order implementation that maintains a global count to work across /// multiple invocations, to help implement topological sort on multi-root DAGs. /// We traverse all operations but only record the ones that appear in /// `toSort` for the final result. struct DFSState { DFSState(const SetVector &set) : toSort(set), seen() {} const SetVector &toSort; SmallVector topologicalCounts; DenseSet seen; /// We mark each op as ready if all its operands are seen. If an op is ready, /// we add it to the queue. Otherwise, we keep adding its operands to the /// ancestors set. void addToReadyQueue(Operation *op, DFSSubgraphState &subGraph, SmallVector &readyQueue) { bool ready = true; for (Value operand : op->getOperands()) { auto def = operand.getDefiningOp(); if (def && !seen.count(def)) { subGraph.push_back(def); ready = false; } } if (ready) readyQueue.push_back(op); } }; void dfsPostorder(Operation *root, DFSState *state) { DFSSubgraphState subGraph; subGraph.push_back(root); SmallVector ops; while (!subGraph.empty()) { // Nodes in the ready queue are ready to be processed. // Meaning that either their operands are all seen or they have null // operands. SmallVector readyQueue; auto *current = subGraph.pop_front(); state->addToReadyQueue(current, subGraph, readyQueue); while (!readyQueue.empty()) { Operation *current = readyQueue.pop_back_val(); if (!state->seen.insert(current).second) continue; ops.push_back(current); for (Value result : current->getResults()) { for (Operation *op : result.getUsers()) state->addToReadyQueue(op, subGraph, readyQueue); } for (Region ®ion : current->getRegions()) { for (Operation &op : region.getOps()) state->addToReadyQueue(&op, subGraph, readyQueue); } } } for (Operation *op : llvm::reverse(ops)) { if (state->toSort.count(op) > 0) state->topologicalCounts.push_back(op); } } } // namespace SetVector multiRootTopologicalSort(const SetVector &toSort) { if (toSort.empty()) { return toSort; } // Run from each root with global count and `seen` set. DFSState state(toSort); for (auto *s : toSort) { assert(toSort.count(s) == 1 && "NYI: multi-sets not supported"); dfsPostorder(s, &state); } // Reorder and return. SetVector res; for (auto it = state.topologicalCounts.rbegin(), eit = state.topologicalCounts.rend(); it != eit; ++it) { res.insert(*it); } return res; } SetVector multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter, TransitiveFilter forwardFilter) { SetVector slice; slice.insert(op); unsigned currentIndex = 0; SetVector backwardSlice; SetVector forwardSlice; while (currentIndex != slice.size()) { auto *currentOp = (slice)[currentIndex]; // Compute and insert the backwardSlice starting from currentOp. backwardSlice.clear(); getBackwardSlice(currentOp, &backwardSlice, backwardFilter); slice.insert(backwardSlice.begin(), backwardSlice.end()); // Compute and insert the forwardSlice starting from currentOp. forwardSlice.clear(); getForwardSlice(currentOp, &forwardSlice, forwardFilter); slice.insert(forwardSlice.begin(), forwardSlice.end()); ++currentIndex; } return multiRootTopologicalSort(slice); } namespace { // Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis // interacts with constant propagation, but SparseConstantPropagation // doesn't seem to be sufficient. class ConstantAnalysis : public DataFlowAnalysis { public: using DataFlowAnalysis::DataFlowAnalysis; LogicalResult initialize(Operation *top) override { WalkResult result = top->walk([&](Operation *op) { if (failed(visit(op))) return WalkResult::interrupt(); return WalkResult::advance(); }); return success(!result.wasInterrupted()); } LogicalResult visit(ProgramPoint point) override { Operation *op = point.get(); Attribute value; if (matchPattern(op, m_Constant(&value))) { auto *constant = getOrCreate>( op->getResult(0)); propagateIfChanged(constant, constant->join(dataflow::ConstantValue( value, op->getDialect()))); return success(); } // Dead code analysis requires every operands has initialized ConstantValue // state before it is visited. // https://github.com/llvm/llvm-project/blob/2ec1aba2b69faa1de5f71832a48e25aa3b5d5314/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp#L322 // That's why we need to set all operands to unknown constants. setAllToUnknownConstants(op->getResults()); for (Region ®ion : op->getRegions()) { for (Block &block : region.getBlocks()) setAllToUnknownConstants(block.getArguments()); } return success(); } private: /// Set all given values as not constants. void setAllToUnknownConstants(ValueRange values) { dataflow::ConstantValue unknownConstant(nullptr, nullptr); for (Value value : values) { auto *constant = getOrCreate>(value); propagateIfChanged(constant, constant->join(unknownConstant)); } } }; } // namespace std::unique_ptr createDataFlowSolver() { auto solver = std::make_unique(); solver->load(); solver->load(); return solver; } } // namespace mlir