[HOPPER][WS] fix TMA store hang in ws mode (#2056)

This commit is contained in:
allatit23
2023-08-08 19:53:52 +08:00
committed by GitHub
parent 2a95d9bf0d
commit 6dee55c912
16 changed files with 184 additions and 69 deletions

View File

@@ -146,6 +146,13 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape);
// Returns null if the op is not inside a agent region (warp specialization
// mode). Note that there should be at most one agent id attached to the
// operation.
std::optional<int> getWSAgentId(Operation *op);
std::optional<int> getWSRoleId(Operation *op);
void setRoleId(Operation *op, int roleId);
} // namespace mlir
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

View File

@@ -372,7 +372,7 @@ def TTNG_RegAllocOp : TTNG_Op<"reg_alloc", []> {
let arguments = (ins I32Attr: $regCount);
let assemblyFormat = "$regCount attr-dict `:` type(operands)";
let assemblyFormat = "$regCount attr-dict";
}
def TTNG_RegDeallocOp : TTNG_Op<"reg_dealloc", []> {
@@ -380,7 +380,7 @@ def TTNG_RegDeallocOp : TTNG_Op<"reg_dealloc", []> {
let arguments = (ins I32Attr: $regCount);
let assemblyFormat = "$regCount attr-dict `:` type(operands)";
let assemblyFormat = "$regCount attr-dict";
}
#endif

View File

@@ -73,6 +73,8 @@ createTritonNvidiaGPUFenceInsertionPass(int computeCapability = 90);
std::unique_ptr<Pass>
createTritonGPURewriteTensorPointerPass(int computeCapability = 80);
std::unique_ptr<Pass> createTritonNvidiaGPUWSFixupMissingAttrs();
/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"

View File

@@ -225,4 +225,22 @@ def TritonGPURewriteTensorPointer : Pass</*cli-arg*/"tritongpu-rewrite-tensor-po
];
}
def TritonGPUWSFixupMissingAttrs : Pass<"triton-nvidia-gpu-ws-fixup-missing-attrs", "mlir::ModuleOp"> {
let summary = "Fixup missing WS related attributes";
let description = [{
WS related attributes are attached to some key operations and are used when lowering to llvm.
However these attributes maybe be dropped in the following IR transform. This pass tries to
fixup the missing attributes.
}];
let constructor = "mlir::createTritonNvidiaGPUWSFixupMissingAttrs()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithDialect"];
}
#endif

View File

