#include "triton/Analysis/Membar.h" #include "triton/Analysis/Alias.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include namespace mlir { void MembarAnalysis::run() { auto *operation = allocation->getOperation(); OpBuilder builder(operation); resolve(operation, &builder); } void MembarAnalysis::resolve(Operation *operation, OpBuilder *builder) { // Initialize the blockList std::deque blockList; operation->walk([&](Block *block) { for (auto &op : block->getOperations()) { // Check if the operation belongs to scf dialect, if so, we need to // throw an error if (op.getDialect()->getNamespace() == "scf") { op.emitError("scf dialect is not supported in membar. Please lower it " "to cf dialect first."); return; } } if (block->isEntryBlock()) blockList.emplace_back(block); }); // A fixed point algorithm while (!blockList.empty()) { auto *block = blockList.front(); blockList.pop_front(); // Make a copy of the inputblockInfo but not update auto inputBlockInfo = inputBlockInfoMap.lookup(block); SmallVector successors; for (auto &op : block->getOperations()) { if (op.hasTrait()) { visitTerminator(&op, successors); } else { update(&op, &inputBlockInfo, builder); } } // Get the reference because we want to update if it changed if (outputBlockInfoMap.count(block) && inputBlockInfo == outputBlockInfoMap[block]) { // If we have seen the block before and the inputBlockInfo is the same as // the outputBlockInfo, we skip the successors continue; } // Update the current block outputBlockInfoMap[block].join(inputBlockInfo); // Update the successors for (auto *successor : successors) { inputBlockInfoMap[successor].join(outputBlockInfoMap[block]); blockList.emplace_back(successor); } } } void MembarAnalysis::visitTerminator(Operation *op, SmallVector &successors) { if (auto branchInterface = dyn_cast(op)) { Block *parentBlock = branchInterface->getBlock(); for (Block *successor : parentBlock->getSuccessors()) { successors.push_back(successor); } return; } // Otherwise, it could be a return op assert(isa(op) && "Unknown terminator"); } void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, OpBuilder *builder) { if (isa(op) || isa(op) || isa(op)) { // alloc is an allocation op without memory write. // FIXME(Keren): extract_slice is always alias for now return; } if (isa(op)) { // If the current op is a barrier, we sync previous reads and writes blockInfo->sync(); return; } if (isa(op) && !isa(op->getNextNode())) { // If the current op is an async wait and the next op is not a barrier we // insert a barrier op and sync blockInfo->sync(); OpBuilder::InsertionGuard g(*builder); builder->setInsertionPointAfter(op); builder->create(op->getLoc()); blockInfo->sync(); return; } BlockInfo curBlockInfo; for (Value value : op->getOperands()) { for (auto bufferId : allocation->getBufferIds(value)) { if (bufferId != Allocation::InvalidBufferId) { if (isa(op) || isa(op)) { // FIXME(Keren): insert_slice and insert_slice_async are always // alias for now curBlockInfo.syncWriteBuffers.insert(bufferId); } else { // ConvertLayoutOp: shared memory -> registers curBlockInfo.syncReadBuffers.insert(bufferId); } } } } for (Value value : op->getResults()) { // ConvertLayoutOp: registers -> shared memory auto bufferId = allocation->getBufferId(value); if (bufferId != Allocation::InvalidBufferId) { curBlockInfo.syncWriteBuffers.insert(bufferId); } } // Scratch buffer is considered as both shared memory write & read auto bufferId = allocation->getBufferId(op); if (bufferId != Allocation::InvalidBufferId) { curBlockInfo.syncWriteBuffers.insert(bufferId); curBlockInfo.syncReadBuffers.insert(bufferId); } if (blockInfo->isIntersected(curBlockInfo, allocation)) { OpBuilder::InsertionGuard g(*builder); builder->setInsertionPoint(op); builder->create(op->getLoc()); blockInfo->sync(); } // Update the region info, even if barrier is inserted, we have to maintain // the current op's read/write buffers. blockInfo->join(curBlockInfo); } } // namespace mlir