mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] no longer uses shared mem or barriers for single-warp reductions (#1915)
0-bytes shared mem buffers don't materialize empty allocation buffers; this could lead to unnecessary barriers. note: reduceop code has become quite messy and will require some cleanup
This commit is contained in:
committed by
Ognjen Plavsic
parent
398d2c7dd0
commit
4215086931
@@ -171,12 +171,18 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
template <BufferT::BufferKind T>
|
||||
void maybeAddScratchBuffer(Operation *op, unsigned bytes) {
|
||||
if (bytes > 0)
|
||||
allocation->addBuffer<T>(op, bytes);
|
||||
}
|
||||
|
||||
/// Initializes temporary shared memory for a given operation.
|
||||
void getScratchValueSize(Operation *op) {
|
||||
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
|
||||
ReduceOpHelper helper(reduceOp);
|
||||
unsigned bytes = helper.getScratchSizeInBytes();
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
|
||||
auto srcTy = cvtLayout.getSrc().getType().cast<RankedTensorType>();
|
||||
auto dstTy = cvtLayout.getResult().getType().cast<RankedTensorType>();
|
||||
@@ -200,7 +206,7 @@ private:
|
||||
srcTy.getElementType().isa<triton::PointerType>()
|
||||
? elems * kPtrBitWidth / 8
|
||||
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(op)) {
|
||||
auto value = op->getOperand(0);
|
||||
// only scalar requires scratch memory
|
||||
@@ -217,7 +223,7 @@ private:
|
||||
elemTy.isa<triton::PointerType>()
|
||||
? elems * kPtrBitWidth / 8
|
||||
: elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
}
|
||||
} else if (auto atomicCASOp = dyn_cast<triton::AtomicCASOp>(op)) {
|
||||
auto value = op->getOperand(0);
|
||||
@@ -229,13 +235,13 @@ private:
|
||||
auto bytes = elemTy.isa<triton::PointerType>()
|
||||
? elems * kPtrBitWidth / 8
|
||||
: elems * elemTy.getIntOrFloatBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
} else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
|
||||
auto callable = callOp.resolveCallable();
|
||||
auto funcOp = dyn_cast<FunctionOpInterface>(callable);
|
||||
auto *funcAlloc = &(*funcAllocMap)[funcOp];
|
||||
auto bytes = funcAlloc->getSharedMemorySize();
|
||||
allocation->addBuffer<BufferT::BufferKind::Virtual>(op, bytes);
|
||||
maybeAddScratchBuffer<BufferT::BufferKind::Virtual>(op, bytes);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -60,9 +60,10 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
|
||||
|
||||
auto argLayout = getSrcLayout();
|
||||
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
// if (argLayoutMma && argLayoutMma.getVersionMajor() == 2 &&
|
||||
// triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
|
||||
// return {{1, 1}, {1, 1}};
|
||||
|
||||
// that case doesn't need inter-warp communication
|
||||
if (isFastReduction() && triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
|
||||
return {{0, 0}, {0, 0}};
|
||||
|
||||
/// shared memory block0
|
||||
smemShapes[0] = convertType<unsigned>(getSrcShape());
|
||||
|
||||
@@ -349,18 +349,20 @@ private:
|
||||
unsigned elems = product<unsigned>(smemShapes[0]);
|
||||
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));
|
||||
|
||||
SmallVector<Value> smemBases(op.getNumOperands());
|
||||
smemBases[0] = bitcast(
|
||||
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]);
|
||||
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
|
||||
smemBases[i] =
|
||||
bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)),
|
||||
elemPtrTys[i]);
|
||||
}
|
||||
|
||||
unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData();
|
||||
unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData();
|
||||
|
||||
SmallVector<Value> smemBases(op.getNumOperands());
|
||||
if (sizeInterWarps > 1) {
|
||||
smemBases[0] = bitcast(
|
||||
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]);
|
||||
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
|
||||
smemBases[i] =
|
||||
bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)),
|
||||
elemPtrTys[i]);
|
||||
}
|
||||
}
|
||||
|
||||
unsigned srcElems = getTotalElemsPerThread(srcTys[0]);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);
|
||||
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
|
||||
@@ -418,6 +420,7 @@ private:
|
||||
Value zero = i32_val(0);
|
||||
Value laneZero = icmp_eq(laneIdAxis, zero);
|
||||
|
||||
std::map<SmallVector<unsigned>, SmallVector<Value>> finalAccs;
|
||||
for (auto it : accs) {
|
||||
const SmallVector<unsigned> &key = it.first;
|
||||
SmallVector<Value> acc = it.second;
|
||||
@@ -440,8 +443,13 @@ private:
|
||||
accumulate(rewriter, *combineOp, acc, shfl, false);
|
||||
}
|
||||
|
||||
if (sizeInterWarps == 1) {
|
||||
finalAccs[key] = acc;
|
||||
continue;
|
||||
}
|
||||
|
||||
SmallVector<Value> writeIdx = indices[key];
|
||||
writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis;
|
||||
writeIdx[axis] = warpIdAxis;
|
||||
Value writeOffset =
|
||||
linearize(rewriter, loc, writeIdx, smemShapes[0], order);
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
@@ -450,6 +458,30 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
if (sizeInterWarps == 1) {
|
||||
SmallVector<Value> results(op.getNumOperands());
|
||||
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
|
||||
if (auto resultTy =
|
||||
op.getResult()[i].getType().dyn_cast<RankedTensorType>()) {
|
||||
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
||||
unsigned resultElems = getTotalElemsPerThread(resultTy);
|
||||
SmallVector<SmallVector<unsigned>> resultOffset =
|
||||
emitOffsetForLayout(resultLayout, resultTy);
|
||||
SmallVector<Value> resultVals;
|
||||
for (int j = 0; j < resultElems; j++) {
|
||||
auto key = resultOffset[j];
|
||||
key.insert(key.begin() + axis, 0);
|
||||
resultVals.push_back(finalAccs[key][i]);
|
||||
}
|
||||
results[i] = getTypeConverter()->packLLElements(loc, resultVals,
|
||||
rewriter, resultTy);
|
||||
} else
|
||||
results[i] = finalAccs.begin()->second[i];
|
||||
}
|
||||
rewriter.replaceOp(op, results);
|
||||
return success();
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
// The second round of shuffle reduction
|
||||
@@ -508,9 +540,6 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
// We could avoid this barrier in some of the layouts, however this is not
|
||||
// the general case.
|
||||
// TODO: optimize the barrier in case the layouts are accepted.
|
||||
barrier();
|
||||
|
||||
// set output values
|
||||
|
||||
@@ -213,17 +213,23 @@ private:
|
||||
// of shared memory and append it to the operands of the callOp.
|
||||
auto loc = callOp.getLoc();
|
||||
auto caller = callOp->getParentOfType<FunctionOpInterface>();
|
||||
auto base = allocation.getFunctionSharedMemoryBase(caller);
|
||||
auto *funcAllocation = allocation.getFuncData(caller);
|
||||
auto bufferId = funcAllocation->getBufferId(callOp);
|
||||
auto offset = funcAllocation->getOffset(bufferId);
|
||||
auto ptrTy = LLVM::LLVMPointerType::get(
|
||||
this->getTypeConverter()->convertType(rewriter.getI8Type()),
|
||||
NVVM::kSharedMemorySpace);
|
||||
auto offsetValue = gep(ptrTy, base, i32_val(offset));
|
||||
auto promotedOperands = this->getTypeConverter()->promoteOperands(
|
||||
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
|
||||
adaptor.getOperands(), rewriter);
|
||||
auto base = allocation.getFunctionSharedMemoryBase(caller);
|
||||
auto *funcAllocation = allocation.getFuncData(caller);
|
||||
auto bufferId = funcAllocation->getBufferId(callOp);
|
||||
// function doesn't have a shared mem buffer
|
||||
if (bufferId == (size_t)-1) {
|
||||
promotedOperands.push_back(base);
|
||||
return promotedOperands;
|
||||
}
|
||||
// function has a shared mem buffer
|
||||
auto offset = funcAllocation->getOffset(bufferId);
|
||||
auto offsetValue = gep(ptrTy, base, i32_val(offset));
|
||||
promotedOperands.push_back(offsetValue);
|
||||
return promotedOperands;
|
||||
}
|
||||
|
||||
@@ -113,7 +113,7 @@ def check_type_supported(dtype):
|
||||
class MmaLayout:
|
||||
def __init__(self, version, warps_per_cta):
|
||||
self.version = version
|
||||
self.warps_per_cta = str(warps_per_cta)
|
||||
self.warps_per_cta = warps_per_cta
|
||||
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}}}>"
|
||||
@@ -121,10 +121,10 @@ class MmaLayout:
|
||||
|
||||
class BlockedLayout:
|
||||
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
|
||||
self.sz_per_thread = str(size_per_thread)
|
||||
self.threads_per_warp = str(threads_per_warp)
|
||||
self.warps_per_cta = str(warps_per_cta)
|
||||
self.order = str(order)
|
||||
self.sz_per_thread = size_per_thread
|
||||
self.threads_per_warp = threads_per_warp
|
||||
self.warps_per_cta = warps_per_cta
|
||||
self.order = order
|
||||
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
|
||||
@@ -1959,7 +1959,6 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
out_dtype = tl.float16
|
||||
else:
|
||||
out_dtype = tl.float32
|
||||
|
||||
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||
y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||
w_tri, w_tri.stride(0), w_tri.stride(1),
|
||||
@@ -1974,6 +1973,14 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
CHAIN_DOT=epilogue == 'chain-dot',
|
||||
ALLOW_TF32=allow_tf32,
|
||||
num_warps=num_warps)
|
||||
if epilogue == 'softmax' and (in_dtype != 'float32' or allow_tf32):
|
||||
ptx = pgm.asm["ptx"]
|
||||
start = ptx.find("shfl.sync")
|
||||
end = ptx.find("cvt.rn.f16.f32")
|
||||
red_code = ptx[start:end]
|
||||
assert len(red_code) > 0
|
||||
assert "shared" not in red_code
|
||||
assert "bar.sync" not in red_code
|
||||
# torch result
|
||||
if in_dtype == 'int8':
|
||||
z_ref = np.matmul(x.astype(np.float32),
|
||||
|
||||
Reference in New Issue
Block a user