[BACKEND] Fix reductions when number of unique element is smaller than layout (#1913)

Fix calculation of unique number of threads within a warp. We need to
consider the number of elements per thread in the calculation. Also
change the layout test to integer sum in order to catch bugs with unique
data as max reduction may hide those kind of problems.
This commit is contained in:
Thomas
2023-07-07 19:48:13 -07:00
committed by GitHub
parent 778ed64a66
commit bd900e0a6f
5 changed files with 36 additions and 38 deletions

View File

@@ -49,7 +49,8 @@ SmallVector<unsigned> getContigPerThread(Attribute layout);
// for thread 0 would be [A_{0, 0}, A_{0, 0}, A_{0, 0}, A_{0, 0}], returns [1,
// 1]. Whereas for a tensor shape [128, 128], the elements for thread 0 would be
// [A_{0, 0}, A_{0, 1}, A_{0, 2}, A_{0, 3}], returns [1, 4].
SmallVector<unsigned> getUniqueContigPerThread(Type type);
SmallVector<unsigned> getUniqueContigPerThread(Attribute layout,
ArrayRef<int64_t> tensorShape);
// Returns the number of threads per warp that have access to non-replicated
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,

View File

@@ -918,7 +918,8 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
auto order = triton::gpu::getOrder(layout);
unsigned align = getPtrAlignment(ptr);
auto uniqueContigPerThread = triton::gpu::getUniqueContigPerThread(tensorTy);
auto uniqueContigPerThread =
triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape());
assert(order[0] < uniqueContigPerThread.size() &&
"Unxpected uniqueContigPerThread size");
unsigned contiguity = uniqueContigPerThread[order[0]];

View File

