mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Handle AtomicCASOp in GPU IR conversion (#2514)
Addressing https://github.com/openai/triton/issues/2011 Co-authored-by: Philippe Tillet <phil@openai.com> Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
@@ -298,17 +298,23 @@ private:
|
||||
scratchAlignment);
|
||||
}
|
||||
} else if (auto atomicCASOp = dyn_cast<triton::AtomicCASOp>(op)) {
|
||||
// only scalar requires scratch memory
|
||||
// make it explicit for readability
|
||||
auto value = op->getOperand(0);
|
||||
auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp);
|
||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||
std::multiplies{});
|
||||
auto elemTy =
|
||||
value.getType().cast<triton::PointerType>().getPointeeType();
|
||||
auto bytes = elemTy.isa<triton::PointerType>()
|
||||
? elems * kPtrBitWidth / 8
|
||||
: elems * elemTy.getIntOrFloatBitWidth() / 8;
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
|
||||
scratchAlignment);
|
||||
if (value.getType().dyn_cast<RankedTensorType>()) {
|
||||
// nothing to do
|
||||
} else {
|
||||
auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp);
|
||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||
std::multiplies{});
|
||||
auto elemTy =
|
||||
value.getType().cast<triton::PointerType>().getPointeeType();
|
||||
auto bytes = elemTy.isa<triton::PointerType>()
|
||||
? elems * kPtrBitWidth / 8
|
||||
: elems * elemTy.getIntOrFloatBitWidth() / 8;
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
|
||||
scratchAlignment);
|
||||
}
|
||||
} else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
|
||||
auto callable = callOp.resolveCallable();
|
||||
auto funcOp = dyn_cast<FunctionOpInterface>(callable);
|
||||
|
||||
Reference in New Issue
Block a user