[BACKEND] Support scan on dimensions other that fastest moving one (#1863)

This relax the restriction in the scan lowering to support layout where
we scan along a dimension which isn't the fastest moving one. This is
done by relaxing how we accesses elements during scanning and allow
elements to be strided.
This commit is contained in:
Thomas
2023-06-30 12:40:48 -07:00
committed by GitHub
parent 66ed53d19d
commit 2e3182bab7
4 changed files with 200 additions and 127 deletions

View File

@@ -72,7 +72,7 @@ public:
// Return true if the lowering of the scan op is supported.
bool isSupported();
// Return the number of elements per thread along axis dim.
unsigned getAxisNumElementsPerThreads();
unsigned getAxisNumElementsPerThread();
// Return the number of elements per thread along non-axis dims.
unsigned getNonAxisNumElementsPerThread();
// Return the number of threads per warp along non-axis dims.
@@ -90,6 +90,13 @@ public:
// Return the size of the scratch space needed for scan lowering.
unsigned getScratchSizeInBytes();
// Stride between contiguous element along axis dim.
unsigned getAxisElementStride();
// Stride between contiguous threads along axis dim.
unsigned getAxisThreadStride();
// Stride between contiguous blocks along axis dim.
unsigned getAxisBlockStride();
Location getLoc() { return scanOp.getLoc(); }
unsigned getAxis() { return scanOp.getAxis(); }
triton::gpu::BlockedEncodingAttr getEncoding();

View File

@@ -117,13 +117,12 @@ bool ReduceOpHelper::isSupportedLayout() {
return false;
}
unsigned ScanLoweringHelper::getAxisNumElementsPerThreads() {
unsigned ScanLoweringHelper::getAxisNumElementsPerThread() {
return getEncoding().getSizePerThread()[getAxis()];
}
unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() {
SmallVector<unsigned> sizePerThreads(getEncoding().getSizePerThread().begin(),
getEncoding().getSizePerThread().end());
SmallVector<unsigned> sizePerThreads = getContigPerThread(getEncoding());
sizePerThreads[getAxis()] = 1;
return product<unsigned>(sizePerThreads);
}
@@ -159,8 +158,9 @@ unsigned ScanLoweringHelper::getAxisNumBlocks() {
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
unsigned axis = getAxis();
return type.getShape()[axis] /
(sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]);
return ceil<unsigned>(
type.getShape()[axis],
(sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
}
unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
@@ -173,19 +173,18 @@ unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
for (unsigned i = 0; i < sizePerThreads.size(); i++) {
if (i == axis)
continue;
numBlocks *= type.getShape()[i] /
(sizePerThreads[i] * threadsPerWarp[i] * warpsPerCTA[i]);
numBlocks *= ceil<unsigned>(
type.getShape()[i],
(sizePerThreads[i] * threadsPerWarp[i] * warpsPerCTA[i]));
}
return numBlocks;
}
bool ScanLoweringHelper::isSupported() {
// TODO: Support the following cases:
// 1. Scan on the non-fast changing dimension
// 2. Scan on non-blocking encodings
// 3. Scan with multiple operands
if (getAxis() != triton::gpu::getOrder(srcEncoding)[0] ||
!isa<triton::gpu::BlockedEncodingAttr>(srcEncoding))
// 1. Scan on non-blocking encodings
// 2. Scan with multiple operands
if (!isa<triton::gpu::BlockedEncodingAttr>(srcEncoding))
return false;
if (scanOp.getNumOperands() != 1)
return false;
@@ -194,16 +193,58 @@ bool ScanLoweringHelper::isSupported() {
unsigned ScanLoweringHelper::getScratchSizeInBytes() {
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
unsigned numElement =
type.getNumElements() * type.getElementTypeBitWidth() / 8;
return numElement /
(getAxisNumElementsPerThreads() * getAxisNumThreadsPerWarp());
unsigned elementSizeInBytes = type.getElementTypeBitWidth() / 8;
auto mod = scanOp->getParentOfType<ModuleOp>();
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
unsigned numNonAxisElementsPerWapr =
getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread();
unsigned numElements = numWarps * numNonAxisElementsPerWapr *
getAxisNumBlocks() * getNonAxisNumBlocks();
return elementSizeInBytes * numElements;
}
triton::gpu::BlockedEncodingAttr ScanLoweringHelper::getEncoding() {
return srcEncoding.cast<triton::gpu::BlockedEncodingAttr>();
}
unsigned ScanLoweringHelper::getAxisElementStride() {
auto order = triton::gpu::getOrder(srcEncoding);
unsigned stride = 1;
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= getContigPerThread(getEncoding())[dim];
}
llvm_unreachable("Axis not found in order");
}
unsigned ScanLoweringHelper::getAxisThreadStride() {
auto order = triton::gpu::getOrder(srcEncoding);
unsigned stride = 1;
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= getEncoding().getThreadsPerWarp()[dim];
}
llvm_unreachable("Axis not found in order");
}
unsigned ScanLoweringHelper::getAxisBlockStride() {
auto order = triton::gpu::getOrder(srcEncoding);
unsigned stride = 1;
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
auto sizePerThreads = triton::gpu::getSizePerThread(srcEncoding);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= type.getShape()[dim] /
(sizePerThreads[dim] * threadsPerWarp[dim] * warpsPerCTA[dim]);
}
llvm_unreachable("Axis not found in order");
}
bool maybeSharedAllocationOp(Operation *op) {
// TODO(Keren): This function can be replaced by adding
// MemoryEffectOpInterface. We can then use the MemoryEffectOpInterface to

View File

@@ -37,19 +37,21 @@ static void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp,
static void scanThreadContiguousElements(SmallVector<Value> &srcValues,
ConversionPatternRewriter &rewriter,
ScanLoweringHelper &helper) {
// TODO: this assumes that axis is the fastest moving dimension. We should
// relax that.
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThreads();
// Loop through the blocks of contiguous elements.
for (unsigned j = 0; j < srcValues.size(); j += scanElementsPerThreads) {
// Reset the accumulator at the beginning of each block of contiguous
// elements.
Value acc;
// Loop through the contiguous elements.
for (unsigned i = 0; i < scanElementsPerThreads; ++i) {
accumulate(rewriter, helper.getCombineOp(), acc, srcValues[i + j]);
srcValues[i + j] = acc;
}
// Depending on layout contiguous elements along axis dim may not be
// contiguous in srcValues. Keep track of what elements belong to the same
// chunk of contiguous elements.
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
unsigned parallelElementsPerThread = helper.getAxisNumElementsPerThread();
unsigned numChunks = srcValues.size() / scanElementsPerThreads;
unsigned stride = helper.getAxisElementStride();
SmallVector<Value> accs(numChunks);
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
unsigned accIndex = (srcIndex % stride) +
((srcIndex / stride) / scanElementsPerThreads) * stride;
accumulate(rewriter, helper.getCombineOp(), accs[accIndex],
srcValues[srcIndex]);
srcValues[srcIndex] = accs[accIndex];
}
}
@@ -59,20 +61,25 @@ static void warpScan(SmallVector<Value> &srcValues,
ConversionPatternRewriter &rewriter,
ScanLoweringHelper &helper, Value laneId) {
Location loc = helper.getLoc();
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThreads();
for (unsigned j = scanElementsPerThreads - 1; j < srcValues.size();
j += scanElementsPerThreads) {
Value acc = srcValues[j];
unsigned scanDim = helper.getAxisNumThreadsPerWarp();
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
unsigned elementStride = helper.getAxisElementStride();
unsigned threadStride = helper.getAxisThreadStride();
unsigned scanDim = helper.getAxisNumThreadsPerWarp();
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads;
// Only consider the last element of each contiguous chunk of elements.
if (elementIdx != scanElementsPerThreads - 1)
continue;
// Reduce within warps.
for (unsigned i = 1; i <= scanDim / 2; i = i << 1) {
Value shfl = shflUpSync(loc, rewriter, acc, i);
Value acc = srcValues[srcIndex];
for (unsigned i = 1; i <= (scanDim) / 2; i = i << 1) {
Value shfl = shflUpSync(loc, rewriter, acc, i * threadStride);
Value tempAcc = acc;
accumulate(rewriter, helper.getCombineOp(), tempAcc, shfl);
Value mask = icmp_slt(laneId, i32_val(i));
acc = select(mask, acc, tempAcc);
}
srcValues[j] = acc;
srcValues[srcIndex] = acc;
}
}
@@ -88,19 +95,24 @@ static void storeWarpAccumulator(SmallVector<Value> &srcValues,
Value warpId, Value baseSharedMemPtr,
Value parallelLaneId) {
Location loc = helper.getLoc();
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThreads();
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
unsigned scanDim = helper.getAxisNumThreadsPerWarp();
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA();
unsigned numWarps = helper.getAxisNumWarps();
unsigned chunkId = 0;
for (unsigned j = scanElementsPerThreads - 1; j < srcValues.size();
j += scanElementsPerThreads, ++chunkId) {
Value lastElement = srcValues[j];
unsigned elementStride = helper.getAxisElementStride();
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads;
// Only consider the last element of each contiguous chunk of elements.
if (elementIdx != scanElementsPerThreads - 1)
continue;
Value lastElement = srcValues[srcIndex];
Value mask = icmp_eq(laneId, i32_val(scanDim - 1));
Value index = add(parallelLaneId, mul(warpId, i32_val(numParallelLane)));
index = add(index, i32_val(chunkId * numParallelLane * numWarps));
Value writePtr = gep(baseSharedMemPtr.getType(), baseSharedMemPtr, index);
storeShared(rewriter, loc, writePtr, lastElement, mask);
chunkId++;
}
}
@@ -116,8 +128,10 @@ static void AddPartialReduce(SmallVector<Value> &srcValues,
Location loc = helper.getLoc();
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA();
unsigned numWarps = helper.getAxisNumWarps();
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThreads();
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread();
unsigned elementStride = helper.getAxisElementStride();
unsigned threadStride = helper.getAxisThreadStride();
Value maskFirstWarp = icmp_eq(warpId, i32_val(0));
Value maskFirstLane = icmp_eq(laneId, i32_val(0));
Value maskFirstThread = and_(maskFirstWarp, maskFirstLane);
@@ -133,63 +147,66 @@ static void AddPartialReduce(SmallVector<Value> &srcValues,
SmallVector<Accumulator> accumulators(numParallelBlocks *
parallelElementsPerThread);
unsigned chunkId = 0;
for (unsigned parallelBlockId = 0; parallelBlockId < numParallelBlocks;
++parallelBlockId) {
for (unsigned scanBlockId = 0; scanBlockId < numScanBlocks; ++scanBlockId) {
for (unsigned parallelElementId = 0;
parallelElementId < parallelElementsPerThread; ++parallelElementId) {
unsigned accumulatorIndex =
parallelElementId + parallelBlockId * parallelElementsPerThread;
Accumulator &accumulator = accumulators[accumulatorIndex];
for (unsigned i = 0; i < numWarps; ++i) {
Value index = add(parallelLaneId, i32_val(numParallelLane *
(i + chunkId * numWarps)));
Value ptr = gep(sharedMemoryPtr.getType(), sharedMemoryPtr, index);
Value partialReduce = load(ptr);
if (!accumulator.acc) {
accumulator.acc = partialReduce;
accumulator.maskedAcc = partialReduce;
continue;
}
accumulate(rewriter, helper.getCombineOp(), accumulator.acc,
partialReduce);
Value mask = icmp_slt(warpId, i32_val(i + 1));
accumulator.maskedAcc =
select(mask, accumulator.maskedAcc, accumulator.acc);
}
unsigned lastElementIndex =
chunkId * scanElementsPerThreads + scanElementsPerThreads - 1;
Value temp = srcValues[lastElementIndex];
accumulate(rewriter, helper.getCombineOp(), temp,
accumulator.maskedAcc);
if (scanBlockId == 0) {
// For the first warp and first chunk we don't have anything to
// accumulate.
temp = select(maskFirstWarp, srcValues[lastElementIndex], temp);
}
srcValues[lastElementIndex] = temp;
// Update the rest of the contiguous elements.
Value lastElement =
shflUpSync(loc, rewriter, srcValues[lastElementIndex], 1);
lastElement = select(maskFirstLane, accumulator.maskedAcc, lastElement);
for (unsigned i = 1; i < scanElementsPerThreads; ++i) {
Value laneValue = srcValues[lastElementIndex - i];
accumulate(rewriter, helper.getCombineOp(), laneValue, lastElement);
if (scanBlockId == 0) {
// For the first warp and first chunk we don't have anything to
// accumulate.
laneValue = select(maskFirstThread, srcValues[lastElementIndex - i],
laneValue);
}
srcValues[lastElementIndex - i] = laneValue;
}
// For the next chunk start back from the value containing the
// accumulated value of all the warps.
accumulator.maskedAcc = accumulator.acc;
chunkId++;
unsigned blockStride = helper.getAxisBlockStride();
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads;
// Only consider the last element of each contiguous chunk of elements.
if (elementIdx != scanElementsPerThreads - 1)
continue;
// Accumulate the partial reduction from shared memory. Decide which
// accumulator to combine based on whether the elements belong to the same
// dimension along axis.
unsigned blockId = chunkId / parallelElementsPerThread;
unsigned parallelBlockId =
blockId % blockStride +
((blockId / blockStride) / numScanBlocks) * blockStride;
unsigned accumulatorIndex = chunkId % parallelElementsPerThread +
parallelBlockId * parallelElementsPerThread;
Accumulator &accumulator = accumulators[accumulatorIndex];
for (unsigned i = 0; i < numWarps; ++i) {
Value index = add(parallelLaneId,
i32_val(numParallelLane * (i + chunkId * numWarps)));
Value ptr = gep(sharedMemoryPtr.getType(), sharedMemoryPtr, index);
Value partialReduce = load(ptr);
if (!accumulator.acc) {
accumulator.acc = partialReduce;
accumulator.maskedAcc = partialReduce;
continue;
}
accumulate(rewriter, helper.getCombineOp(), accumulator.acc,
partialReduce);
Value mask = icmp_slt(warpId, i32_val(i + 1));
accumulator.maskedAcc =
select(mask, accumulator.maskedAcc, accumulator.acc);
}
Value temp = srcValues[srcIndex];
accumulate(rewriter, helper.getCombineOp(), temp, accumulator.maskedAcc);
unsigned axisBlockId = (blockId / blockStride) % numScanBlocks;
if (axisBlockId == 0) {
// For the first warp and first chunk we don't have anything to
// accumulate.
temp = select(maskFirstWarp, srcValues[srcIndex], temp);
}
srcValues[srcIndex] = temp;
// Update the rest of the contiguous elements.
Value lastElement =
shflUpSync(loc, rewriter, srcValues[srcIndex], threadStride);
lastElement = select(maskFirstLane, accumulator.maskedAcc, lastElement);
for (unsigned i = 1; i < scanElementsPerThreads; ++i) {
Value laneValue = srcValues[srcIndex - i * elementStride];
accumulate(rewriter, helper.getCombineOp(), laneValue, lastElement);
if (axisBlockId == 0) {
// For the first warp and first chunk we don't have anything to
// accumulate.
laneValue = select(maskFirstThread,
srcValues[srcIndex - i * elementStride], laneValue);
}
srcValues[srcIndex - i * elementStride] = laneValue;
}
// For the next chunk start back from the value containing the
// accumulated value of all the warps.
accumulator.maskedAcc = accumulator.acc;
chunkId++;
}
}

View File

@@ -1526,18 +1526,20 @@ def test_reduce2d(op, dtype_str, shape, axis, device):
np.testing.assert_equal(z_ref, z_tri)
scan2d_shapes = [(16, 32), (32, 16), (2, 1024), (1024, 2), (32, 32), (1, 1024)]
scan2d_shapes = [(8, 32), (16, 32), (32, 16), (2, 1024), (1024, 2), (32, 32), (1, 1024)]
scan_configs = [
(op, type, shape, 1)
(op, type, shape, axis, num_warps)
for num_warps in [4, 16]
for type in ['int32', 'float32']
for axis in [1, 0]
for shape in scan2d_shapes
for op in ['cumsum']
]
@pytest.mark.parametrize("op, dtype_str, shape, axis", scan_configs)
def test_scan2d(op, dtype_str, shape, axis, device):
@pytest.mark.parametrize("op, dtype_str, shape, axis, num_warps", scan_configs)
def test_scan2d(op, dtype_str, shape, axis, num_warps, device):
check_type_supported(dtype_str, device)
# triton kernel
@@ -1549,7 +1551,7 @@ def test_scan2d(op, dtype_str, shape, axis, device):
z = GENERATE_TEST_HERE
tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=1)'})
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis={axis})'})
# input
rs = RandomState(17)
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
@@ -1560,7 +1562,7 @@ def test_scan2d(op, dtype_str, shape, axis, device):
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
# triton result
z_tri = to_triton(z, device=device)
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps)
z_tri = to_numpy(z_tri)
# compare
if dtype_str == 'float32':
@@ -1570,6 +1572,12 @@ def test_scan2d(op, dtype_str, shape, axis, device):
scan_layouts = [
BlockedLayout([1, 4], [4, 8], [4, 1], [0, 1]),
BlockedLayout([1, 4], [8, 4], [4, 1], [0, 1]),
BlockedLayout([4, 1], [4, 8], [1, 4], [0, 1]),
BlockedLayout([2, 2], [4, 8], [2, 2], [0, 1]),
BlockedLayout([2, 2], [8, 4], [2, 2], [0, 1]),
BlockedLayout([1, 4], [4, 8], [4, 1], [1, 0]),
BlockedLayout([1, 4], [8, 4], [4, 1], [1, 0]),
BlockedLayout([4, 1], [4, 8], [1, 4], [1, 0]),
@@ -1578,34 +1586,36 @@ scan_layouts = [
]
@pytest.mark.parametrize("M, N", [[32, 32], [32, 64], [64, 32]])
@pytest.mark.parametrize("src_layout", scan_layouts)
def test_scan_layouts(src_layout, device):
@pytest.mark.parametrize("axis", [0, 1])
def test_scan_layouts(M, N, src_layout, axis, device):
ir = f"""
#blocked = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{
tt.func public @kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<32> : tensor<32x1xi32, #blocked>
%0 = tt.make_range {{end = 32 : i32, start = 0 : i32}} : tensor<32xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<32xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<32x1xi32, #blocked>
%2 = arith.muli %1, %cst : tensor<32x1xi32, #blocked>
%3 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<32x1x!tt.ptr<i32>, #blocked>
%4 = tt.addptr %3, %2 : tensor<32x1x!tt.ptr<i32>, #blocked>, tensor<32x1xi32, #blocked>
%5 = tt.make_range {{end = 32 : i32, start = 0 : i32}} : tensor<32xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>
%6 = tt.expand_dims %5 {{axis = 0 : i32}} : (tensor<32xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x32xi32, #blocked>
%7 = tt.broadcast %4 : (tensor<32x1x!tt.ptr<i32>, #blocked>) -> tensor<32x32x!tt.ptr<i32>, #blocked>
%8 = tt.broadcast %6 : (tensor<1x32xi32, #blocked>) -> tensor<32x32xi32, #blocked>
%9 = tt.addptr %7, %8 : tensor<32x32x!tt.ptr<i32>, #blocked>, tensor<32x32xi32, #blocked>
%10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<32x32xi32, #blocked>
%11 = "tt.scan"(%10) <{{axis = 1 : i32}}> ({{
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked>
%2 = arith.muli %1, %cst : tensor<{M}x1xi32, #blocked>
%3 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x1x!tt.ptr<i32>, #blocked>
%4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr<i32>, #blocked>, tensor<{M}x1xi32, #blocked>
%5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>
%6 = tt.expand_dims %5 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked>
%7 = tt.broadcast %4 : (tensor<{M}x1x!tt.ptr<i32>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
%8 = tt.broadcast %6 : (tensor<1x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
%9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
%10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #blocked>
%11 = "tt.scan"(%10) <{{axis = {axis} : i32}}> ({{
^bb0(%arg2: i32, %arg3: i32):
%16 = arith.addi %arg2, %arg3 : i32
tt.scan.return %16 : i32
}}) : (tensor<32x32xi32, #blocked>) -> tensor<32x32xi32, #blocked>
%12 = tt.splat %arg1 : (!tt.ptr<i32>) -> tensor<32x1x!tt.ptr<i32>, #blocked>
%13 = tt.addptr %12, %2 : tensor<32x1x!tt.ptr<i32>, #blocked>, tensor<32x1xi32, #blocked>
%14 = tt.broadcast %13 : (tensor<32x1x!tt.ptr<i32>, #blocked>) -> tensor<32x32x!tt.ptr<i32>, #blocked>
%15 = tt.addptr %14, %8 : tensor<32x32x!tt.ptr<i32>, #blocked>, tensor<32x32xi32, #blocked>
tt.store %15, %11 {{cache = 1 : i32, evict = 1 : i32}} : tensor<32x32xi32, #blocked>
}}) : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
%12 = tt.splat %arg1 : (!tt.ptr<i32>) -> tensor<{M}x1x!tt.ptr<i32>, #blocked>
%13 = tt.addptr %12, %2 : tensor<{M}x1x!tt.ptr<i32>, #blocked>, tensor<{M}x1xi32, #blocked>
%14 = tt.broadcast %13 : (tensor<{M}x1x!tt.ptr<i32>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
%15 = tt.addptr %14, %8 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
tt.store %15, %11 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{M}x{N}xi32, #blocked>
tt.return
}}
}}
@@ -1616,8 +1626,6 @@ def test_scan_layouts(src_layout, device):
f.write(ir)
f.flush()
kernel = triton.compile(f.name)
M = 32
N = 32
rs = RandomState(17)
x = rs.randint(-100, 100, (M, N)).astype('int32')
@@ -1627,7 +1635,7 @@ def test_scan_layouts(src_layout, device):
kernel[(1, 1, 1)](x_tri, z_tri)
z_ref = np.cumsum(x, axis=1)
z_ref = np.cumsum(x, axis=axis)
np.testing.assert_equal(z_ref, z_tri.cpu().numpy())