mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[HOPPER][WS] fix TMA store hang in ws mode (#2056)
This commit is contained in:
@@ -146,6 +146,13 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
|
||||
Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
|
||||
ArrayRef<unsigned> shape);
|
||||
|
||||
// Returns null if the op is not inside a agent region (warp specialization
|
||||
// mode). Note that there should be at most one agent id attached to the
|
||||
// operation.
|
||||
std::optional<int> getWSAgentId(Operation *op);
|
||||
std::optional<int> getWSRoleId(Operation *op);
|
||||
void setRoleId(Operation *op, int roleId);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
|
||||
|
||||
@@ -372,7 +372,7 @@ def TTNG_RegAllocOp : TTNG_Op<"reg_alloc", []> {
|
||||
|
||||
let arguments = (ins I32Attr: $regCount);
|
||||
|
||||
let assemblyFormat = "$regCount attr-dict `:` type(operands)";
|
||||
let assemblyFormat = "$regCount attr-dict";
|
||||
}
|
||||
|
||||
def TTNG_RegDeallocOp : TTNG_Op<"reg_dealloc", []> {
|
||||
@@ -380,7 +380,7 @@ def TTNG_RegDeallocOp : TTNG_Op<"reg_dealloc", []> {
|
||||
|
||||
let arguments = (ins I32Attr: $regCount);
|
||||
|
||||
let assemblyFormat = "$regCount attr-dict `:` type(operands)";
|
||||
let assemblyFormat = "$regCount attr-dict";
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -73,6 +73,8 @@ createTritonNvidiaGPUFenceInsertionPass(int computeCapability = 90);
|
||||
std::unique_ptr<Pass>
|
||||
createTritonGPURewriteTensorPointerPass(int computeCapability = 80);
|
||||
|
||||
std::unique_ptr<Pass> createTritonNvidiaGPUWSFixupMissingAttrs();
|
||||
|
||||
/// Generate the code for registering passes.
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
|
||||
|
||||
@@ -225,4 +225,22 @@ def TritonGPURewriteTensorPointer : Pass</*cli-arg*/"tritongpu-rewrite-tensor-po
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUWSFixupMissingAttrs : Pass<"triton-nvidia-gpu-ws-fixup-missing-attrs", "mlir::ModuleOp"> {
|
||||
let summary = "Fixup missing WS related attributes";
|
||||
|
||||
let description = [{
|
||||
WS related attributes are attached to some key operations and are used when lowering to llvm.
|
||||
However these attributes maybe be dropped in the following IR transform. This pass tries to
|
||||
fixup the missing attributes.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonNvidiaGPUWSFixupMissingAttrs()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::arith::ArithDialect"];
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "../lib/Conversion/TritonGPUToLLVM/Utility.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
|
||||
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
@@ -117,15 +118,28 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
|
||||
return;
|
||||
}
|
||||
|
||||
if (isa<triton::gpu::AsyncWaitOp>(op) &&
|
||||
!isa<gpu::BarrierOp>(op->getNextNode())) {
|
||||
if (isa<triton::gpu::AsyncWaitOp, triton::gpu::AsyncBulkWaitOp>(op) &&
|
||||
!isa<gpu::BarrierOp>(op->getNextNode()) &&
|
||||
!(isa<LLVM::InlineAsmOp>(op->getNextNode()) &&
|
||||
(dyn_cast<LLVM::InlineAsmOp>(op->getNextNode())
|
||||
.getAsmString()
|
||||
.find("bar.sync") != std::string::npos))) {
|
||||
// If the current op is an async wait and the next op is not a barrier we
|
||||
// insert a barrier op and sync
|
||||
blockInfo->sync();
|
||||
OpBuilder::InsertionGuard g(*builder);
|
||||
builder->setInsertionPointAfter(op);
|
||||
builder->create<gpu::BarrierOp>(op->getLoc());
|
||||
blockInfo->sync();
|
||||
if (auto optionalAgentId = getWSAgentId(op)) {
|
||||
int agentId = *optionalAgentId, roleId = 0;
|
||||
if (auto optionalRoleId = getWSRoleId(op))
|
||||
roleId = *optionalRoleId;
|
||||
int barId = agentId + roleId + nameBarrierIdBegin;
|
||||
assert(barId < nameBarrierIdEnd);
|
||||
barSync(*builder, op, barId, 128);
|
||||
} else {
|
||||
builder->create<gpu::BarrierOp>(op->getLoc());
|
||||
blockInfo->sync();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -180,10 +194,10 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
|
||||
// TODO(Keren): Don't expose LLVM Dialect ops here
|
||||
// TODO[shuhaoj]: Change hard code style of numThreads. Hide async_agent
|
||||
// attr. Better way to determine barId (number of agents are limited).
|
||||
if (op->hasAttr("async_agent")) {
|
||||
int agentId = getAgentIds(op).front(), roleId = 0;
|
||||
if (op->hasAttr("agent.mutex_role"))
|
||||
roleId = op->getAttrOfType<IntegerAttr>("agent.mutex_role").getInt();
|
||||
if (auto optionalAgentId = getWSAgentId(op)) {
|
||||
int agentId = *optionalAgentId, roleId = 0;
|
||||
if (auto optionalRoleId = getWSRoleId(op))
|
||||
roleId = *optionalRoleId;
|
||||
int barId = agentId + roleId + nameBarrierIdBegin;
|
||||
assert(barId < nameBarrierIdEnd);
|
||||
barSync(*builder, op, barId, 128);
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "ConvertLayoutOpToLLVM.h"
|
||||
#include "Utility.h"
|
||||
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
|
||||
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
@@ -589,11 +591,10 @@ private:
|
||||
if (repId != 0) {
|
||||
// TODO[shuhaoj]: change hard code style of numThreads. Hide async
|
||||
// attr. Better way to determine barId (number of agents are limited).
|
||||
if (op->hasAttr("async_agent")) {
|
||||
int agentId = getAgentIds(op).front(), roleId = 0;
|
||||
if (op->hasAttr("agent.mutex_role"))
|
||||
roleId =
|
||||
op->getAttrOfType<IntegerAttr>("agent.mutex_role").getInt();
|
||||
if (auto optionalAgentId = getWSAgentId(op)) {
|
||||
int agentId = *optionalAgentId, roleId = 0;
|
||||
if (auto optionalRoleId = getWSRoleId(op))
|
||||
roleId = *optionalRoleId;
|
||||
int barId = agentId + roleId + nameBarrierIdBegin;
|
||||
assert(barId < nameBarrierIdEnd);
|
||||
auto bar = rewriter.create<LLVM::ConstantOp>(
|
||||
@@ -624,10 +625,10 @@ private:
|
||||
|
||||
// TODO[shuhaoj]: change hard code style of numThreads. Hide async_agent
|
||||
// attr. Better way to determine barId (number of agents are limited).
|
||||
if (op->hasAttr("async_agent")) {
|
||||
int agentId = getAgentIds(op).front(), roleId = 0;
|
||||
if (op->hasAttr("agent.mutex_role"))
|
||||
roleId = op->getAttrOfType<IntegerAttr>("agent.mutex_role").getInt();
|
||||
if (auto optionalAgentId = getWSAgentId(op)) {
|
||||
int agentId = *optionalAgentId, roleId = 0;
|
||||
if (auto optionalRoleId = getWSRoleId(op))
|
||||
roleId = *optionalRoleId;
|
||||
int barId = agentId + roleId + nameBarrierIdBegin;
|
||||
assert(barId < nameBarrierIdEnd);
|
||||
auto bar = rewriter.create<LLVM::ConstantOp>(
|
||||
@@ -793,10 +794,10 @@ private:
|
||||
}
|
||||
// TODO[shuhaoj]: change hard code style of numThreads. Hide async_agent
|
||||
// attr. Better way to determine barId (number of agents are limited).
|
||||
if (op->hasAttr("async_agent")) {
|
||||
int agentId = getAgentIds(op).front(), roleId = 0;
|
||||
if (op->hasAttr("agent.mutex_role"))
|
||||
roleId = op->getAttrOfType<IntegerAttr>("agent.mutex_role").getInt();
|
||||
if (auto optionalAgentId = getWSAgentId(op)) {
|
||||
int agentId = *optionalAgentId, roleId = 0;
|
||||
if (auto optionalRoleId = getWSRoleId(op))
|
||||
roleId = *optionalRoleId;
|
||||
int barId = agentId + roleId + nameBarrierIdBegin;
|
||||
assert(barId < nameBarrierIdEnd);
|
||||
auto bar = rewriter.create<LLVM::ConstantOp>(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "ReduceOpToLLVM.h"
|
||||
#include "Utility.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -289,7 +290,7 @@ private:
|
||||
triton::ReduceOp op) const {
|
||||
// TODO[shuhaoj]: change hard code style of numThreads. Hide async_agent
|
||||
// attr.
|
||||
if (op->hasAttr("async_agent")) {
|
||||
if (getWSAgentId(op)) {
|
||||
barSync(rewriter, op, getAgentIds(op).front(), 128);
|
||||
} else {
|
||||
barrier();
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
@@ -47,6 +48,12 @@ public:
|
||||
// Sink conversions into loops when they will increase
|
||||
// register pressure
|
||||
DenseMap<Operation *, Operation *> opToMove;
|
||||
auto moveAfter = [](Operation *lhs, Operation *rhs) {
|
||||
auto lhsId = getWSRoleId(lhs);
|
||||
auto rhsId = getWSRoleId(rhs);
|
||||
if (lhsId == rhsId)
|
||||
lhs->moveAfter(rhs);
|
||||
};
|
||||
m.walk([&](triton::gpu::ConvertLayoutOp op) {
|
||||
if (!willIncreaseRegisterPressure(op))
|
||||
return;
|
||||
@@ -70,7 +77,7 @@ public:
|
||||
Operation *argOp = op.getOperand().getDefiningOp();
|
||||
if (!argOp)
|
||||
return;
|
||||
op->moveAfter(argOp);
|
||||
moveAfter(op, argOp);
|
||||
});
|
||||
// Move transpositions just after their definition
|
||||
opToMove.clear();
|
||||
@@ -78,7 +85,7 @@ public:
|
||||
Operation *argOp = op.getOperand().getDefiningOp();
|
||||
if (!argOp)
|
||||
return;
|
||||
op->moveAfter(argOp);
|
||||
moveAfter(op, argOp);
|
||||
});
|
||||
// Move `dot` operand so that conversions to opIdx=1 happens after
|
||||
// conversions to opIdx=0
|
||||
@@ -104,7 +111,7 @@ public:
|
||||
// after the conversion to OpIdx=0.
|
||||
if (!dom.dominates(op.getOperation(), AOp.getOperation()))
|
||||
return;
|
||||
op->moveAfter(AOp);
|
||||
moveAfter(op, AOp);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -731,4 +731,28 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
|
||||
return linear;
|
||||
}
|
||||
|
||||
std::optional<int> getWSAgentId(Operation *op) {
|
||||
int prevAgentId = -1;
|
||||
if (auto attr = op->getAttrOfType<DenseIntElementsAttr>("async_agent")) {
|
||||
for (auto agentId : attr.getValues<int>()) {
|
||||
assert(prevAgentId == -1 && "support at most one agent id");
|
||||
prevAgentId = agentId;
|
||||
}
|
||||
}
|
||||
if (prevAgentId == -1)
|
||||
return std::nullopt;
|
||||
return prevAgentId;
|
||||
}
|
||||
|
||||
std::optional<int> getWSRoleId(Operation *op) {
|
||||
if (!op->hasAttr("agent.mutex_role"))
|
||||
return std::nullopt;
|
||||
return op->getAttrOfType<IntegerAttr>("agent.mutex_role").getInt();
|
||||
}
|
||||
|
||||
void setRoleId(Operation *op, int roleId) {
|
||||
auto attr = IntegerAttr::get(IntegerType::get(op->getContext(), 32), roleId);
|
||||
op->setAttr("agent.mutex_role", attr);
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -6,6 +6,7 @@ add_mlir_dialect_library(TritonNvidiaGPUTransforms
|
||||
WSPipeline.cpp
|
||||
WSMutex.cpp
|
||||
WSMaterialization.cpp
|
||||
WSFixupMissingAttrs.cpp
|
||||
FenceInsertion.cpp
|
||||
RewriteTensorPointer.cpp
|
||||
Utility.cpp
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace ttng = triton::nvidia_gpu;
|
||||
|
||||
namespace {
|
||||
|
||||
class TritonGPUWSFixupMissingAttrsPass
|
||||
: public TritonGPUWSFixupMissingAttrsBase<
|
||||
TritonGPUWSFixupMissingAttrsPass> {
|
||||
public:
|
||||
TritonGPUWSFixupMissingAttrsPass() = default;
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp mod = getOperation();
|
||||
if (!ttng::TritonNvidiaGPUDialect::getWSSupportedAttr(mod))
|
||||
return;
|
||||
OpBuilder builder(mod);
|
||||
mod->walk([&](mlir::triton::FuncOp funcOp) {
|
||||
for (Operation &op : funcOp.getBody().front().getOperations()) {
|
||||
if (!isa<scf::IfOp>(&op))
|
||||
continue;
|
||||
auto agentIds = getAgentIds(&op);
|
||||
if (agentIds.size() != 1)
|
||||
continue;
|
||||
Block *roleIdBlock = nullptr;
|
||||
op.walk<WalkOrder::PreOrder>([&](Operation *subOp) {
|
||||
setAgentIds(subOp, agentIds);
|
||||
// Find the outter most common block that has roleId.
|
||||
// The below implementation assumes that:
|
||||
// - all lock/unlock ops are in the same block (denoted as B).
|
||||
// - there is always one scf.if op in the front of `B` which has
|
||||
// role id attached.
|
||||
// The above assumptions are maintained by WSMutex pass currently.
|
||||
if (!roleIdBlock && isa<scf::IfOp>(subOp) && getWSRoleId(subOp))
|
||||
roleIdBlock = subOp->getBlock();
|
||||
});
|
||||
if (!roleIdBlock)
|
||||
continue;
|
||||
int roleId = 0;
|
||||
for (Operation &roleOp : roleIdBlock->getOperations()) {
|
||||
auto optionalRoleId = getWSRoleId(&roleOp);
|
||||
if (!optionalRoleId) {
|
||||
setRoleId(&roleOp, roleId);
|
||||
} else {
|
||||
roleId = *optionalRoleId;
|
||||
}
|
||||
roleOp.walk([&](Operation *subOp) { setRoleId(subOp, roleId); });
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createTritonNvidiaGPUWSFixupMissingAttrs() {
|
||||
return std::make_unique<TritonGPUWSFixupMissingAttrsPass>();
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
@@ -708,32 +708,6 @@ struct WSMaterializationPass
|
||||
materializeMutexOperations(mod);
|
||||
tryRegisterRealloc(mod);
|
||||
|
||||
mod->walk([](Operation *op) {
|
||||
bool hasTensor = 0;
|
||||
auto results = op->getResults();
|
||||
auto operands = op->getOperands();
|
||||
for (auto i : results) {
|
||||
if (isa<RankedTensorType>(i.getType())) {
|
||||
hasTensor = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!hasTensor) {
|
||||
for (auto i : operands) {
|
||||
if (isa<RankedTensorType>(i.getType())) {
|
||||
hasTensor = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!hasTensor && !isa<ttng::MBarrierWaitOp>(op) &&
|
||||
!isa<ttng::ExtractMBarrierOp>(op) &&
|
||||
!isa<ttng::MBarrierArriveOp>(op)) {
|
||||
op->removeAttr("async_agent");
|
||||
}
|
||||
});
|
||||
|
||||
// TODO: More flexible way to set num-warps
|
||||
// One dma, one math warp group, set num-warps = 8
|
||||
auto i32_ty = IntegerType::get(mod->getContext(), 32);
|
||||
|
||||
@@ -264,8 +264,9 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,
|
||||
});
|
||||
for (int i = 0; i < numRoles; ++i) {
|
||||
if (lockLocs[i] == op) {
|
||||
if (roleId != -1)
|
||||
op->setAttr("agent.mutex_role", builder.getI32IntegerAttr(roleId));
|
||||
roleId = i;
|
||||
op->setAttr("agent.mutex_role", builder.getI32IntegerAttr(i));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1628,6 +1628,10 @@ void init_triton_ir(py::module &&m) {
|
||||
self.addPass(mlir::createTritonNvidiaGPUWSMaterializationPass(
|
||||
computeCapability));
|
||||
})
|
||||
.def("add_tritongpu_ws_fixup_missing_attrs_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonNvidiaGPUWSFixupMissingAttrs());
|
||||
})
|
||||
.def(
|
||||
"add_convert_triton_to_tritongpu_pass",
|
||||
[](mlir::PassManager &self, int numWarps, int threadsPerWarp,
|
||||
@@ -1752,6 +1756,7 @@ void init_triton_translation(py::module &m) {
|
||||
});
|
||||
m.def("get_num_warps", [](mlir::ModuleOp mod) {
|
||||
auto shared = mod->getAttrOfType<mlir::IntegerAttr>("triton_gpu.num-warps");
|
||||
assert(shared);
|
||||
return shared.getInt();
|
||||
});
|
||||
|
||||
|
||||
@@ -93,11 +93,6 @@ def matmul_no_scf_kernel(
|
||||
]))
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, ENABLE_WS):
|
||||
if '-'.join(map(str, [USE_TMA_EPILOGUE, ENABLE_WS])) in [
|
||||
'True-True'
|
||||
]:
|
||||
pytest.skip("error, skip")
|
||||
|
||||
if (TRANS_A):
|
||||
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
@@ -335,12 +330,6 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A,
|
||||
]:
|
||||
pytest.skip('Known legacy issue, ldmatrix can only support x4')
|
||||
|
||||
# with ENABLE_TMA=1 and ENABLE_MMA_V3=1
|
||||
if ENABLE_WS:
|
||||
# example:
|
||||
# [128-128-64-4-1-None-None-None-False-False-False-chain-dot-float16-False-3-True]
|
||||
pytest.skip('hang!')
|
||||
|
||||
M = BLOCK_M if M is None else M
|
||||
N = BLOCK_N if N is None else N
|
||||
K = BLOCK_K if K is None else K
|
||||
|
||||
@@ -120,11 +120,13 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch,
|
||||
pm.add_tritongpu_optimize_dot_operands_pass()
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
pm.add_tritongpu_decompose_conversions_pass()
|
||||
pm.add_tritongpu_ws_fixup_missing_attrs_pass()
|
||||
pm.add_tritongpu_reorder_instructions_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_symbol_dce_pass()
|
||||
if arch // 10 >= 9:
|
||||
pm.add_tritongpu_fence_insertion_pass()
|
||||
pm.add_tritongpu_ws_fixup_missing_attrs_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
@@ -556,15 +558,15 @@ def compile(fn, **kwargs):
|
||||
asm[ir_name] = str(next_module)
|
||||
if ir_name == "llir" and "shared" not in metadata:
|
||||
metadata["shared"] = get_shared_memory_size(module)
|
||||
if ir_name == "ttgir" and enable_warp_specialization:
|
||||
metadata["num_warps"] = get_num_warps(module)
|
||||
if ir_name == "ttgir":
|
||||
metadata["enable_warp_specialization"] = _triton.ir.is_ws_supported(next_module)
|
||||
if metadata["enable_warp_specialization"]:
|
||||
metadata["num_warps"] = get_num_warps(next_module)
|
||||
if ir_name == "ptx":
|
||||
metadata["name"] = get_kernel_name(next_module, pattern='// .globl')
|
||||
if ir_name == "amdgcn":
|
||||
metadata["name"] = get_kernel_name(next_module[0], pattern='.globl')
|
||||
asm["hsaco_path"] = next_module[1]
|
||||
if ir_name == "ttgir":
|
||||
metadata["enable_warp_specialization"] = _triton.ir.is_ws_supported(next_module)
|
||||
if not is_cuda and not is_hip:
|
||||
_device_backend.add_meta_info(ir_name, module, next_module, metadata, asm)
|
||||
module = next_module
|
||||
|
||||
Reference in New Issue
Block a user