[BACKEND] Handle AtomicCASOp in GPU IR conversion (#2514)

Addressing https://github.com/openai/triton/issues/2011

Co-authored-by: Philippe Tillet <phil@openai.com>
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
Hongtao Yu
2023-10-25 12:20:07 -07:00
committed by GitHub
parent 7d55968fee
commit 2323adb387
5 changed files with 133 additions and 40 deletions

View File

@@ -298,17 +298,23 @@ private:
scratchAlignment);
}
} else if (auto atomicCASOp = dyn_cast<triton::AtomicCASOp>(op)) {
// only scalar requires scratch memory
// make it explicit for readability
auto value = op->getOperand(0);
auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto elemTy =
value.getType().cast<triton::PointerType>().getPointeeType();
auto bytes = elemTy.isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * elemTy.getIntOrFloatBitWidth() / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
if (value.getType().dyn_cast<RankedTensorType>()) {
// nothing to do
} else {
auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto elemTy =
value.getType().cast<triton::PointerType>().getPointeeType();
auto bytes = elemTy.isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * elemTy.getIntOrFloatBitWidth() / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
}
} else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
auto callable = callOp.resolveCallable();
auto funcOp = dyn_cast<FunctionOpInterface>(callable);