mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Support scan on dimensions other that fastest moving one (#1863)
This relax the restriction in the scan lowering to support layout where we scan along a dimension which isn't the fastest moving one. This is done by relaxing how we accesses elements during scanning and allow elements to be strided.
This commit is contained in:
@@ -72,7 +72,7 @@ public:
|
||||
// Return true if the lowering of the scan op is supported.
|
||||
bool isSupported();
|
||||
// Return the number of elements per thread along axis dim.
|
||||
unsigned getAxisNumElementsPerThreads();
|
||||
unsigned getAxisNumElementsPerThread();
|
||||
// Return the number of elements per thread along non-axis dims.
|
||||
unsigned getNonAxisNumElementsPerThread();
|
||||
// Return the number of threads per warp along non-axis dims.
|
||||
@@ -90,6 +90,13 @@ public:
|
||||
// Return the size of the scratch space needed for scan lowering.
|
||||
unsigned getScratchSizeInBytes();
|
||||
|
||||
// Stride between contiguous element along axis dim.
|
||||
unsigned getAxisElementStride();
|
||||
// Stride between contiguous threads along axis dim.
|
||||
unsigned getAxisThreadStride();
|
||||
// Stride between contiguous blocks along axis dim.
|
||||
unsigned getAxisBlockStride();
|
||||
|
||||
Location getLoc() { return scanOp.getLoc(); }
|
||||
unsigned getAxis() { return scanOp.getAxis(); }
|
||||
triton::gpu::BlockedEncodingAttr getEncoding();
|
||||
|
||||
@@ -117,13 +117,12 @@ bool ReduceOpHelper::isSupportedLayout() {
|
||||
return false;
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getAxisNumElementsPerThreads() {
|
||||
unsigned ScanLoweringHelper::getAxisNumElementsPerThread() {
|
||||
return getEncoding().getSizePerThread()[getAxis()];
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() {
|
||||
SmallVector<unsigned> sizePerThreads(getEncoding().getSizePerThread().begin(),
|
||||
getEncoding().getSizePerThread().end());
|
||||
SmallVector<unsigned> sizePerThreads = getContigPerThread(getEncoding());
|
||||
sizePerThreads[getAxis()] = 1;
|
||||
return product<unsigned>(sizePerThreads);
|
||||
}
|
||||
@@ -159,8 +158,9 @@ unsigned ScanLoweringHelper::getAxisNumBlocks() {
|
||||
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
|
||||
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
|
||||
unsigned axis = getAxis();
|
||||
return type.getShape()[axis] /
|
||||
(sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]);
|
||||
return ceil<unsigned>(
|
||||
type.getShape()[axis],
|
||||
(sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
|
||||
@@ -173,19 +173,18 @@ unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
|
||||
for (unsigned i = 0; i < sizePerThreads.size(); i++) {
|
||||
if (i == axis)
|
||||
continue;
|
||||
numBlocks *= type.getShape()[i] /
|
||||
(sizePerThreads[i] * threadsPerWarp[i] * warpsPerCTA[i]);
|
||||
numBlocks *= ceil<unsigned>(
|
||||
type.getShape()[i],
|
||||
(sizePerThreads[i] * threadsPerWarp[i] * warpsPerCTA[i]));
|
||||
}
|
||||
return numBlocks;
|
||||
}
|
||||
|
||||
bool ScanLoweringHelper::isSupported() {
|
||||
// TODO: Support the following cases:
|
||||
// 1. Scan on the non-fast changing dimension
|
||||
// 2. Scan on non-blocking encodings
|
||||
// 3. Scan with multiple operands
|
||||
if (getAxis() != triton::gpu::getOrder(srcEncoding)[0] ||
|
||||
!isa<triton::gpu::BlockedEncodingAttr>(srcEncoding))
|
||||
// 1. Scan on non-blocking encodings
|
||||
// 2. Scan with multiple operands
|
||||
if (!isa<triton::gpu::BlockedEncodingAttr>(srcEncoding))
|
||||
return false;
|
||||
if (scanOp.getNumOperands() != 1)
|
||||
return false;
|
||||
@@ -194,16 +193,58 @@ bool ScanLoweringHelper::isSupported() {
|
||||
|
||||
unsigned ScanLoweringHelper::getScratchSizeInBytes() {
|
||||
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
|
||||
unsigned numElement =
|
||||
type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||
return numElement /
|
||||
(getAxisNumElementsPerThreads() * getAxisNumThreadsPerWarp());
|
||||
unsigned elementSizeInBytes = type.getElementTypeBitWidth() / 8;
|
||||
auto mod = scanOp->getParentOfType<ModuleOp>();
|
||||
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
unsigned numNonAxisElementsPerWapr =
|
||||
getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread();
|
||||
unsigned numElements = numWarps * numNonAxisElementsPerWapr *
|
||||
getAxisNumBlocks() * getNonAxisNumBlocks();
|
||||
return elementSizeInBytes * numElements;
|
||||
}
|
||||
|
||||
triton::gpu::BlockedEncodingAttr ScanLoweringHelper::getEncoding() {
|
||||
return srcEncoding.cast<triton::gpu::BlockedEncodingAttr>();
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getAxisElementStride() {
|
||||
auto order = triton::gpu::getOrder(srcEncoding);
|
||||
unsigned stride = 1;
|
||||
for (unsigned dim : order) {
|
||||
if (dim == getAxis())
|
||||
return stride;
|
||||
stride *= getContigPerThread(getEncoding())[dim];
|
||||
}
|
||||
llvm_unreachable("Axis not found in order");
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getAxisThreadStride() {
|
||||
auto order = triton::gpu::getOrder(srcEncoding);
|
||||
unsigned stride = 1;
|
||||
for (unsigned dim : order) {
|
||||
if (dim == getAxis())
|
||||
return stride;
|
||||
stride *= getEncoding().getThreadsPerWarp()[dim];
|
||||
}
|
||||
llvm_unreachable("Axis not found in order");
|
||||
}
|
||||
|
||||
unsigned ScanLoweringHelper::getAxisBlockStride() {
|
||||
auto order = triton::gpu::getOrder(srcEncoding);
|
||||
unsigned stride = 1;
|
||||
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto sizePerThreads = triton::gpu::getSizePerThread(srcEncoding);
|
||||
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
|
||||
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
|
||||
for (unsigned dim : order) {
|
||||
if (dim == getAxis())
|
||||
return stride;
|
||||
stride *= type.getShape()[dim] /
|
||||
(sizePerThreads[dim] * threadsPerWarp[dim] * warpsPerCTA[dim]);
|
||||
}
|
||||
llvm_unreachable("Axis not found in order");
|
||||
}
|
||||
|
||||
bool maybeSharedAllocationOp(Operation *op) {
|
||||
// TODO(Keren): This function can be replaced by adding
|
||||
// MemoryEffectOpInterface. We can then use the MemoryEffectOpInterface to
|
||||
|
||||
@@ -37,19 +37,21 @@ static void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp,
|
||||
static void scanThreadContiguousElements(SmallVector<Value> &srcValues,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
ScanLoweringHelper &helper) {
|
||||
// TODO: this assumes that axis is the fastest moving dimension. We should
|
||||
// relax that.
|
||||
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThreads();
|
||||
// Loop through the blocks of contiguous elements.
|
||||
for (unsigned j = 0; j < srcValues.size(); j += scanElementsPerThreads) {
|
||||
// Reset the accumulator at the beginning of each block of contiguous
|
||||
// elements.
|
||||
Value acc;
|
||||
// Loop through the contiguous elements.
|
||||
for (unsigned i = 0; i < scanElementsPerThreads; ++i) {
|
||||
accumulate(rewriter, helper.getCombineOp(), acc, srcValues[i + j]);
|
||||
srcValues[i + j] = acc;
|
||||
}
|
||||
// Depending on layout contiguous elements along axis dim may not be
|
||||
// contiguous in srcValues. Keep track of what elements belong to the same
|
||||
// chunk of contiguous elements.
|
||||
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
|
||||
unsigned parallelElementsPerThread = helper.getAxisNumElementsPerThread();
|
||||
unsigned numChunks = srcValues.size() / scanElementsPerThreads;
|
||||
unsigned stride = helper.getAxisElementStride();
|
||||
SmallVector<Value> accs(numChunks);
|
||||
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
|
||||
unsigned accIndex = (srcIndex % stride) +
|
||||
((srcIndex / stride) / scanElementsPerThreads) * stride;
|
||||
|
||||
accumulate(rewriter, helper.getCombineOp(), accs[accIndex],
|
||||
srcValues[srcIndex]);
|
||||
srcValues[srcIndex] = accs[accIndex];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,20 +61,25 @@ static void warpScan(SmallVector<Value> &srcValues,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
ScanLoweringHelper &helper, Value laneId) {
|
||||
Location loc = helper.getLoc();
|
||||
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThreads();
|
||||
for (unsigned j = scanElementsPerThreads - 1; j < srcValues.size();
|
||||
j += scanElementsPerThreads) {
|
||||
Value acc = srcValues[j];
|
||||
unsigned scanDim = helper.getAxisNumThreadsPerWarp();
|
||||
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
|
||||
unsigned elementStride = helper.getAxisElementStride();
|
||||
unsigned threadStride = helper.getAxisThreadStride();
|
||||
unsigned scanDim = helper.getAxisNumThreadsPerWarp();
|
||||
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
|
||||
unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads;
|
||||
// Only consider the last element of each contiguous chunk of elements.
|
||||
if (elementIdx != scanElementsPerThreads - 1)
|
||||
continue;
|
||||
// Reduce within warps.
|
||||
for (unsigned i = 1; i <= scanDim / 2; i = i << 1) {
|
||||
Value shfl = shflUpSync(loc, rewriter, acc, i);
|
||||
Value acc = srcValues[srcIndex];
|
||||
for (unsigned i = 1; i <= (scanDim) / 2; i = i << 1) {
|
||||
Value shfl = shflUpSync(loc, rewriter, acc, i * threadStride);
|
||||
Value tempAcc = acc;
|
||||
accumulate(rewriter, helper.getCombineOp(), tempAcc, shfl);
|
||||
Value mask = icmp_slt(laneId, i32_val(i));
|
||||
acc = select(mask, acc, tempAcc);
|
||||
}
|
||||
srcValues[j] = acc;
|
||||
srcValues[srcIndex] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,19 +95,24 @@ static void storeWarpAccumulator(SmallVector<Value> &srcValues,
|
||||
Value warpId, Value baseSharedMemPtr,
|
||||
Value parallelLaneId) {
|
||||
Location loc = helper.getLoc();
|
||||
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThreads();
|
||||
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
|
||||
unsigned scanDim = helper.getAxisNumThreadsPerWarp();
|
||||
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA();
|
||||
unsigned numWarps = helper.getAxisNumWarps();
|
||||
unsigned chunkId = 0;
|
||||
for (unsigned j = scanElementsPerThreads - 1; j < srcValues.size();
|
||||
j += scanElementsPerThreads, ++chunkId) {
|
||||
Value lastElement = srcValues[j];
|
||||
unsigned elementStride = helper.getAxisElementStride();
|
||||
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
|
||||
unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads;
|
||||
// Only consider the last element of each contiguous chunk of elements.
|
||||
if (elementIdx != scanElementsPerThreads - 1)
|
||||
continue;
|
||||
Value lastElement = srcValues[srcIndex];
|
||||
Value mask = icmp_eq(laneId, i32_val(scanDim - 1));
|
||||
Value index = add(parallelLaneId, mul(warpId, i32_val(numParallelLane)));
|
||||
index = add(index, i32_val(chunkId * numParallelLane * numWarps));
|
||||
Value writePtr = gep(baseSharedMemPtr.getType(), baseSharedMemPtr, index);
|
||||
storeShared(rewriter, loc, writePtr, lastElement, mask);
|
||||
chunkId++;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,8 +128,10 @@ static void AddPartialReduce(SmallVector<Value> &srcValues,
|
||||
Location loc = helper.getLoc();
|
||||
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA();
|
||||
unsigned numWarps = helper.getAxisNumWarps();
|
||||
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThreads();
|
||||
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
|
||||
unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread();
|
||||
unsigned elementStride = helper.getAxisElementStride();
|
||||
unsigned threadStride = helper.getAxisThreadStride();
|
||||
Value maskFirstWarp = icmp_eq(warpId, i32_val(0));
|
||||
Value maskFirstLane = icmp_eq(laneId, i32_val(0));
|
||||
Value maskFirstThread = and_(maskFirstWarp, maskFirstLane);
|
||||
@@ -133,63 +147,66 @@ static void AddPartialReduce(SmallVector<Value> &srcValues,
|
||||
SmallVector<Accumulator> accumulators(numParallelBlocks *
|
||||
parallelElementsPerThread);
|
||||
unsigned chunkId = 0;
|
||||
for (unsigned parallelBlockId = 0; parallelBlockId < numParallelBlocks;
|
||||
++parallelBlockId) {
|
||||
for (unsigned scanBlockId = 0; scanBlockId < numScanBlocks; ++scanBlockId) {
|
||||
for (unsigned parallelElementId = 0;
|
||||
parallelElementId < parallelElementsPerThread; ++parallelElementId) {
|
||||
unsigned accumulatorIndex =
|
||||
parallelElementId + parallelBlockId * parallelElementsPerThread;
|
||||
Accumulator &accumulator = accumulators[accumulatorIndex];
|
||||
for (unsigned i = 0; i < numWarps; ++i) {
|
||||
Value index = add(parallelLaneId, i32_val(numParallelLane *
|
||||
(i + chunkId * numWarps)));
|
||||
Value ptr = gep(sharedMemoryPtr.getType(), sharedMemoryPtr, index);
|
||||
Value partialReduce = load(ptr);
|
||||
if (!accumulator.acc) {
|
||||
accumulator.acc = partialReduce;
|
||||
accumulator.maskedAcc = partialReduce;
|
||||
continue;
|
||||
}
|
||||
accumulate(rewriter, helper.getCombineOp(), accumulator.acc,
|
||||
partialReduce);
|
||||
Value mask = icmp_slt(warpId, i32_val(i + 1));
|
||||
accumulator.maskedAcc =
|
||||
select(mask, accumulator.maskedAcc, accumulator.acc);
|
||||
}
|
||||
unsigned lastElementIndex =
|
||||
chunkId * scanElementsPerThreads + scanElementsPerThreads - 1;
|
||||
Value temp = srcValues[lastElementIndex];
|
||||
accumulate(rewriter, helper.getCombineOp(), temp,
|
||||
accumulator.maskedAcc);
|
||||
if (scanBlockId == 0) {
|
||||
// For the first warp and first chunk we don't have anything to
|
||||
// accumulate.
|
||||
temp = select(maskFirstWarp, srcValues[lastElementIndex], temp);
|
||||
}
|
||||
srcValues[lastElementIndex] = temp;
|
||||
|
||||
// Update the rest of the contiguous elements.
|
||||
Value lastElement =
|
||||
shflUpSync(loc, rewriter, srcValues[lastElementIndex], 1);
|
||||
lastElement = select(maskFirstLane, accumulator.maskedAcc, lastElement);
|
||||
for (unsigned i = 1; i < scanElementsPerThreads; ++i) {
|
||||
Value laneValue = srcValues[lastElementIndex - i];
|
||||
accumulate(rewriter, helper.getCombineOp(), laneValue, lastElement);
|
||||
if (scanBlockId == 0) {
|
||||
// For the first warp and first chunk we don't have anything to
|
||||
// accumulate.
|
||||
laneValue = select(maskFirstThread, srcValues[lastElementIndex - i],
|
||||
laneValue);
|
||||
}
|
||||
srcValues[lastElementIndex - i] = laneValue;
|
||||
}
|
||||
// For the next chunk start back from the value containing the
|
||||
// accumulated value of all the warps.
|
||||
accumulator.maskedAcc = accumulator.acc;
|
||||
chunkId++;
|
||||
unsigned blockStride = helper.getAxisBlockStride();
|
||||
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
|
||||
unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads;
|
||||
// Only consider the last element of each contiguous chunk of elements.
|
||||
if (elementIdx != scanElementsPerThreads - 1)
|
||||
continue;
|
||||
// Accumulate the partial reduction from shared memory. Decide which
|
||||
// accumulator to combine based on whether the elements belong to the same
|
||||
// dimension along axis.
|
||||
unsigned blockId = chunkId / parallelElementsPerThread;
|
||||
unsigned parallelBlockId =
|
||||
blockId % blockStride +
|
||||
((blockId / blockStride) / numScanBlocks) * blockStride;
|
||||
unsigned accumulatorIndex = chunkId % parallelElementsPerThread +
|
||||
parallelBlockId * parallelElementsPerThread;
|
||||
Accumulator &accumulator = accumulators[accumulatorIndex];
|
||||
for (unsigned i = 0; i < numWarps; ++i) {
|
||||
Value index = add(parallelLaneId,
|
||||
i32_val(numParallelLane * (i + chunkId * numWarps)));
|
||||
Value ptr = gep(sharedMemoryPtr.getType(), sharedMemoryPtr, index);
|
||||
Value partialReduce = load(ptr);
|
||||
if (!accumulator.acc) {
|
||||
accumulator.acc = partialReduce;
|
||||
accumulator.maskedAcc = partialReduce;
|
||||
continue;
|
||||
}
|
||||
accumulate(rewriter, helper.getCombineOp(), accumulator.acc,
|
||||
partialReduce);
|
||||
Value mask = icmp_slt(warpId, i32_val(i + 1));
|
||||
accumulator.maskedAcc =
|
||||
select(mask, accumulator.maskedAcc, accumulator.acc);
|
||||
}
|
||||
Value temp = srcValues[srcIndex];
|
||||
accumulate(rewriter, helper.getCombineOp(), temp, accumulator.maskedAcc);
|
||||
unsigned axisBlockId = (blockId / blockStride) % numScanBlocks;
|
||||
if (axisBlockId == 0) {
|
||||
// For the first warp and first chunk we don't have anything to
|
||||
// accumulate.
|
||||
temp = select(maskFirstWarp, srcValues[srcIndex], temp);
|
||||
}
|
||||
srcValues[srcIndex] = temp;
|
||||
// Update the rest of the contiguous elements.
|
||||
Value lastElement =
|
||||
shflUpSync(loc, rewriter, srcValues[srcIndex], threadStride);
|
||||
lastElement = select(maskFirstLane, accumulator.maskedAcc, lastElement);
|
||||
for (unsigned i = 1; i < scanElementsPerThreads; ++i) {
|
||||
Value laneValue = srcValues[srcIndex - i * elementStride];
|
||||
accumulate(rewriter, helper.getCombineOp(), laneValue, lastElement);
|
||||
if (axisBlockId == 0) {
|
||||
// For the first warp and first chunk we don't have anything to
|
||||
// accumulate.
|
||||
laneValue = select(maskFirstThread,
|
||||
srcValues[srcIndex - i * elementStride], laneValue);
|
||||
}
|
||||
srcValues[srcIndex - i * elementStride] = laneValue;
|
||||
}
|
||||
// For the next chunk start back from the value containing the
|
||||
// accumulated value of all the warps.
|
||||
accumulator.maskedAcc = accumulator.acc;
|
||||
chunkId++;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1526,18 +1526,20 @@ def test_reduce2d(op, dtype_str, shape, axis, device):
|
||||
np.testing.assert_equal(z_ref, z_tri)
|
||||
|
||||
|
||||
scan2d_shapes = [(16, 32), (32, 16), (2, 1024), (1024, 2), (32, 32), (1, 1024)]
|
||||
scan2d_shapes = [(8, 32), (16, 32), (32, 16), (2, 1024), (1024, 2), (32, 32), (1, 1024)]
|
||||
|
||||
scan_configs = [
|
||||
(op, type, shape, 1)
|
||||
(op, type, shape, axis, num_warps)
|
||||
for num_warps in [4, 16]
|
||||
for type in ['int32', 'float32']
|
||||
for axis in [1, 0]
|
||||
for shape in scan2d_shapes
|
||||
for op in ['cumsum']
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("op, dtype_str, shape, axis", scan_configs)
|
||||
def test_scan2d(op, dtype_str, shape, axis, device):
|
||||
@pytest.mark.parametrize("op, dtype_str, shape, axis, num_warps", scan_configs)
|
||||
def test_scan2d(op, dtype_str, shape, axis, num_warps, device):
|
||||
check_type_supported(dtype_str, device)
|
||||
|
||||
# triton kernel
|
||||
@@ -1549,7 +1551,7 @@ def test_scan2d(op, dtype_str, shape, axis, device):
|
||||
z = GENERATE_TEST_HERE
|
||||
tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z)
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=1)'})
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis={axis})'})
|
||||
# input
|
||||
rs = RandomState(17)
|
||||
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
|
||||
@@ -1560,7 +1562,7 @@ def test_scan2d(op, dtype_str, shape, axis, device):
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
# triton result
|
||||
z_tri = to_triton(z, device=device)
|
||||
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
||||
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps)
|
||||
z_tri = to_numpy(z_tri)
|
||||
# compare
|
||||
if dtype_str == 'float32':
|
||||
@@ -1570,6 +1572,12 @@ def test_scan2d(op, dtype_str, shape, axis, device):
|
||||
|
||||
|
||||
scan_layouts = [
|
||||
BlockedLayout([1, 4], [4, 8], [4, 1], [0, 1]),
|
||||
BlockedLayout([1, 4], [8, 4], [4, 1], [0, 1]),
|
||||
BlockedLayout([4, 1], [4, 8], [1, 4], [0, 1]),
|
||||
BlockedLayout([2, 2], [4, 8], [2, 2], [0, 1]),
|
||||
BlockedLayout([2, 2], [8, 4], [2, 2], [0, 1]),
|
||||
|
||||
BlockedLayout([1, 4], [4, 8], [4, 1], [1, 0]),
|
||||
BlockedLayout([1, 4], [8, 4], [4, 1], [1, 0]),
|
||||
BlockedLayout([4, 1], [4, 8], [1, 4], [1, 0]),
|
||||
@@ -1578,34 +1586,36 @@ scan_layouts = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N", [[32, 32], [32, 64], [64, 32]])
|
||||
@pytest.mark.parametrize("src_layout", scan_layouts)
|
||||
def test_scan_layouts(src_layout, device):
|
||||
@pytest.mark.parametrize("axis", [0, 1])
|
||||
def test_scan_layouts(M, N, src_layout, axis, device):
|
||||
ir = f"""
|
||||
#blocked = {src_layout}
|
||||
module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{
|
||||
tt.func public @kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
|
||||
%cst = arith.constant dense<32> : tensor<32x1xi32, #blocked>
|
||||
%0 = tt.make_range {{end = 32 : i32, start = 0 : i32}} : tensor<32xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>
|
||||
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<32xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<32x1xi32, #blocked>
|
||||
%2 = arith.muli %1, %cst : tensor<32x1xi32, #blocked>
|
||||
%3 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<32x1x!tt.ptr<i32>, #blocked>
|
||||
%4 = tt.addptr %3, %2 : tensor<32x1x!tt.ptr<i32>, #blocked>, tensor<32x1xi32, #blocked>
|
||||
%5 = tt.make_range {{end = 32 : i32, start = 0 : i32}} : tensor<32xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>
|
||||
%6 = tt.expand_dims %5 {{axis = 0 : i32}} : (tensor<32xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x32xi32, #blocked>
|
||||
%7 = tt.broadcast %4 : (tensor<32x1x!tt.ptr<i32>, #blocked>) -> tensor<32x32x!tt.ptr<i32>, #blocked>
|
||||
%8 = tt.broadcast %6 : (tensor<1x32xi32, #blocked>) -> tensor<32x32xi32, #blocked>
|
||||
%9 = tt.addptr %7, %8 : tensor<32x32x!tt.ptr<i32>, #blocked>, tensor<32x32xi32, #blocked>
|
||||
%10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<32x32xi32, #blocked>
|
||||
%11 = "tt.scan"(%10) <{{axis = 1 : i32}}> ({{
|
||||
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked>
|
||||
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>
|
||||
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked>
|
||||
%2 = arith.muli %1, %cst : tensor<{M}x1xi32, #blocked>
|
||||
%3 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x1x!tt.ptr<i32>, #blocked>
|
||||
%4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr<i32>, #blocked>, tensor<{M}x1xi32, #blocked>
|
||||
%5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>
|
||||
%6 = tt.expand_dims %5 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked>
|
||||
%7 = tt.broadcast %4 : (tensor<{M}x1x!tt.ptr<i32>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
|
||||
%8 = tt.broadcast %6 : (tensor<1x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
|
||||
%9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
|
||||
%10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #blocked>
|
||||
%11 = "tt.scan"(%10) <{{axis = {axis} : i32}}> ({{
|
||||
^bb0(%arg2: i32, %arg3: i32):
|
||||
%16 = arith.addi %arg2, %arg3 : i32
|
||||
tt.scan.return %16 : i32
|
||||
}}) : (tensor<32x32xi32, #blocked>) -> tensor<32x32xi32, #blocked>
|
||||
%12 = tt.splat %arg1 : (!tt.ptr<i32>) -> tensor<32x1x!tt.ptr<i32>, #blocked>
|
||||
%13 = tt.addptr %12, %2 : tensor<32x1x!tt.ptr<i32>, #blocked>, tensor<32x1xi32, #blocked>
|
||||
%14 = tt.broadcast %13 : (tensor<32x1x!tt.ptr<i32>, #blocked>) -> tensor<32x32x!tt.ptr<i32>, #blocked>
|
||||
%15 = tt.addptr %14, %8 : tensor<32x32x!tt.ptr<i32>, #blocked>, tensor<32x32xi32, #blocked>
|
||||
tt.store %15, %11 {{cache = 1 : i32, evict = 1 : i32}} : tensor<32x32xi32, #blocked>
|
||||
}}) : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
|
||||
%12 = tt.splat %arg1 : (!tt.ptr<i32>) -> tensor<{M}x1x!tt.ptr<i32>, #blocked>
|
||||
%13 = tt.addptr %12, %2 : tensor<{M}x1x!tt.ptr<i32>, #blocked>, tensor<{M}x1xi32, #blocked>
|
||||
%14 = tt.broadcast %13 : (tensor<{M}x1x!tt.ptr<i32>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
|
||||
%15 = tt.addptr %14, %8 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
|
||||
tt.store %15, %11 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{M}x{N}xi32, #blocked>
|
||||
tt.return
|
||||
}}
|
||||
}}
|
||||
@@ -1616,8 +1626,6 @@ def test_scan_layouts(src_layout, device):
|
||||
f.write(ir)
|
||||
f.flush()
|
||||
kernel = triton.compile(f.name)
|
||||
M = 32
|
||||
N = 32
|
||||
rs = RandomState(17)
|
||||
x = rs.randint(-100, 100, (M, N)).astype('int32')
|
||||
|
||||
@@ -1627,7 +1635,7 @@ def test_scan_layouts(src_layout, device):
|
||||
|
||||
kernel[(1, 1, 1)](x_tri, z_tri)
|
||||
|
||||
z_ref = np.cumsum(x, axis=1)
|
||||
z_ref = np.cumsum(x, axis=axis)
|
||||
|
||||
np.testing.assert_equal(z_ref, z_tri.cpu().numpy())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user