From be19cf31034e905578eeb792feb2da8615ef86d0 Mon Sep 17 00:00:00 2001 From: Zahi Moudallal <128723247+zahimoud@users.noreply.github.com> Date: Fri, 6 Oct 2023 15:08:22 -0700 Subject: [PATCH] [BACKEND] Enable reduce with 3D tensors and added tests (#2460) --- include/triton/Analysis/Utility.h | 2 + lib/Analysis/Utility.cpp | 21 ++++++-- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 14 ++---- python/test/unit/language/test_core.py | 49 ++++++++++++++----- 4 files changed, 60 insertions(+), 26 deletions(-) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 2bede3c93..081ab815a 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -54,6 +54,8 @@ public: SmallVector getScratchConfig(); + SmallVector getOrderWithAxisAtBeginning(); + unsigned getScratchSizeInBytes(); bool isSupportedLayout(); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index f4def9d59..ee0e3ff6f 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -38,6 +38,17 @@ bool ReduceOpHelper::isReductionOnLayoutFastAxis() { getParentOrder(getSrcLayout())[0]; } +SmallVector ReduceOpHelper::getOrderWithAxisAtBeginning() { + auto srcLayout = getSrcLayout(); + auto order = triton::gpu::getOrder(srcLayout); + auto it = std::find(order.begin(), order.end(), axis); + // delete the axis from order + order.erase(it); + // insert axis at the beginning of order + order.insert(order.begin(), axis); + return order; +} + // Thread offset is the thread index offset of two adjacent threads on the // reduction axis within the warp. unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { @@ -56,11 +67,11 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { threadOffset = threadsPerWarp[sliceLayout.getDim()]; } else { auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); - if (threadsPerWarp.size() == 1) { - threadOffset = 1; - } else { - assert(threadsPerWarp.size() == 2 && "Only supports 2D layouts"); - threadOffset = axis == 0 ? threadsPerWarp[1] : threadsPerWarp[0]; + auto order = triton::gpu::getOrder(srcLayout); + for (unsigned i = 0; i < order.size(); i++) { + if (order[i] == axis) + break; + threadOffset *= threadsPerWarp[order[i]]; } } return threadOffset; diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 914696b8d..da1527292 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -398,16 +398,15 @@ private: getMultiDimWarpId(helper, warpId, loc, rewriter); Value warpIdAxis = multiDimWarpId[axis]; - if (!helper.isReductionOnLayoutFastAxis()) { - std::reverse(order.begin(), order.end()); - } + auto smemOrder = helper.getOrderWithAxisAtBeginning(); for (auto it : accs) { const SmallVector &key = it.first; SmallVector &acc = it.second; SmallVector writeIdx = indices[key]; writeIdx[axis] = warpIdAxis; - Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, order); + Value writeOffset = + linearize(rewriter, loc, writeIdx, smemShape, smemOrder); for (unsigned i = 0; i < op.getNumOperands(); ++i) { auto elemPtrTy = getElementPtrType(op, i); Value writePtr = gep(elemPtrTy, smemBases[i], writeOffset); @@ -481,10 +480,7 @@ private: Location loc = op.getLoc(); auto srcLayout = helper.getSrcLayout(); auto axis = op.getAxis(); - auto order = getOrder(srcLayout); - if (!helper.isReductionOnLayoutFastAxis()) { - std::reverse(order.begin(), order.end()); - } + auto smemOrder = helper.getOrderWithAxisAtBeginning(); SmallVector results(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { if (auto resultTy = @@ -500,7 +496,7 @@ private: SmallVector readIdx = resultIndices[j]; readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0)); Value readOffset = - linearize(rewriter, loc, readIdx, smemShape, order); + linearize(rewriter, loc, readIdx, smemShape, smemOrder); Value readPtr = gep(getElementPtrType(op, i), smemBases[i], readOffset); resultVals[j] = load(readPtr); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index c7c76c6a9..a7517da82 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1661,10 +1661,18 @@ reduce_configs2 = [ for op in ['min', 'max', 'sum'] ] +reduce3d_shapes = [(2, 32, 16), (32, 2, 16), (32, 16, 2)] +reduce_configs3 = [ + (op, 'float32', shape, axis) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for shape in reduce3d_shapes + for axis in [0, 1, 2] +] -@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2) + +@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2 + reduce_configs3) @pytest.mark.parametrize("num_ctas", num_ctas_list) -def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device): +def test_reduce(op, dtype_str, shape, axis, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested if is_hip(): @@ -1672,17 +1680,31 @@ def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device): # triton kernel @triton.jit - def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): + def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, AXIS: tl.constexpr): range_m = tl.arange(0, BLOCK_M) range_n = tl.arange(0, BLOCK_N) - x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) - z = GENERATE_TEST_HERE - if AXIS is None: - tl.store(Z, z) - elif AXIS == 1: - tl.store(Z + range_m, z) + range_k = tl.arange(0, BLOCK_K) + if IS_3D: + x = tl.load(X + range_m[:, None, None] * BLOCK_N * BLOCK_K + range_n[None, :, None] * BLOCK_K + range_k[None, None, :]) else: - tl.store(Z + range_n, z) + x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) + z = GENERATE_TEST_HERE + if IS_3D: + if AXIS is None: + tl.store(Z, z) + elif AXIS == 0: + tl.store(Z + range_n[:, None] * BLOCK_K + range_k[None, :], z) + elif AXIS == 1: + tl.store(Z + range_m[:, None] * BLOCK_K + range_k[None, :], z) + else: + tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z) + else: + if AXIS is None: + tl.store(Z, z) + elif AXIS == 0: + tl.store(Z + range_n, z) + else: + tl.store(Z + range_m, z) kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'}) # input @@ -1705,10 +1727,13 @@ def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device): z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) # triton result ret_numel = 1 if axis is None else shape[1 - axis] - z_tri = to_triton(numpy_random((ret_numel,), dtype_str=z_dtype_str, rs=rs), + z_shape = (1,) if axis is None else tuple(shape_i for i, shape_i in enumerate(shape) if i != axis) + z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) + BLOCK_K = 1 if len(shape) == 2 else shape[2] + IS_3D = bool(len(shape) == 3) kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], - BLOCK_N=shape[1], AXIS=axis, num_ctas=num_ctas) + BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, num_ctas=num_ctas) z_tri = to_numpy(z_tri) # compare if op == 'sum':