mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Fix reductions when number of unique element is smaller than layout (#1913)
Fix calculation of unique number of threads within a warp. We need to consider the number of elements per thread in the calculation. Also change the layout test to integer sum in order to catch bugs with unique data as max reduction may hide those kind of problems.
This commit is contained in:
@@ -49,7 +49,8 @@ SmallVector<unsigned> getContigPerThread(Attribute layout);
|
||||
// for thread 0 would be [A_{0, 0}, A_{0, 0}, A_{0, 0}, A_{0, 0}], returns [1,
|
||||
// 1]. Whereas for a tensor shape [128, 128], the elements for thread 0 would be
|
||||
// [A_{0, 0}, A_{0, 1}, A_{0, 2}, A_{0, 3}], returns [1, 4].
|
||||
SmallVector<unsigned> getUniqueContigPerThread(Type type);
|
||||
SmallVector<unsigned> getUniqueContigPerThread(Attribute layout,
|
||||
ArrayRef<int64_t> tensorShape);
|
||||
|
||||
// Returns the number of threads per warp that have access to non-replicated
|
||||
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,
|
||||
|
||||
@@ -918,7 +918,8 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
|
||||
auto order = triton::gpu::getOrder(layout);
|
||||
unsigned align = getPtrAlignment(ptr);
|
||||
|
||||
auto uniqueContigPerThread = triton::gpu::getUniqueContigPerThread(tensorTy);
|
||||
auto uniqueContigPerThread =
|
||||
triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape());
|
||||
assert(order[0] < uniqueContigPerThread.size() &&
|
||||
"Unxpected uniqueContigPerThread size");
|
||||
unsigned contiguity = uniqueContigPerThread[order[0]];
|
||||
|
||||
@@ -60,7 +60,9 @@ unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() {
|
||||
|
||||
unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() {
|
||||
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
||||
return std::min(srcReduceDimSize,
|
||||
unsigned elementPerThreads = triton::gpu::getUniqueContigPerThread(
|
||||
getSrcLayout(), getSrcShape())[axis];
|
||||
return std::min(srcReduceDimSize / elementPerThreads,
|
||||
triton::gpu::getThreadsPerWarpWithUniqueData(
|
||||
getSrcLayout(), getSrcShape())[axis]);
|
||||
}
|
||||
|
||||
@@ -221,20 +221,15 @@ SmallVector<unsigned> getContigPerThread(Attribute layout) {
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getUniqueContigPerThread(Type type) {
|
||||
if (type.isIntOrIndexOrFloat() || type.isa<triton::PointerType>())
|
||||
return SmallVector<unsigned>(1, 1);
|
||||
auto tensorType = type.cast<RankedTensorType>();
|
||||
auto shape = tensorType.getShape();
|
||||
SmallVector<unsigned> getUniqueContigPerThread(Attribute layout,
|
||||
ArrayRef<int64_t> shape) {
|
||||
// If slice layout, call recursively on parent layout, and drop
|
||||
// sliced dim
|
||||
if (auto sliceLayout =
|
||||
tensorType.getEncoding().dyn_cast<SliceEncodingAttr>()) {
|
||||
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
auto parentLayout = sliceLayout.getParent();
|
||||
auto parentShape = sliceLayout.paddedShape(shape);
|
||||
auto parentTy = RankedTensorType::get(
|
||||
parentShape, tensorType.getElementType(), parentLayout);
|
||||
auto parentUniqueContigPerThread = getUniqueContigPerThread(parentTy);
|
||||
auto parentUniqueContigPerThread =
|
||||
getUniqueContigPerThread(parentLayout, parentShape);
|
||||
parentUniqueContigPerThread.erase(parentUniqueContigPerThread.begin() +
|
||||
sliceLayout.getDim());
|
||||
return parentUniqueContigPerThread;
|
||||
@@ -242,7 +237,7 @@ SmallVector<unsigned> getUniqueContigPerThread(Type type) {
|
||||
// Base case
|
||||
auto rank = shape.size();
|
||||
SmallVector<unsigned> ret(rank);
|
||||
auto contigPerThread = getContigPerThread(tensorType.getEncoding());
|
||||
auto contigPerThread = getContigPerThread(layout);
|
||||
assert(contigPerThread.size() == rank && "Unexpected contigPerThread size");
|
||||
for (int d = 0; d < rank; ++d) {
|
||||
ret[d] = std::min<unsigned>(shape[d], contigPerThread[d]);
|
||||
|
||||
@@ -1614,12 +1614,13 @@ def test_scan_layouts(M, N, src_layout, axis, device):
|
||||
layouts = [
|
||||
BlockedLayout([1, 4], [8, 4], [4, 1], [1, 0]),
|
||||
BlockedLayout([1, 4], [8, 4], [4, 1], [0, 1]),
|
||||
BlockedLayout([4, 4], [2, 16], [4, 1], [1, 0]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[2, 2])
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N", [[128, 16], [128, 128], [32, 128]])
|
||||
@pytest.mark.parametrize("M, N", [[128, 16], [128, 128], [32, 128], [32, 32]])
|
||||
@pytest.mark.parametrize("src_layout", layouts)
|
||||
@pytest.mark.parametrize("axis", [0, 1])
|
||||
def test_reduce_layouts(M, N, src_layout, axis, device):
|
||||
@@ -1630,31 +1631,30 @@ def test_reduce_layouts(M, N, src_layout, axis, device):
|
||||
#blocked = #triton_gpu.blocked<{{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}}>
|
||||
#src = {src_layout}
|
||||
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
|
||||
tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}) {{
|
||||
tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
|
||||
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>
|
||||
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked>
|
||||
%2 = tt.splat %arg1 : (i32) -> tensor<{M}x1xi32, #blocked>
|
||||
%3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<{M}x1x!tt.ptr<f32>, #blocked>
|
||||
%5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<f32>, #blocked>, tensor<{M}x1xi32, #blocked>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x1x!tt.ptr<i32>, #blocked>
|
||||
%5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<i32>, #blocked>, tensor<{M}x1xi32, #blocked>
|
||||
%6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>
|
||||
%7 = tt.expand_dims %6 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked>
|
||||
%8 = tt.broadcast %5 : (tensor<{M}x1x!tt.ptr<f32>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<f32>, #blocked>
|
||||
%8 = tt.broadcast %5 : (tensor<{M}x1x!tt.ptr<i32>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
|
||||
%9 = tt.broadcast %7 : (tensor<1x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
|
||||
%10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<f32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
|
||||
%11 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<{rdims_2d}x!tt.ptr<f32>, #blocked>
|
||||
%12 = tt.addptr %11, {store_range} : tensor<{rdims_2d}x!tt.ptr<f32>, #blocked>, tensor<{rdims_2d}xi32, #blocked>
|
||||
%13 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf32, #blocked>
|
||||
%14 = triton_gpu.convert_layout %13 : (tensor<{M}x{N}xf32, #blocked>) -> tensor<{M}x{N}xf32, #src>
|
||||
%10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
|
||||
%11 = tt.splat %arg2 : (!tt.ptr<i32>) -> tensor<{rdims_2d}x!tt.ptr<i32>, #blocked>
|
||||
%12 = tt.addptr %11, {store_range} : tensor<{rdims_2d}x!tt.ptr<i32>, #blocked>, tensor<{rdims_2d}xi32, #blocked>
|
||||
%13 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #blocked>
|
||||
%14 = triton_gpu.convert_layout %13 : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #src>
|
||||
%15 = "tt.reduce"(%14) ({{
|
||||
^bb0(%arg3: f32, %arg4: f32):
|
||||
%16 = "triton_gpu.cmpf"(%arg3, %arg4) {{predicate = 2 : i64}} : (f32, f32) -> i1
|
||||
%17 = arith.select %16, %arg3, %arg4 : f32
|
||||
tt.reduce.return %17 : f32
|
||||
}}) {{axis = {axis} : i32}} : (tensor<{M}x{N}xf32, #src>) -> tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>
|
||||
%18 = triton_gpu.convert_layout %15 : (tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>
|
||||
%19 = tt.expand_dims %18 {{axis = {axis} : i32}} : (tensor<{rdims_1d}xf32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}xf32, #blocked>
|
||||
tt.store %12, %19 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{rdims_2d}xf32, #blocked>
|
||||
^bb0(%arg3: i32, %arg4: i32):
|
||||
%17 = arith.addi %arg3, %arg4 : i32
|
||||
tt.reduce.return %17 : i32
|
||||
}}) {{axis = {axis} : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{rdims_1d}xi32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>
|
||||
%18 = triton_gpu.convert_layout %15 : (tensor<{rdims_1d}xi32, #triton_gpu.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}xi32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>
|
||||
%19 = tt.expand_dims %18 {{axis = {axis} : i32}} : (tensor<{rdims_1d}xi32, #triton_gpu.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}xi32, #blocked>
|
||||
tt.store %12, %19 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{rdims_2d}xi32, #blocked>
|
||||
tt.return
|
||||
}}
|
||||
}}
|
||||
@@ -1667,22 +1667,21 @@ def test_reduce_layouts(M, N, src_layout, axis, device):
|
||||
kernel = triton.compile(f.name)
|
||||
|
||||
rs = RandomState(17)
|
||||
x = rs.randint(0, 4, (M, N)).astype('float32')
|
||||
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
x = rs.randint(0, 20, (M, N)).astype('int32')
|
||||
|
||||
if axis == 0:
|
||||
z = np.zeros((1, N)).astype('float32')
|
||||
z = np.zeros((1, N)).astype('int32')
|
||||
else:
|
||||
z = np.zeros((M, 1)).astype('float32')
|
||||
z = np.zeros((M, 1)).astype('int32')
|
||||
|
||||
x_tri = torch.tensor(x, device=device)
|
||||
z_tri = torch.tensor(z, device=device)
|
||||
|
||||
pgm = kernel[(1, 1, 4)](x_tri, x_tri.stride(0), z_tri)
|
||||
|
||||
z_ref = np.max(x, axis=axis, keepdims=True)
|
||||
z_ref = np.sum(x, axis=axis, keepdims=True)
|
||||
|
||||
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
|
||||
np.testing.assert_equal(z_ref, z_tri.cpu().numpy())
|
||||
|
||||
|
||||
layouts = [
|
||||
|
||||
Reference in New Issue
Block a user