@@ -5,6 +5,7 @@
#include "../lib/Conversion/TritonGPUToLLVM/Utility.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -117,15 +118,28 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
return;
}
if (isa<triton::gpu::AsyncWaitOp>(op) &&
!isa<gpu::BarrierOp>(op->getNextNode())) {
if (isa<triton::gpu::AsyncWaitOp, triton::gpu::AsyncBulkWaitOp>(op) &&
!isa<gpu::BarrierOp>(op->getNextNode()) &&
!(isa<LLVM::InlineAsmOp>(op->getNextNode()) &&
(dyn_cast<LLVM::InlineAsmOp>(op->getNextNode())
.getAsmString()
.find("bar.sync") != std::string::npos))) {
// If the current op is an async wait and the next op is not a barrier we
// insert a barrier op and sync
blockInfo->sync();
OpBuilder::InsertionGuard g(*builder);
builder->setInsertionPointAfter(op);
builder->create<gpu::BarrierOp>(op->getLoc());
blockInfo->sync();
if (auto optionalAgentId = getWSAgentId(op)) {
int agentId = *optionalAgentId, roleId = 0;
if (auto optionalRoleId = getWSRoleId(op))
roleId = *optionalRoleId;
int barId = agentId + roleId + nameBarrierIdBegin;
assert(barId < nameBarrierIdEnd);
barSync(*builder, op, barId, 128);
} else {
builder->create<gpu::BarrierOp>(op->getLoc());
blockInfo->sync();
}
return;
}
@@ -180,10 +194,10 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
// TODO(Keren): Don't expose LLVM Dialect ops here
// TODO[shuhaoj]: Change hard code style of numThreads. Hide async_agent
// attr. Better way to determine barId (number of agents are limited).
if (op->hasAttr("async_agent")) {
int agentId = getAgentIds(op).front(), roleId = 0;
if (op->hasAttr("agent.mutex_role"))
roleId = op->getAttrOfType<IntegerAttr>("agent.mutex_role").getInt();
if (auto optionalAgentId = getWSAgentId(op)) {
int agentId = *optionalAgentId, roleId = 0;
if (auto optionalRoleId = getWSRoleId(op))
roleId = *optionalRoleId;
int barId = agentId + roleId + nameBarrierIdBegin;
assert(barId < nameBarrierIdEnd);
barSync(*builder, op, barId, 128);

View File

@@ -1,5 +1,7 @@
#include "ConvertLayoutOpToLLVM.h"
#include "Utility.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
@@ -589,11 +591,10 @@ private:
if (repId != 0) {
// TODO[shuhaoj]: change hard code style of numThreads. Hide async
// attr. Better way to determine barId (number of agents are limited).
if (op->hasAttr("async_agent")) {
int agentId = getAgentIds(op).front(), roleId = 0;
if (op->hasAttr("agent.mutex_role"))
roleId =
op->getAttrOfType<IntegerAttr>("agent.mutex_role").getInt();
if (auto optionalAgentId = getWSAgentId(op)) {
int agentId = *optionalAgentId, roleId = 0;
if (auto optionalRoleId = getWSRoleId(op))
roleId = *optionalRoleId;
int barId = agentId + roleId + nameBarrierIdBegin;
assert(barId < nameBarrierIdEnd);
auto bar = rewriter.create<LLVM::ConstantOp>(
@@ -624,10 +625,10 @@ private:
// TODO[shuhaoj]: change hard code style of numThreads. Hide async_agent
// attr. Better way to determine barId (number of agents are limited).
if (op->hasAttr("async_agent")) {
int agentId = getAgentIds(op).front(), roleId = 0;
if (op->hasAttr("agent.mutex_role"))
roleId = op->getAttrOfType<IntegerAttr>("agent.mutex_role").getInt();
if (auto optionalAgentId = getWSAgentId(op)) {
int agentId = *optionalAgentId, roleId = 0;
if (auto optionalRoleId = getWSRoleId(op))
roleId = *optionalRoleId;
int barId = agentId + roleId + nameBarrierIdBegin;
assert(barId < nameBarrierIdEnd);
auto bar = rewriter.create<LLVM::ConstantOp>(
@@ -793,10 +794,10 @@ private:
}
// TODO[shuhaoj]: change hard code style of numThreads. Hide async_agent
// attr. Better way to determine barId (number of agents are limited).
if (op->hasAttr("async_agent")) {
int agentId = getAgentIds(op).front(), roleId = 0;
if (op->hasAttr("agent.mutex_role"))
roleId = op->getAttrOfType<IntegerAttr>("agent.mutex_role").getInt();
if (auto optionalAgentId = getWSAgentId(op)) {
int agentId = *optionalAgentId, roleId = 0;
if (auto optionalRoleId = getWSRoleId(op))
roleId = *optionalRoleId;
int barId = agentId + roleId + nameBarrierIdBegin;
assert(barId < nameBarrierIdEnd);
auto bar = rewriter.create<LLVM::ConstantOp>(

View File

@@ -1,6 +1,7 @@
#include "ReduceOpToLLVM.h"
#include "Utility.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
using namespace mlir;
@@ -289,7 +290,7 @@ private:
triton::ReduceOp op) const {
// TODO[shuhaoj]: change hard code style of numThreads. Hide async_agent
// attr.
if (op->hasAttr("async_agent")) {
if (getWSAgentId(op)) {
barSync(rewriter, op, getAgentIds(op).front(), 128);
} else {
barrier();

View File

@@ -17,6 +17,7 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
@@ -47,6 +48,12 @@ public:
// Sink conversions into loops when they will increase
// register pressure
DenseMap<Operation *, Operation *> opToMove;
auto moveAfter = [](Operation *lhs, Operation *rhs) {
auto lhsId = getWSRoleId(lhs);
auto rhsId = getWSRoleId(rhs);
if (lhsId == rhsId)
lhs->moveAfter(rhs);
};
m.walk([&](triton::gpu::ConvertLayoutOp op) {
if (!willIncreaseRegisterPressure(op))
return;
@@ -70,7 +77,7 @@ public:
Operation *argOp = op.getOperand().getDefiningOp();
if (!argOp)
return;
op->moveAfter(argOp);
moveAfter(op, argOp);
});
// Move transpositions just after their definition
opToMove.clear();
@@ -78,7 +85,7 @@ public:
Operation *argOp = op.getOperand().getDefiningOp();
if (!argOp)
return;
op->moveAfter(argOp);
moveAfter(op, argOp);
});
// Move `dot` operand so that conversions to opIdx=1 happens after
// conversions to opIdx=0
@@ -104,7 +111,7 @@ public:
// after the conversion to OpIdx=0.
if (!dom.dominates(op.getOperation(), AOp.getOperation()))
return;
op->moveAfter(AOp);
moveAfter(op, AOp);
});
return;
}

View File

@@ -731,4 +731,28 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
return linear;
}
std::optional<int> getWSAgentId(Operation *op) {
int prevAgentId = -1;
if (auto attr = op->getAttrOfType<DenseIntElementsAttr>("async_agent")) {
for (auto agentId : attr.getValues<int>()) {
assert(prevAgentId == -1 && "support at most one agent id");
prevAgentId = agentId;
}
}
if (prevAgentId == -1)
return std::nullopt;
return prevAgentId;
}
std::optional<int> getWSRoleId(Operation *op) {
if (!op->hasAttr("agent.mutex_role"))
return std::nullopt;
return op->getAttrOfType<IntegerAttr>("agent.mutex_role").getInt();
}
void setRoleId(Operation *op, int roleId) {
auto attr = IntegerAttr::get(IntegerType::get(op->getContext(), 32), roleId);
op->setAttr("agent.mutex_role", attr);
}
} // namespace mlir

View File

@@ -6,6 +6,7 @@ add_mlir_dialect_library(TritonNvidiaGPUTransforms
WSPipeline.cpp
WSMutex.cpp
WSMaterialization.cpp
WSFixupMissingAttrs.cpp
FenceInsertion.cpp
RewriteTensorPointer.cpp
Utility.cpp

View File

@@ -0,0 +1,69 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
namespace mlir {
namespace ttng = triton::nvidia_gpu;
namespace {
class TritonGPUWSFixupMissingAttrsPass
: public TritonGPUWSFixupMissingAttrsBase<
TritonGPUWSFixupMissingAttrsPass> {
public:
TritonGPUWSFixupMissingAttrsPass() = default;
void runOnOperation() override {
ModuleOp mod = getOperation();
if (!ttng::TritonNvidiaGPUDialect::getWSSupportedAttr(mod))
return;
OpBuilder builder(mod);
mod->walk([&](mlir::triton::FuncOp funcOp) {
for (Operation &op : funcOp.getBody().front().getOperations()) {
if (!isa<scf::IfOp>(&op))
continue;
auto agentIds = getAgentIds(&op);
if (agentIds.size() != 1)
continue;
Block *roleIdBlock = nullptr;
op.walk<WalkOrder::PreOrder>([&](Operation *subOp) {
setAgentIds(subOp, agentIds);
// Find the outter most common block that has roleId.
// The below implementation assumes that:
// - all lock/unlock ops are in the same block (denoted as B).
// - there is always one scf.if op in the front of `B` which has
// role id attached.
// The above assumptions are maintained by WSMutex pass currently.
if (!roleIdBlock && isa<scf::IfOp>(subOp) && getWSRoleId(subOp))
roleIdBlock = subOp->getBlock();
});
if (!roleIdBlock)
continue;
int roleId = 0;
for (Operation &roleOp : roleIdBlock->getOperations()) {
auto optionalRoleId = getWSRoleId(&roleOp);
if (!optionalRoleId) {
setRoleId(&roleOp, roleId);
} else {
roleId = *optionalRoleId;
}
roleOp.walk([&](Operation *subOp) { setRoleId(subOp, roleId); });
}
}
});
}
};
} // namespace
std::unique_ptr<Pass> createTritonNvidiaGPUWSFixupMissingAttrs() {
return std::make_unique<TritonGPUWSFixupMissingAttrsPass>();
}
} // namespace mlir

View File

@@ -708,32 +708,6 @@ struct WSMaterializationPass
materializeMutexOperations(mod);
tryRegisterRealloc(mod);
mod->walk([](Operation *op) {
bool hasTensor = 0;
auto results = op->getResults();
auto operands = op->getOperands();
for (auto i : results) {
if (isa<RankedTensorType>(i.getType())) {
hasTensor = 1;
break;
}
}
if (!hasTensor) {
for (auto i : operands) {
if (isa<RankedTensorType>(i.getType())) {
hasTensor = 1;
break;
}
}
}
if (!hasTensor && !isa<ttng::MBarrierWaitOp>(op) &&
!isa<ttng::ExtractMBarrierOp>(op) &&
!isa<ttng::MBarrierArriveOp>(op)) {
op->removeAttr("async_agent");
}
});
// TODO: More flexible way to set num-warps
// One dma, one math warp group, set num-warps = 8
auto i32_ty = IntegerType::get(mod->getContext(), 32);

View File

@@ -264,8 +264,9 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,
});
for (int i = 0; i < numRoles; ++i) {
if (lockLocs[i] == op) {
if (roleId != -1)
op->setAttr("agent.mutex_role", builder.getI32IntegerAttr(roleId));
roleId = i;
op->setAttr("agent.mutex_role", builder.getI32IntegerAttr(i));
break;
}
}

View File

@@ -1628,6 +1628,10 @@ void init_triton_ir(py::module &&m) {
self.addPass(mlir::createTritonNvidiaGPUWSMaterializationPass(
computeCapability));
})
.def("add_tritongpu_ws_fixup_missing_attrs_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonNvidiaGPUWSFixupMissingAttrs());
})
.def(
"add_convert_triton_to_tritongpu_pass",
[](mlir::PassManager &self, int numWarps, int threadsPerWarp,
@@ -1752,6 +1756,7 @@ void init_triton_translation(py::module &m) {
});
m.def("get_num_warps", [](mlir::ModuleOp mod) {
auto shared = mod->getAttrOfType<mlir::IntegerAttr>("triton_gpu.num-warps");
assert(shared);
return shared.getInt();
});

View File

@@ -93,11 +93,6 @@ def matmul_no_scf_kernel(
]))
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, ENABLE_WS):
if '-'.join(map(str, [USE_TMA_EPILOGUE, ENABLE_WS])) in [
'True-True'
]:
pytest.skip("error, skip")
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
@@ -335,12 +330,6 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A,
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
# with ENABLE_TMA=1 and ENABLE_MMA_V3=1
if ENABLE_WS:
# example:
# [128-128-64-4-1-None-None-None-False-False-False-chain-dot-float16-False-3-True]
pytest.skip('hang!')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K

View File

@@ -120,11 +120,13 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch,
pm.add_tritongpu_optimize_dot_operands_pass()
pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_decompose_conversions_pass()
pm.add_tritongpu_ws_fixup_missing_attrs_pass()
pm.add_tritongpu_reorder_instructions_pass()
pm.add_cse_pass()
pm.add_symbol_dce_pass()
if arch // 10 >= 9:
pm.add_tritongpu_fence_insertion_pass()
pm.add_tritongpu_ws_fixup_missing_attrs_pass()
pm.run(mod)
return mod
@@ -556,15 +558,15 @@ def compile(fn, **kwargs):
asm[ir_name] = str(next_module)
if ir_name == "llir" and "shared" not in metadata:
metadata["shared"] = get_shared_memory_size(module)
if ir_name == "ttgir" and enable_warp_specialization:
metadata["num_warps"] = get_num_warps(module)
if ir_name == "ttgir":
metadata["enable_warp_specialization"] = _triton.ir.is_ws_supported(next_module)
if metadata["enable_warp_specialization"]:
metadata["num_warps"] = get_num_warps(next_module)
if ir_name == "ptx":
metadata["name"] = get_kernel_name(next_module, pattern='// .globl')
if ir_name == "amdgcn":
metadata["name"] = get_kernel_name(next_module[0], pattern='.globl')
asm["hsaco_path"] = next_module[1]
if ir_name == "ttgir":
metadata["enable_warp_specialization"] = _triton.ir.is_ws_supported(next_module)
if not is_cuda and not is_hip:
_device_backend.add_meta_info(ir_name, module, next_module, metadata, asm)
module = next_module