@@ -60,7 +60,9 @@ unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() {
unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
return std::min(srcReduceDimSize,
unsigned elementPerThreads = triton::gpu::getUniqueContigPerThread(
getSrcLayout(), getSrcShape())[axis];
return std::min(srcReduceDimSize / elementPerThreads,
triton::gpu::getThreadsPerWarpWithUniqueData(
getSrcLayout(), getSrcShape())[axis]);
}

View File

@@ -221,20 +221,15 @@ SmallVector<unsigned> getContigPerThread(Attribute layout) {
}
}
SmallVector<unsigned> getUniqueContigPerThread(Type type) {
if (type.isIntOrIndexOrFloat() || type.isa<triton::PointerType>())
return SmallVector<unsigned>(1, 1);
auto tensorType = type.cast<RankedTensorType>();
auto shape = tensorType.getShape();
SmallVector<unsigned> getUniqueContigPerThread(Attribute layout,
ArrayRef<int64_t> shape) {
// If slice layout, call recursively on parent layout, and drop
// sliced dim
if (auto sliceLayout =
tensorType.getEncoding().dyn_cast<SliceEncodingAttr>()) {
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(shape);
auto parentTy = RankedTensorType::get(
parentShape, tensorType.getElementType(), parentLayout);
auto parentUniqueContigPerThread = getUniqueContigPerThread(parentTy);
auto parentUniqueContigPerThread =
getUniqueContigPerThread(parentLayout, parentShape);
parentUniqueContigPerThread.erase(parentUniqueContigPerThread.begin() +
sliceLayout.getDim());
return parentUniqueContigPerThread;
@@ -242,7 +237,7 @@ SmallVector<unsigned> getUniqueContigPerThread(Type type) {
// Base case
auto rank = shape.size();
SmallVector<unsigned> ret(rank);
auto contigPerThread = getContigPerThread(tensorType.getEncoding());
auto contigPerThread = getContigPerThread(layout);
assert(contigPerThread.size() == rank && "Unexpected contigPerThread size");
for (int d = 0; d < rank; ++d) {
ret[d] = std::min<unsigned>(shape[d], contigPerThread[d]);

View File

@@ -1614,12 +1614,13 @@ def test_scan_layouts(M, N, src_layout, axis, device):
layouts = [
BlockedLayout([1, 4], [8, 4], [4, 1], [1, 0]),
BlockedLayout([1, 4], [8, 4], [4, 1], [0, 1]),
BlockedLayout([4, 4], [2, 16], [4, 1], [1, 0]),
MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
MmaLayout(version=(2, 0), warps_per_cta=[2, 2])
]
@pytest.mark.parametrize("M, N", [[128, 16], [128, 128], [32, 128]])
@pytest.mark.parametrize("M, N", [[128, 16], [128, 128], [32, 128], [32, 32]])
@pytest.mark.parametrize("src_layout", layouts)
@pytest.mark.parametrize("axis", [0, 1])
def test_reduce_layouts(M, N, src_layout, axis, device):
@@ -1630,31 +1631,30 @@ def test_reduce_layouts(M, N, src_layout, axis, device):
#blocked = #triton_gpu.blocked<{{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}}>
#src = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}) {{
tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked>
%2 = tt.splat %arg1 : (i32) -> tensor<{M}x1xi32, #blocked>
%3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked>
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<{M}x1x!tt.ptr<f32>, #blocked>
%5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<f32>, #blocked>, tensor<{M}x1xi32, #blocked>
%4 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x1x!tt.ptr<i32>, #blocked>
%5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<i32>, #blocked>, tensor<{M}x1xi32, #blocked>
%6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>
%7 = tt.expand_dims %6 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked>
%8 = tt.broadcast %5 : (tensor<{M}x1x!tt.ptr<f32>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<f32>, #blocked>
%8 = tt.broadcast %5 : (tensor<{M}x1x!tt.ptr<i32>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
%9 = tt.broadcast %7 : (tensor<1x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
%10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<f32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
%11 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<{rdims_2d}x!tt.ptr<f32>, #blocked>
%12 = tt.addptr %11, {store_range} : tensor<{rdims_2d}x!tt.ptr<f32>, #blocked>, tensor<{rdims_2d}xi32, #blocked>
%13 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf32, #blocked>
%14 = triton_gpu.convert_layout %13 : (tensor<{M}x{N}xf32, #blocked>) -> tensor<{M}x{N}xf32, #src>
%10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
%11 = tt.splat %arg2 : (!tt.ptr<i32>) -> tensor<{rdims_2d}x!tt.ptr<i32>, #blocked>
%12 = tt.addptr %11, {store_range} : tensor<{rdims_2d}x!tt.ptr<i32>, #blocked>, tensor<{rdims_2d}xi32, #blocked>
%13 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #blocked>
%14 = triton_gpu.convert_layout %13 : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #src>
%15 = "tt.reduce"(%14) ({{
^bb0(%arg3: f32, %arg4: f32):
%16 = "triton_gpu.cmpf"(%arg3, %arg4) {{predicate = 2 : i64}} : (f32, f32) -> i1
%17 = arith.select %16, %arg3, %arg4 : f32
tt.reduce.return %17 : f32
}}) {{axis = {axis} : i32}} : (tensor<{M}x{N}xf32, #src>) -> tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>
%18 = triton_gpu.convert_layout %15 : (tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>
%19 = tt.expand_dims %18 {{axis = {axis} : i32}} : (tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}xf32, #blocked>
tt.store %12, %19 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{rdims_2d}xf32, #blocked>
^bb0(%arg3: i32, %arg4: i32):
%17 = arith.addi %arg3, %arg4 : i32
tt.reduce.return %17 : i32
}}) {{axis = {axis} : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{rdims_1d}xi32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>
%18 = triton_gpu.convert_layout %15 : (tensor<{rdims_1d}xi32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}xi32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>
%19 = tt.expand_dims %18 {{axis = {axis} : i32}} : (tensor<{rdims_1d}xi32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}xi32, #blocked>
tt.store %12, %19 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{rdims_2d}xi32, #blocked>
tt.return
}}
}}
@@ -1667,22 +1667,21 @@ def test_reduce_layouts(M, N, src_layout, axis, device):
kernel = triton.compile(f.name)
rs = RandomState(17)
x = rs.randint(0, 4, (M, N)).astype('float32')
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
x = rs.randint(0, 20, (M, N)).astype('int32')
if axis == 0:
z = np.zeros((1, N)).astype('float32')
z = np.zeros((1, N)).astype('int32')
else:
z = np.zeros((M, 1)).astype('float32')
z = np.zeros((M, 1)).astype('int32')
x_tri = torch.tensor(x, device=device)
z_tri = torch.tensor(z, device=device)
pgm = kernel[(1, 1, 4)](x_tri, x_tri.stride(0), z_tri)
z_ref = np.max(x, axis=axis, keepdims=True)
z_ref = np.sum(x, axis=axis, keepdims=True)
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
np.testing.assert_equal(z_ref, z_tri.cpu().numpy())
layouts = [