[TEST] Fixed and re-enabled reduce test (#1644)

Re-enabled reduce test after fixing the %cst stride in the ttgir, and
modifying the sweep parameters to make sure the shape per CTA to be less
than or equal to the tensor shape.
This commit is contained in:
Zahi Moudallal
2023-05-10 15:15:11 -07:00
committed by GitHub
parent 147ec4384d
commit fb40bf1954
4 changed files with 78 additions and 63 deletions

View File

@@ -39,6 +39,10 @@ public:
unsigned getIntraWarpSize();
unsigned getInterWarpSizeWithUniqueData();
unsigned getIntraWarpSizeWithUniqueData();
unsigned getThreadsReductionAxis();
SmallVector<unsigned> getScratchConfigBasic();

View File

@@ -16,12 +16,25 @@ bool ReduceOpHelper::isFastReduction() {
unsigned ReduceOpHelper::getInterWarpSize() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned sizeIntraWarps = getIntraWarpSize();
return std::min(srcReduceDimSize / sizeIntraWarps,
triton::gpu::getWarpsPerCTA(getSrcLayout())[axis]);
}
unsigned ReduceOpHelper::getIntraWarpSize() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
return std::min(srcReduceDimSize,
triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]);
}
unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned sizeIntraWarps = getIntraWarpSizeWithUniqueData();
return std::min(srcReduceDimSize / sizeIntraWarps,
triton::gpu::getWarpsPerCTAWithUniqueData(
getSrcLayout(), getSrcShape())[axis]);
}
unsigned ReduceOpHelper::getIntraWarpSize() {
unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
return std::min(srcReduceDimSize,
triton::gpu::getThreadsPerWarpWithUniqueData(

View File

@@ -336,8 +336,8 @@ private:
elemPtrTys[i]);
}
unsigned sizeIntraWarps = helper.getIntraWarpSize();
unsigned sizeInterWarps = helper.getInterWarpSize();
unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData();
unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData();
unsigned srcElems = getTotalElemsPerThread(srcTys[0]);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);

View File

@@ -1605,68 +1605,66 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
)
# layouts = [
# BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
# BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
# BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0]),
# BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1])
# ]
layouts = [
BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0]),
BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1])
]
# @pytest.mark.parametrize("M, N", [[32, 128], [128, 128], [128, 32]])
# @pytest.mark.parametrize("src_layout", layouts)
# def test_reduce_2d(M, N, src_layout, device='cuda'):
# ir = f"""
# #src = {src_layout}
# module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
# tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
# %cst = arith.constant dense<{M}> : tensor<{M}x1xi32, #src>
# %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
# %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
# %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src>
# %3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
# %4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
# %5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
# %6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
# %7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src>
# %8 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x{N}x!tt.ptr<i32>, #src>
# %9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr<i32>, #src>, tensor<{M}x{N}xi32, #src>
# %10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src>
# %11 = "tt.reduce"(%10) ({{
# ^bb0(%arg2: i32, %arg3: i32):
# %13 = arith.addi %arg2, %arg3 : i32
# tt.reduce.return %13 : i32
# }}) {{axis = 1 : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
# %12 = "tt.reduce"(%11) ({{
# ^bb0(%arg2: i32, %arg3: i32):
# %13 = arith.addi %arg2, %arg3 : i32
# tt.reduce.return %13 : i32
# }}) {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> i32
# tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32
# tt.return
# }}
# }}
# """
# import tempfile
# with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
# f.write(ir)
# f.flush()
# kernel = triton.compile(f.name)
#
# rs = RandomState(17)
# x = rs.randint(0, 4, (M, N)).astype('int32')
# x = (x.view('uint32') & np.uint32(0xffffe000)).view('int32')
#
# z = np.zeros((1,)).astype('int32')
#
# x_tri = torch.tensor(x, device=device)
# z_tri = torch.tensor(z, device=device)
#
# pgm = kernel[(1, 1, 1)](x_tri, z_tri)
#
# z_ref = np.sum(x)
#
# np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
@pytest.mark.parametrize("M, N", [[128, 128], [256, 128], [256, 256], [128, 256]])
@pytest.mark.parametrize("src_layout", layouts)
def test_chain_reduce(M, N, src_layout, device='cuda'):
ir = f"""
#src = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src>
%3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
%4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
%5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src>
%8 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x{N}x!tt.ptr<i32>, #src>
%9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr<i32>, #src>, tensor<{M}x{N}xi32, #src>
%10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src>
%11 = "tt.reduce"(%10) ({{
^bb0(%arg2: i32, %arg3: i32):
%13 = arith.addi %arg2, %arg3 : i32
tt.reduce.return %13 : i32
}}) {{axis = 1 : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%12 = "tt.reduce"(%11) ({{
^bb0(%arg2: i32, %arg3: i32):
%13 = arith.addi %arg2, %arg3 : i32
tt.reduce.return %13 : i32
}}) {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> i32
tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32
tt.return
}}
}}
"""
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)
rs = RandomState(17)
x = rs.randint(0, 4, (M, N)).astype('int32')
z = np.zeros((1,)).astype('int32')
x_tri = torch.tensor(x, device=device)
z_tri = torch.tensor(z, device=device)
pgm = kernel[(1, 1, 1)](x_tri, z_tri)
z_ref = np.sum(x)
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
def test_generic_reduction(device='cuda'):