mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
fix more conflits Resolve merge conflicts Some more build and conflict fixes Resolve conflicts for 06-fused-attension.py resolve merge conflicts for the tutorial group gemm example Fixes for some LIT tests resolve remaining conflicts in tests Fix empty kernel set capability 0
456 lines
21 KiB
C++
456 lines
21 KiB
C++
#include "ScanOpToLLVM.h"
|
|
#include "TritonGPUToLLVMBase.h"
|
|
#include "triton/Analysis/Utility.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::triton;
|
|
|
|
using ::mlir::LLVM::delinearize;
|
|
using ::mlir::LLVM::linearize;
|
|
using ::mlir::LLVM::shflIdxSync;
|
|
using ::mlir::LLVM::shflUpSync;
|
|
using ::mlir::LLVM::storeShared;
|
|
|
|
// apply combine region to a and b and return the result. If a or b is null,
|
|
// return the other operand.
|
|
static Value accumulate(ConversionPatternRewriter &rewriter, Region &combineOp,
|
|
Value a, Value b) {
|
|
if (!a) {
|
|
return b;
|
|
}
|
|
if (!b) {
|
|
return a;
|
|
}
|
|
// Create a new copy of the reduce block, and inline it
|
|
Block *currentBlock = rewriter.getBlock();
|
|
Region &parent = *currentBlock->getParent();
|
|
rewriter.cloneRegionBefore(combineOp, &parent.front());
|
|
auto &newScan = parent.front();
|
|
auto returnOp = dyn_cast<triton::ScanReturnOp>(newScan.getTerminator());
|
|
llvm::SmallVector<Value> combineArgs = {a, b};
|
|
rewriter.inlineBlockBefore(&newScan, &*rewriter.getInsertionPoint(),
|
|
combineArgs);
|
|
auto results = returnOp.getResult();
|
|
Value acc = results[0];
|
|
// Delete the terminator, which is no longer used
|
|
rewriter.eraseOp(returnOp);
|
|
return acc;
|
|
}
|
|
|
|
// Scan a contiguous elements within a thread and update `srcValues` in place.
|
|
static void scanThreadContiguousElements(SmallVector<Value> &srcValues,
|
|
ConversionPatternRewriter &rewriter,
|
|
ScanLoweringHelper &helper) {
|
|
// Depending on layout contiguous elements along axis dim may not be
|
|
// contiguous in srcValues. Keep track of what elements belong to the same
|
|
// chunk of contiguous elements.
|
|
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
|
|
unsigned numChunks = srcValues.size() / scanElementsPerThreads;
|
|
unsigned stride = helper.getAxisElementStride();
|
|
SmallVector<Value> accs(numChunks);
|
|
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
|
|
unsigned accIndex = (srcIndex % stride) +
|
|
((srcIndex / stride) / scanElementsPerThreads) * stride;
|
|
|
|
accs[accIndex] = accumulate(rewriter, helper.getCombineOp(), accs[accIndex],
|
|
srcValues[srcIndex]);
|
|
srcValues[srcIndex] = accs[accIndex];
|
|
}
|
|
}
|
|
|
|
// Apply a scan across threads of the warp for the last element of each
|
|
// contiguous group of elements.
|
|
static void warpScan(SmallVector<Value> &srcValues,
|
|
ConversionPatternRewriter &rewriter,
|
|
ScanLoweringHelper &helper, Value laneIdAxis,
|
|
Value laneId) {
|
|
Location loc = helper.getLoc();
|
|
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
|
|
unsigned elementStride = helper.getAxisElementStride();
|
|
unsigned threadStride = helper.getAxisThreadStride();
|
|
unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData();
|
|
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
|
|
unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads;
|
|
// Only consider the last element of each contiguous chunk of elements.
|
|
if (elementIdx != scanElementsPerThreads - 1)
|
|
continue;
|
|
// Reduce within warps.
|
|
Value acc = srcValues[srcIndex];
|
|
for (unsigned i = 1; i <= (scanDim) / 2; i = i << 1) {
|
|
Value shfl = shflUpSync(loc, rewriter, acc, i * threadStride, laneId);
|
|
Value tempAcc = accumulate(rewriter, helper.getCombineOp(), shfl, acc);
|
|
Value mask = icmp_slt(laneIdAxis, i32_val(i));
|
|
acc = select(mask, acc, tempAcc);
|
|
}
|
|
srcValues[srcIndex] = acc;
|
|
}
|
|
}
|
|
|
|
// For each set of contiguous elements within a thread we store the partial
|
|
// reduction into shared memory. Each parallel scan and each warp will store its
|
|
// own partial reductions. The shared memory is organized as follow:
|
|
// -----------------------------------------------------------------
|
|
// chunk 0: | acc[0] warp 0 | acc[1] warp 0 | acc[0] warp 1 | acc[1] warp 1 |
|
|
// chunk 1: | acc[0] warp 0 | acc[1] warp 0 | acc[0] warp 1 | acc[1] warp 1 |
|
|
static void storeWarpAccumulator(SmallVector<Value> &srcValues,
|
|
ConversionPatternRewriter &rewriter,
|
|
ScanLoweringHelper &helper, Value laneId,
|
|
Value warpId, Value baseSharedMemPtr,
|
|
Value parallelLaneId) {
|
|
Location loc = helper.getLoc();
|
|
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
|
|
unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData();
|
|
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA();
|
|
unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData();
|
|
unsigned chunkId = 0;
|
|
unsigned elementStride = helper.getAxisElementStride();
|
|
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
|
|
unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads;
|
|
// Only consider the last element of each contiguous chunk of elements.
|
|
if (elementIdx != scanElementsPerThreads - 1)
|
|
continue;
|
|
Value lastElement = srcValues[srcIndex];
|
|
Value mask = icmp_eq(laneId, i32_val(scanDim - 1));
|
|
Value index = add(parallelLaneId, mul(warpId, i32_val(numParallelLane)));
|
|
index = add(index, i32_val(chunkId * numParallelLane * axisNumWarps));
|
|
Value writePtr = gep(baseSharedMemPtr.getType(), baseSharedMemPtr, index);
|
|
storeShared(rewriter, loc, writePtr, lastElement, mask);
|
|
chunkId++;
|
|
}
|
|
}
|
|
|
|
// Read the partial reductions from shared memory from each chunk of contiguous
|
|
// elements for each warp and parallel scan. Then combine the partial reduction
|
|
// with the right elements. Within a given contiguous element chunk we update
|
|
// all the elements by accumulating the value from the last element of the
|
|
// reduced value from the previous lane.
|
|
static void AddPartialReduce(SmallVector<Value> &srcValues,
|
|
ConversionPatternRewriter &rewriter,
|
|
ScanLoweringHelper &helper, Value sharedMemoryPtr,
|
|
Value warpId, Value laneIdAxis,
|
|
Value parallelLaneId, Value laneId) {
|
|
Location loc = helper.getLoc();
|
|
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA();
|
|
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
|
|
unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread();
|
|
unsigned elementStride = helper.getAxisElementStride();
|
|
unsigned threadStride = helper.getAxisThreadStride();
|
|
unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData();
|
|
Value maskFirstWarp = icmp_eq(warpId, i32_val(0));
|
|
Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0));
|
|
Value maskFirstThread = and_(maskFirstWarp, maskFirstLane);
|
|
struct Accumulator {
|
|
Value acc;
|
|
Value maskedAcc;
|
|
};
|
|
unsigned numScanBlocks = helper.getAxisNumBlocks();
|
|
unsigned numParallelBlocks = helper.getNonAxisNumBlocks();
|
|
assert(numScanBlocks * numParallelBlocks * parallelElementsPerThread *
|
|
scanElementsPerThreads ==
|
|
srcValues.size());
|
|
SmallVector<Accumulator> accumulators(numParallelBlocks *
|
|
parallelElementsPerThread);
|
|
unsigned chunkId = 0;
|
|
unsigned blockStride = helper.getAxisBlockStride();
|
|
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
|
|
unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads;
|
|
// Only consider the last element of each contiguous chunk of elements.
|
|
if (elementIdx != scanElementsPerThreads - 1)
|
|
continue;
|
|
// Accumulate the partial reduction from shared memory. Decide which
|
|
// accumulator to combine based on whether the elements belong to the same
|
|
// dimension along axis.
|
|
unsigned blockId = chunkId / parallelElementsPerThread;
|
|
unsigned parallelBlockId =
|
|
blockId % blockStride +
|
|
((blockId / blockStride) / numScanBlocks) * blockStride;
|
|
unsigned accumulatorIndex = chunkId % parallelElementsPerThread +
|
|
parallelBlockId * parallelElementsPerThread;
|
|
Accumulator &accumulator = accumulators[accumulatorIndex];
|
|
unsigned axisBlockId = (blockId / blockStride) % numScanBlocks;
|
|
for (unsigned i = 0; i < axisNumWarps; ++i) {
|
|
Value index = add(parallelLaneId, i32_val(numParallelLane *
|
|
(i + chunkId * axisNumWarps)));
|
|
Value ptr = gep(sharedMemoryPtr.getType(), sharedMemoryPtr, index);
|
|
Value partialReduce = load(ptr);
|
|
if (!accumulator.acc) {
|
|
accumulator.acc = partialReduce;
|
|
accumulator.maskedAcc = partialReduce;
|
|
continue;
|
|
}
|
|
accumulator.acc = accumulate(rewriter, helper.getCombineOp(),
|
|
accumulator.acc, partialReduce);
|
|
Value mask = icmp_slt(warpId, i32_val(i + 1));
|
|
accumulator.maskedAcc =
|
|
select(mask, accumulator.maskedAcc, accumulator.acc);
|
|
}
|
|
Value temp = accumulate(rewriter, helper.getCombineOp(),
|
|
accumulator.maskedAcc, srcValues[srcIndex]);
|
|
if (axisBlockId == 0) {
|
|
// For the first warp and first chunk we don't have anything to
|
|
// accumulate.
|
|
temp = select(maskFirstWarp, srcValues[srcIndex], temp);
|
|
}
|
|
srcValues[srcIndex] = temp;
|
|
// Update the rest of the contiguous elements.
|
|
Value lastElement =
|
|
shflUpSync(loc, rewriter, srcValues[srcIndex], threadStride, laneId);
|
|
lastElement = select(maskFirstLane, accumulator.maskedAcc, lastElement);
|
|
for (unsigned i = 1; i < scanElementsPerThreads; ++i) {
|
|
Value laneValue = srcValues[srcIndex - i * elementStride];
|
|
laneValue =
|
|
accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue);
|
|
if (axisBlockId == 0) {
|
|
// For the first warp and first chunk we don't have anything to
|
|
// accumulate.
|
|
laneValue = select(maskFirstThread,
|
|
srcValues[srcIndex - i * elementStride], laneValue);
|
|
}
|
|
srcValues[srcIndex - i * elementStride] = laneValue;
|
|
}
|
|
// For the next chunk start back from the value containing the
|
|
// accumulated value of all the warps.
|
|
accumulator.maskedAcc = accumulator.acc;
|
|
chunkId++;
|
|
}
|
|
}
|
|
|
|
static void AddPartialReduceOneWarp(SmallVector<Value> &srcValues,
|
|
ConversionPatternRewriter &rewriter,
|
|
ScanLoweringHelper &helper, Value warpId,
|
|
Value laneIdAxis, Value laneIdLast, Value laneId) {
|
|
Location loc = helper.getLoc();
|
|
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
|
|
unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread();
|
|
unsigned elementStride = helper.getAxisElementStride();
|
|
unsigned threadStride = helper.getAxisThreadStride();
|
|
unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData();
|
|
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA();
|
|
unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData();
|
|
Value maskFirstWarp = icmp_eq(warpId, i32_val(0));
|
|
Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0));
|
|
Value maskFirstThread = and_(maskFirstWarp, maskFirstLane);
|
|
unsigned numScanBlocks = helper.getAxisNumBlocks();
|
|
unsigned numParallelBlocks = helper.getNonAxisNumBlocks();
|
|
assert(numScanBlocks * numParallelBlocks * parallelElementsPerThread *
|
|
scanElementsPerThreads ==
|
|
srcValues.size());
|
|
SmallVector<Value> accumulators(numParallelBlocks *
|
|
parallelElementsPerThread);
|
|
unsigned chunkId = 0;
|
|
unsigned blockStride = helper.getAxisBlockStride();
|
|
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
|
|
unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads;
|
|
// Only consider the last element of each contiguous chunk of elements.
|
|
if (elementIdx != scanElementsPerThreads - 1)
|
|
continue;
|
|
unsigned blockId = chunkId / parallelElementsPerThread;
|
|
unsigned parallelBlockId =
|
|
blockId % blockStride +
|
|
((blockId / blockStride) / numScanBlocks) * blockStride;
|
|
unsigned accumulatorIndex = chunkId % parallelElementsPerThread +
|
|
parallelBlockId * parallelElementsPerThread;
|
|
Value &accumulator = accumulators[accumulatorIndex];
|
|
unsigned axisBlockId = (blockId / blockStride) % numScanBlocks;
|
|
if (axisBlockId == 0) // First chunk and first block
|
|
accumulator = srcValues[srcIndex];
|
|
else
|
|
srcValues[srcIndex] = accumulate(rewriter, helper.getCombineOp(),
|
|
accumulator, srcValues[srcIndex]);
|
|
// Update the rest of the contiguous elements.
|
|
Value lastElement = srcValues[srcIndex];
|
|
if (scanDim > 1) {
|
|
lastElement =
|
|
shflUpSync(loc, rewriter, srcValues[srcIndex], threadStride, laneId);
|
|
lastElement = select(maskFirstLane, accumulator, lastElement);
|
|
if (numScanBlocks > 1)
|
|
// Update accumulator with the value from the last lane.
|
|
accumulator =
|
|
shflIdxSync(loc, rewriter, srcValues[srcIndex], laneIdLast);
|
|
}
|
|
for (unsigned i = 1; i < scanElementsPerThreads; ++i) {
|
|
Value laneValue = srcValues[srcIndex - i * elementStride];
|
|
laneValue =
|
|
accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue);
|
|
if (axisBlockId == 0)
|
|
// For the first warp and first chunk we don't have anything to
|
|
// accumulate.
|
|
laneValue = select(maskFirstThread,
|
|
srcValues[srcIndex - i * elementStride], laneValue);
|
|
srcValues[srcIndex - i * elementStride] = laneValue;
|
|
}
|
|
// For the next chunk start back from the value containing the
|
|
// accumulated value of all the warps.
|
|
chunkId++;
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
struct ScanOpConversion
|
|
: public ConvertTritonGPUOpToLLVMPattern<triton::ScanOp> {
|
|
public:
|
|
using ConvertTritonGPUOpToLLVMPattern<
|
|
triton::ScanOp>::ConvertTritonGPUOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (succeeded(emitFastScan(op, adaptor, rewriter)))
|
|
return success();
|
|
return failure();
|
|
}
|
|
|
|
private:
|
|
SmallVector<Value> getMultiDimLaneId(ConversionPatternRewriter &rewriter,
|
|
ScanLoweringHelper &helper,
|
|
Value laneId) const;
|
|
SmallVector<Value> getMultiDimWarpId(ConversionPatternRewriter &rewriter,
|
|
ScanLoweringHelper &helper,
|
|
Value warpId) const;
|
|
std::tuple<Value, Value, Value>
|
|
getDelinearizedIds(ConversionPatternRewriter &rewriter,
|
|
ScanLoweringHelper &helper, Value laneId,
|
|
Value warpId) const;
|
|
LogicalResult emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const;
|
|
};
|
|
|
|
SmallVector<Value>
|
|
ScanOpConversion::getMultiDimLaneId(ConversionPatternRewriter &rewriter,
|
|
ScanLoweringHelper &helper,
|
|
Value laneId) const {
|
|
auto loc = helper.getLoc();
|
|
unsigned axis = helper.getAxis();
|
|
auto srcEncoding = helper.getEncoding();
|
|
|
|
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
|
|
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
|
|
auto order = triton::gpu::getOrder(srcEncoding);
|
|
return delinearize(rewriter, loc, laneId, threadsPerWarp, order);
|
|
}
|
|
|
|
SmallVector<Value>
|
|
ScanOpConversion::getMultiDimWarpId(ConversionPatternRewriter &rewriter,
|
|
ScanLoweringHelper &helper,
|
|
Value warpId) const {
|
|
auto loc = helper.getLoc();
|
|
unsigned axis = helper.getAxis();
|
|
auto srcEncoding = helper.getEncoding();
|
|
|
|
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
|
|
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
|
|
auto order = triton::gpu::getOrder(srcEncoding);
|
|
return delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
|
}
|
|
|
|
// Break up the threadId into lane and warp id along the scan dimension and
|
|
// compute a flat id for the parallel dimensions.
|
|
std::tuple<Value, Value, Value>
|
|
ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter,
|
|
ScanLoweringHelper &helper, Value laneId,
|
|
Value warpId) const {
|
|
auto loc = helper.getLoc();
|
|
unsigned axis = helper.getAxis();
|
|
auto srcEncoding = helper.getEncoding();
|
|
|
|
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
|
|
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
|
|
auto order = triton::gpu::getOrder(srcEncoding);
|
|
SmallVector<Value> multiDimLaneId =
|
|
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
|
|
SmallVector<Value> multiDimWarpId =
|
|
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
|
|
|
Value laneIdAxis = multiDimLaneId[axis];
|
|
Value warpIdAxis = multiDimWarpId[axis];
|
|
|
|
multiDimLaneId[axis] = i32_val(0);
|
|
threadsPerWarp[axis] = 1;
|
|
Value laneIdParallel =
|
|
linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, order);
|
|
multiDimWarpId[axis] = i32_val(0);
|
|
warpsPerCTA[axis] = 1;
|
|
Value warpIdParallel =
|
|
linearize(rewriter, loc, multiDimWarpId, warpsPerCTA, order);
|
|
Value flatIdParallel =
|
|
add(laneIdParallel,
|
|
mul(warpIdParallel, i32_val(helper.getNonAxisNumThreadsPerWarp())));
|
|
return std::make_tuple(laneIdAxis, warpIdAxis, flatIdParallel);
|
|
}
|
|
|
|
// Lowering using warp shuffle operations to do warp level scan.
|
|
LogicalResult
|
|
ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
ScanLoweringHelper helper(op);
|
|
auto loc = helper.getLoc();
|
|
if (!helper.isSupported())
|
|
return failure();
|
|
|
|
Value threadId = getThreadId(rewriter, loc);
|
|
auto mod = op->getParentOfType<ModuleOp>();
|
|
unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
|
|
Value warpSize = i32_val(iWarpSize);
|
|
Value warpId = udiv(threadId, warpSize);
|
|
Value laneId = urem(threadId, warpSize);
|
|
|
|
auto [laneIdAxis, warpIdAxis, flatIdParallel] =
|
|
getDelinearizedIds(rewriter, helper, laneId, warpId);
|
|
auto input = adaptor.getOperands()[0];
|
|
auto type = op.getOperand(0).getType().cast<RankedTensorType>();
|
|
auto axisNumWarps = helper.getAxisNumWarpsWithUniqueData();
|
|
warpIdAxis = urem(warpIdAxis, i32_val(axisNumWarps));
|
|
SmallVector<Value> srcValues =
|
|
getTypeConverter()->unpackLLElements(loc, input, rewriter, type);
|
|
|
|
// Scan contigous elements in a thread and update `srcValues`.
|
|
scanThreadContiguousElements(srcValues, rewriter, helper);
|
|
// Apply warp level scan to the last element of each chunk of contiguous
|
|
// elements.
|
|
warpScan(srcValues, rewriter, helper, laneIdAxis, laneId);
|
|
|
|
if (axisNumWarps > 1) {
|
|
// Slow path for the case where there are multiple warps with unique data on
|
|
// the axis.
|
|
Type elemPtrTys = LLVM::LLVMPointerType::get(srcValues[0].getType(), 3);
|
|
Value baseSharedMemPtr = bitcast(
|
|
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys);
|
|
// Store the partial reducing for each warp into shared memory.
|
|
storeWarpAccumulator(srcValues, rewriter, helper, laneIdAxis, warpIdAxis,
|
|
baseSharedMemPtr, flatIdParallel);
|
|
barrier();
|
|
// Read back the partial reduction of each warp and accumulate them based on
|
|
// warpId. Then update each chunk of contiguous elements by adding the
|
|
// accumulated value from the previous lane.
|
|
AddPartialReduce(srcValues, rewriter, helper, baseSharedMemPtr, warpIdAxis,
|
|
laneIdAxis, flatIdParallel, laneId);
|
|
} else if (srcValues.size() > 1) {
|
|
// Fast path for the case where there is only one warp with unique data on
|
|
// the axis.
|
|
unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData();
|
|
auto multiDimLaneId = getMultiDimLaneId(rewriter, helper, laneId);
|
|
multiDimLaneId[helper.getAxis()] = i32_val(scanDim - 1);
|
|
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(helper.getEncoding());
|
|
auto laneIdLast = linearize(rewriter, loc, multiDimLaneId, threadsPerWarp,
|
|
triton::gpu::getOrder(helper.getEncoding()));
|
|
AddPartialReduceOneWarp(srcValues, rewriter, helper, warpIdAxis, laneIdAxis,
|
|
laneIdLast, laneId);
|
|
} // else axisNumWarps == 1 and srcValues.size() == 1, nothing to do.
|
|
|
|
Value results = getTypeConverter()->packLLElements(loc, srcValues, rewriter,
|
|
input.getType());
|
|
rewriter.replaceOp(op, results);
|
|
return success();
|
|
}
|
|
} // namespace
|
|
|
|
void populateScanOpToLLVMPatterns(
|
|
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
|
int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
|
ModuleAllocation &allocation,
|
|
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
|
PatternBenefit benefit) {
|
|
patterns.add<ScanOpConversion>(typeConverter, allocation, indexCacheInfo,
|
|
benefit);
|
|
}
|