mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TEST] Fixed and re-enabled reduce test (#1644)
Re-enabled reduce test after fixing the %cst stride in the ttgir, and modifying the sweep parameters to make sure the shape per CTA to be less than or equal to the tensor shape.
This commit is contained in:
@@ -39,6 +39,10 @@ public:
|
||||
|
||||
unsigned getIntraWarpSize();
|
||||
|
||||
unsigned getInterWarpSizeWithUniqueData();
|
||||
|
||||
unsigned getIntraWarpSizeWithUniqueData();
|
||||
|
||||
unsigned getThreadsReductionAxis();
|
||||
|
||||
SmallVector<unsigned> getScratchConfigBasic();
|
||||
|
||||
@@ -16,12 +16,25 @@ bool ReduceOpHelper::isFastReduction() {
|
||||
unsigned ReduceOpHelper::getInterWarpSize() {
|
||||
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
||||
unsigned sizeIntraWarps = getIntraWarpSize();
|
||||
return std::min(srcReduceDimSize / sizeIntraWarps,
|
||||
triton::gpu::getWarpsPerCTA(getSrcLayout())[axis]);
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getIntraWarpSize() {
|
||||
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
||||
return std::min(srcReduceDimSize,
|
||||
triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]);
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() {
|
||||
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
||||
unsigned sizeIntraWarps = getIntraWarpSizeWithUniqueData();
|
||||
return std::min(srcReduceDimSize / sizeIntraWarps,
|
||||
triton::gpu::getWarpsPerCTAWithUniqueData(
|
||||
getSrcLayout(), getSrcShape())[axis]);
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getIntraWarpSize() {
|
||||
unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() {
|
||||
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
||||
return std::min(srcReduceDimSize,
|
||||
triton::gpu::getThreadsPerWarpWithUniqueData(
|
||||
|
||||
@@ -336,8 +336,8 @@ private:
|
||||
elemPtrTys[i]);
|
||||
}
|
||||
|
||||
unsigned sizeIntraWarps = helper.getIntraWarpSize();
|
||||
unsigned sizeInterWarps = helper.getInterWarpSize();
|
||||
unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData();
|
||||
unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData();
|
||||
|
||||
unsigned srcElems = getTotalElemsPerThread(srcTys[0]);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);
|
||||
|
||||
@@ -1605,68 +1605,66 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
|
||||
)
|
||||
|
||||
|
||||
# layouts = [
|
||||
# BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
|
||||
# BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
|
||||
# BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0]),
|
||||
# BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1])
|
||||
# ]
|
||||
layouts = [
|
||||
BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
|
||||
BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
|
||||
BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0]),
|
||||
BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1])
|
||||
]
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("M, N", [[32, 128], [128, 128], [128, 32]])
|
||||
# @pytest.mark.parametrize("src_layout", layouts)
|
||||
# def test_reduce_2d(M, N, src_layout, device='cuda'):
|
||||
# ir = f"""
|
||||
# #src = {src_layout}
|
||||
# module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
|
||||
# tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
|
||||
# %cst = arith.constant dense<{M}> : tensor<{M}x1xi32, #src>
|
||||
# %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
||||
# %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
|
||||
# %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src>
|
||||
# %3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
|
||||
# %4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
|
||||
# %5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
|
||||
# %6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
|
||||
# %7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src>
|
||||
# %8 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x{N}x!tt.ptr<i32>, #src>
|
||||
# %9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr<i32>, #src>, tensor<{M}x{N}xi32, #src>
|
||||
# %10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src>
|
||||
# %11 = "tt.reduce"(%10) ({{
|
||||
# ^bb0(%arg2: i32, %arg3: i32):
|
||||
# %13 = arith.addi %arg2, %arg3 : i32
|
||||
# tt.reduce.return %13 : i32
|
||||
# }}) {{axis = 1 : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
||||
# %12 = "tt.reduce"(%11) ({{
|
||||
# ^bb0(%arg2: i32, %arg3: i32):
|
||||
# %13 = arith.addi %arg2, %arg3 : i32
|
||||
# tt.reduce.return %13 : i32
|
||||
# }}) {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> i32
|
||||
# tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32
|
||||
# tt.return
|
||||
# }}
|
||||
# }}
|
||||
# """
|
||||
# import tempfile
|
||||
# with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
|
||||
# f.write(ir)
|
||||
# f.flush()
|
||||
# kernel = triton.compile(f.name)
|
||||
#
|
||||
# rs = RandomState(17)
|
||||
# x = rs.randint(0, 4, (M, N)).astype('int32')
|
||||
# x = (x.view('uint32') & np.uint32(0xffffe000)).view('int32')
|
||||
#
|
||||
# z = np.zeros((1,)).astype('int32')
|
||||
#
|
||||
# x_tri = torch.tensor(x, device=device)
|
||||
# z_tri = torch.tensor(z, device=device)
|
||||
#
|
||||
# pgm = kernel[(1, 1, 1)](x_tri, z_tri)
|
||||
#
|
||||
# z_ref = np.sum(x)
|
||||
#
|
||||
# np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
|
||||
@pytest.mark.parametrize("M, N", [[128, 128], [256, 128], [256, 256], [128, 256]])
|
||||
@pytest.mark.parametrize("src_layout", layouts)
|
||||
def test_chain_reduce(M, N, src_layout, device='cuda'):
|
||||
ir = f"""
|
||||
#src = {src_layout}
|
||||
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
|
||||
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
|
||||
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src>
|
||||
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
||||
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
|
||||
%2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src>
|
||||
%3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
|
||||
%4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
|
||||
%5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
|
||||
%6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
|
||||
%7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src>
|
||||
%8 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x{N}x!tt.ptr<i32>, #src>
|
||||
%9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr<i32>, #src>, tensor<{M}x{N}xi32, #src>
|
||||
%10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src>
|
||||
%11 = "tt.reduce"(%10) ({{
|
||||
^bb0(%arg2: i32, %arg3: i32):
|
||||
%13 = arith.addi %arg2, %arg3 : i32
|
||||
tt.reduce.return %13 : i32
|
||||
}}) {{axis = 1 : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
||||
%12 = "tt.reduce"(%11) ({{
|
||||
^bb0(%arg2: i32, %arg3: i32):
|
||||
%13 = arith.addi %arg2, %arg3 : i32
|
||||
tt.reduce.return %13 : i32
|
||||
}}) {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> i32
|
||||
tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32
|
||||
tt.return
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
|
||||
f.write(ir)
|
||||
f.flush()
|
||||
kernel = triton.compile(f.name)
|
||||
|
||||
rs = RandomState(17)
|
||||
x = rs.randint(0, 4, (M, N)).astype('int32')
|
||||
|
||||
z = np.zeros((1,)).astype('int32')
|
||||
|
||||
x_tri = torch.tensor(x, device=device)
|
||||
z_tri = torch.tensor(z, device=device)
|
||||
|
||||
pgm = kernel[(1, 1, 1)](x_tri, z_tri)
|
||||
z_ref = np.sum(x)
|
||||
|
||||
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
|
||||
|
||||
|
||||
def test_generic_reduction(device='cuda'):
|
||||
|
||||
Reference in New Issue
Block a user