Files
ROCm/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp
Jason Furmanek 3a6dc5ad8d resolve some merge conflicts
fix more conflits

Resolve merge conflicts

Some more build and conflict fixes

Resolve conflicts for 06-fused-attension.py

resolve merge conflicts for the tutorial group gemm example

Fixes for some LIT tests

resolve remaining conflicts in tests

Fix empty kernel

set capability 0
2023-11-06 23:13:10 +00:00

456 lines
21 KiB
C++

#include "ScanOpToLLVM.h"
#include "TritonGPUToLLVMBase.h"
#include "triton/Analysis/Utility.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::linearize;
using ::mlir::LLVM::shflIdxSync;
using ::mlir::LLVM::shflUpSync;
using ::mlir::LLVM::storeShared;
// apply combine region to a and b and return the result. If a or b is null,
// return the other operand.
static Value accumulate(ConversionPatternRewriter &rewriter, Region &combineOp,
Value a, Value b) {
if (!a) {
return b;
}
if (!b) {
return a;
}
// Create a new copy of the reduce block, and inline it
Block *currentBlock = rewriter.getBlock();
Region &parent = *currentBlock->getParent();
rewriter.cloneRegionBefore(combineOp, &parent.front());
auto &newScan = parent.front();
auto returnOp = dyn_cast<triton::ScanReturnOp>(newScan.getTerminator());
llvm::SmallVector<Value> combineArgs = {a, b};
rewriter.inlineBlockBefore(&newScan, &*rewriter.getInsertionPoint(),
combineArgs);
auto results = returnOp.getResult();
Value acc = results[0];
// Delete the terminator, which is no longer used
rewriter.eraseOp(returnOp);
return acc;
}
// Scan a contiguous elements within a thread and update `srcValues` in place.
static void scanThreadContiguousElements(SmallVector<Value> &srcValues,
ConversionPatternRewriter &rewriter,
ScanLoweringHelper &helper) {
// Depending on layout contiguous elements along axis dim may not be
// contiguous in srcValues. Keep track of what elements belong to the same
// chunk of contiguous elements.
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
unsigned numChunks = srcValues.size() / scanElementsPerThreads;
unsigned stride = helper.getAxisElementStride();
SmallVector<Value> accs(numChunks);
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
unsigned accIndex = (srcIndex % stride) +
((srcIndex / stride) / scanElementsPerThreads) * stride;
accs[accIndex] = accumulate(rewriter, helper.getCombineOp(), accs[accIndex],
srcValues[srcIndex]);
srcValues[srcIndex] = accs[accIndex];
}
}
// Apply a scan across threads of the warp for the last element of each
// contiguous group of elements.
static void warpScan(SmallVector<Value> &srcValues,
ConversionPatternRewriter &rewriter,
ScanLoweringHelper &helper, Value laneIdAxis,
Value laneId) {
Location loc = helper.getLoc();
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
unsigned elementStride = helper.getAxisElementStride();
unsigned threadStride = helper.getAxisThreadStride();
unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData();
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads;
// Only consider the last element of each contiguous chunk of elements.
if (elementIdx != scanElementsPerThreads - 1)
continue;
// Reduce within warps.
Value acc = srcValues[srcIndex];
for (unsigned i = 1; i <= (scanDim) / 2; i = i << 1) {
Value shfl = shflUpSync(loc, rewriter, acc, i * threadStride, laneId);
Value tempAcc = accumulate(rewriter, helper.getCombineOp(), shfl, acc);
Value mask = icmp_slt(laneIdAxis, i32_val(i));
acc = select(mask, acc, tempAcc);
}
srcValues[srcIndex] = acc;
}
}
// For each set of contiguous elements within a thread we store the partial
// reduction into shared memory. Each parallel scan and each warp will store its
// own partial reductions. The shared memory is organized as follow:
// -----------------------------------------------------------------
// chunk 0: | acc[0] warp 0 | acc[1] warp 0 | acc[0] warp 1 | acc[1] warp 1 |
// chunk 1: | acc[0] warp 0 | acc[1] warp 0 | acc[0] warp 1 | acc[1] warp 1 |
static void storeWarpAccumulator(SmallVector<Value> &srcValues,
ConversionPatternRewriter &rewriter,
ScanLoweringHelper &helper, Value laneId,
Value warpId, Value baseSharedMemPtr,
Value parallelLaneId) {
Location loc = helper.getLoc();
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData();
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA();
unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData();
unsigned chunkId = 0;
unsigned elementStride = helper.getAxisElementStride();
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads;
// Only consider the last element of each contiguous chunk of elements.
if (elementIdx != scanElementsPerThreads - 1)
continue;
Value lastElement = srcValues[srcIndex];
Value mask = icmp_eq(laneId, i32_val(scanDim - 1));
Value index = add(parallelLaneId, mul(warpId, i32_val(numParallelLane)));
index = add(index, i32_val(chunkId * numParallelLane * axisNumWarps));
Value writePtr = gep(baseSharedMemPtr.getType(), baseSharedMemPtr, index);
storeShared(rewriter, loc, writePtr, lastElement, mask);
chunkId++;
}
}
// Read the partial reductions from shared memory from each chunk of contiguous
// elements for each warp and parallel scan. Then combine the partial reduction
// with the right elements. Within a given contiguous element chunk we update
// all the elements by accumulating the value from the last element of the
// reduced value from the previous lane.
static void AddPartialReduce(SmallVector<Value> &srcValues,
ConversionPatternRewriter &rewriter,
ScanLoweringHelper &helper, Value sharedMemoryPtr,
Value warpId, Value laneIdAxis,
Value parallelLaneId, Value laneId) {
Location loc = helper.getLoc();
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA();
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread();
unsigned elementStride = helper.getAxisElementStride();
unsigned threadStride = helper.getAxisThreadStride();
unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData();
Value maskFirstWarp = icmp_eq(warpId, i32_val(0));
Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0));
Value maskFirstThread = and_(maskFirstWarp, maskFirstLane);
struct Accumulator {
Value acc;
Value maskedAcc;
};
unsigned numScanBlocks = helper.getAxisNumBlocks();
unsigned numParallelBlocks = helper.getNonAxisNumBlocks();
assert(numScanBlocks * numParallelBlocks * parallelElementsPerThread *
scanElementsPerThreads ==
srcValues.size());
SmallVector<Accumulator> accumulators(numParallelBlocks *
parallelElementsPerThread);
unsigned chunkId = 0;
unsigned blockStride = helper.getAxisBlockStride();
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads;
// Only consider the last element of each contiguous chunk of elements.
if (elementIdx != scanElementsPerThreads - 1)
continue;
// Accumulate the partial reduction from shared memory. Decide which
// accumulator to combine based on whether the elements belong to the same
// dimension along axis.
unsigned blockId = chunkId / parallelElementsPerThread;
unsigned parallelBlockId =
blockId % blockStride +
((blockId / blockStride) / numScanBlocks) * blockStride;
unsigned accumulatorIndex = chunkId % parallelElementsPerThread +
parallelBlockId * parallelElementsPerThread;
Accumulator &accumulator = accumulators[accumulatorIndex];
unsigned axisBlockId = (blockId / blockStride) % numScanBlocks;
for (unsigned i = 0; i < axisNumWarps; ++i) {
Value index = add(parallelLaneId, i32_val(numParallelLane *
(i + chunkId * axisNumWarps)));
Value ptr = gep(sharedMemoryPtr.getType(), sharedMemoryPtr, index);
Value partialReduce = load(ptr);
if (!accumulator.acc) {
accumulator.acc = partialReduce;
accumulator.maskedAcc = partialReduce;
continue;
}
accumulator.acc = accumulate(rewriter, helper.getCombineOp(),
accumulator.acc, partialReduce);
Value mask = icmp_slt(warpId, i32_val(i + 1));
accumulator.maskedAcc =
select(mask, accumulator.maskedAcc, accumulator.acc);
}
Value temp = accumulate(rewriter, helper.getCombineOp(),
accumulator.maskedAcc, srcValues[srcIndex]);
if (axisBlockId == 0) {
// For the first warp and first chunk we don't have anything to
// accumulate.
temp = select(maskFirstWarp, srcValues[srcIndex], temp);
}
srcValues[srcIndex] = temp;
// Update the rest of the contiguous elements.
Value lastElement =
shflUpSync(loc, rewriter, srcValues[srcIndex], threadStride, laneId);
lastElement = select(maskFirstLane, accumulator.maskedAcc, lastElement);
for (unsigned i = 1; i < scanElementsPerThreads; ++i) {
Value laneValue = srcValues[srcIndex - i * elementStride];
laneValue =
accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue);
if (axisBlockId == 0) {
// For the first warp and first chunk we don't have anything to
// accumulate.
laneValue = select(maskFirstThread,
srcValues[srcIndex - i * elementStride], laneValue);
}
srcValues[srcIndex - i * elementStride] = laneValue;
}
// For the next chunk start back from the value containing the
// accumulated value of all the warps.
accumulator.maskedAcc = accumulator.acc;
chunkId++;
}
}
static void AddPartialReduceOneWarp(SmallVector<Value> &srcValues,
ConversionPatternRewriter &rewriter,
ScanLoweringHelper &helper, Value warpId,
Value laneIdAxis, Value laneIdLast, Value laneId) {
Location loc = helper.getLoc();
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread();
unsigned elementStride = helper.getAxisElementStride();
unsigned threadStride = helper.getAxisThreadStride();
unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData();
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA();
unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData();
Value maskFirstWarp = icmp_eq(warpId, i32_val(0));
Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0));
Value maskFirstThread = and_(maskFirstWarp, maskFirstLane);
unsigned numScanBlocks = helper.getAxisNumBlocks();
unsigned numParallelBlocks = helper.getNonAxisNumBlocks();
assert(numScanBlocks * numParallelBlocks * parallelElementsPerThread *
scanElementsPerThreads ==
srcValues.size());
SmallVector<Value> accumulators(numParallelBlocks *
parallelElementsPerThread);
unsigned chunkId = 0;
unsigned blockStride = helper.getAxisBlockStride();
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads;
// Only consider the last element of each contiguous chunk of elements.
if (elementIdx != scanElementsPerThreads - 1)
continue;
unsigned blockId = chunkId / parallelElementsPerThread;
unsigned parallelBlockId =
blockId % blockStride +
((blockId / blockStride) / numScanBlocks) * blockStride;
unsigned accumulatorIndex = chunkId % parallelElementsPerThread +
parallelBlockId * parallelElementsPerThread;
Value &accumulator = accumulators[accumulatorIndex];
unsigned axisBlockId = (blockId / blockStride) % numScanBlocks;
if (axisBlockId == 0) // First chunk and first block
accumulator = srcValues[srcIndex];
else
srcValues[srcIndex] = accumulate(rewriter, helper.getCombineOp(),
accumulator, srcValues[srcIndex]);
// Update the rest of the contiguous elements.
Value lastElement = srcValues[srcIndex];
if (scanDim > 1) {
lastElement =
shflUpSync(loc, rewriter, srcValues[srcIndex], threadStride, laneId);
lastElement = select(maskFirstLane, accumulator, lastElement);
if (numScanBlocks > 1)
// Update accumulator with the value from the last lane.
accumulator =
shflIdxSync(loc, rewriter, srcValues[srcIndex], laneIdLast);
}
for (unsigned i = 1; i < scanElementsPerThreads; ++i) {
Value laneValue = srcValues[srcIndex - i * elementStride];
laneValue =
accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue);
if (axisBlockId == 0)
// For the first warp and first chunk we don't have anything to
// accumulate.
laneValue = select(maskFirstThread,
srcValues[srcIndex - i * elementStride], laneValue);
srcValues[srcIndex - i * elementStride] = laneValue;
}
// For the next chunk start back from the value containing the
// accumulated value of all the warps.
chunkId++;
}
}
namespace {
struct ScanOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::ScanOp> {
public:
using ConvertTritonGPUOpToLLVMPattern<
triton::ScanOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (succeeded(emitFastScan(op, adaptor, rewriter)))
return success();
return failure();
}
private:
SmallVector<Value> getMultiDimLaneId(ConversionPatternRewriter &rewriter,
ScanLoweringHelper &helper,
Value laneId) const;
SmallVector<Value> getMultiDimWarpId(ConversionPatternRewriter &rewriter,
ScanLoweringHelper &helper,
Value warpId) const;
std::tuple<Value, Value, Value>
getDelinearizedIds(ConversionPatternRewriter &rewriter,
ScanLoweringHelper &helper, Value laneId,
Value warpId) const;
LogicalResult emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const;
};
SmallVector<Value>
ScanOpConversion::getMultiDimLaneId(ConversionPatternRewriter &rewriter,
ScanLoweringHelper &helper,
Value laneId) const {
auto loc = helper.getLoc();
unsigned axis = helper.getAxis();
auto srcEncoding = helper.getEncoding();
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
auto order = triton::gpu::getOrder(srcEncoding);
return delinearize(rewriter, loc, laneId, threadsPerWarp, order);
}
SmallVector<Value>
ScanOpConversion::getMultiDimWarpId(ConversionPatternRewriter &rewriter,
ScanLoweringHelper &helper,
Value warpId) const {
auto loc = helper.getLoc();
unsigned axis = helper.getAxis();
auto srcEncoding = helper.getEncoding();
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
auto order = triton::gpu::getOrder(srcEncoding);
return delinearize(rewriter, loc, warpId, warpsPerCTA, order);
}
// Break up the threadId into lane and warp id along the scan dimension and
// compute a flat id for the parallel dimensions.
std::tuple<Value, Value, Value>
ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter,
ScanLoweringHelper &helper, Value laneId,
Value warpId) const {
auto loc = helper.getLoc();
unsigned axis = helper.getAxis();
auto srcEncoding = helper.getEncoding();
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
auto order = triton::gpu::getOrder(srcEncoding);
SmallVector<Value> multiDimLaneId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
Value laneIdAxis = multiDimLaneId[axis];
Value warpIdAxis = multiDimWarpId[axis];
multiDimLaneId[axis] = i32_val(0);
threadsPerWarp[axis] = 1;
Value laneIdParallel =
linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, order);
multiDimWarpId[axis] = i32_val(0);
warpsPerCTA[axis] = 1;
Value warpIdParallel =
linearize(rewriter, loc, multiDimWarpId, warpsPerCTA, order);
Value flatIdParallel =
add(laneIdParallel,
mul(warpIdParallel, i32_val(helper.getNonAxisNumThreadsPerWarp())));
return std::make_tuple(laneIdAxis, warpIdAxis, flatIdParallel);
}
// Lowering using warp shuffle operations to do warp level scan.
LogicalResult
ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
ScanLoweringHelper helper(op);
auto loc = helper.getLoc();
if (!helper.isSupported())
return failure();
Value threadId = getThreadId(rewriter, loc);
auto mod = op->getParentOfType<ModuleOp>();
unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
Value warpSize = i32_val(iWarpSize);
Value warpId = udiv(threadId, warpSize);
Value laneId = urem(threadId, warpSize);
auto [laneIdAxis, warpIdAxis, flatIdParallel] =
getDelinearizedIds(rewriter, helper, laneId, warpId);
auto input = adaptor.getOperands()[0];
auto type = op.getOperand(0).getType().cast<RankedTensorType>();
auto axisNumWarps = helper.getAxisNumWarpsWithUniqueData();
warpIdAxis = urem(warpIdAxis, i32_val(axisNumWarps));
SmallVector<Value> srcValues =
getTypeConverter()->unpackLLElements(loc, input, rewriter, type);
// Scan contigous elements in a thread and update `srcValues`.
scanThreadContiguousElements(srcValues, rewriter, helper);
// Apply warp level scan to the last element of each chunk of contiguous
// elements.
warpScan(srcValues, rewriter, helper, laneIdAxis, laneId);
if (axisNumWarps > 1) {
// Slow path for the case where there are multiple warps with unique data on
// the axis.
Type elemPtrTys = LLVM::LLVMPointerType::get(srcValues[0].getType(), 3);
Value baseSharedMemPtr = bitcast(
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys);
// Store the partial reducing for each warp into shared memory.
storeWarpAccumulator(srcValues, rewriter, helper, laneIdAxis, warpIdAxis,
baseSharedMemPtr, flatIdParallel);
barrier();
// Read back the partial reduction of each warp and accumulate them based on
// warpId. Then update each chunk of contiguous elements by adding the
// accumulated value from the previous lane.
AddPartialReduce(srcValues, rewriter, helper, baseSharedMemPtr, warpIdAxis,
laneIdAxis, flatIdParallel, laneId);
} else if (srcValues.size() > 1) {
// Fast path for the case where there is only one warp with unique data on
// the axis.
unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData();
auto multiDimLaneId = getMultiDimLaneId(rewriter, helper, laneId);
multiDimLaneId[helper.getAxis()] = i32_val(scanDim - 1);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(helper.getEncoding());
auto laneIdLast = linearize(rewriter, loc, multiDimLaneId, threadsPerWarp,
triton::gpu::getOrder(helper.getEncoding()));
AddPartialReduceOneWarp(srcValues, rewriter, helper, warpIdAxis, laneIdAxis,
laneIdLast, laneId);
} // else axisNumWarps == 1 and srcValues.size() == 1, nothing to do.
Value results = getTypeConverter()->packLLElements(loc, srcValues, rewriter,
input.getType());
rewriter.replaceOp(op, results);
return success();
}
} // namespace
void populateScanOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {
patterns.add<ScanOpConversion>(typeConverter, allocation, indexCacheInfo,
benefit);
}