mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
[BACKEND] Enable reduce with 3D tensors and added tests (#2460)
This commit is contained in:
@@ -1661,10 +1661,18 @@ reduce_configs2 = [
|
||||
for op in ['min', 'max', 'sum']
|
||||
]
|
||||
|
||||
reduce3d_shapes = [(2, 32, 16), (32, 2, 16), (32, 16, 2)]
|
||||
reduce_configs3 = [
|
||||
(op, 'float32', shape, axis)
|
||||
for op in ['min', 'max', 'sum', 'argmin', 'argmax']
|
||||
for shape in reduce3d_shapes
|
||||
for axis in [0, 1, 2]
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
|
||||
|
||||
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2 + reduce_configs3)
|
||||
@pytest.mark.parametrize("num_ctas", num_ctas_list)
|
||||
def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device):
|
||||
def test_reduce(op, dtype_str, shape, axis, num_ctas, device):
|
||||
check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
if is_hip():
|
||||
@@ -1672,17 +1680,31 @@ def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device):
|
||||
# triton kernel
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
|
||||
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, AXIS: tl.constexpr):
|
||||
range_m = tl.arange(0, BLOCK_M)
|
||||
range_n = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
|
||||
z = GENERATE_TEST_HERE
|
||||
if AXIS is None:
|
||||
tl.store(Z, z)
|
||||
elif AXIS == 1:
|
||||
tl.store(Z + range_m, z)
|
||||
range_k = tl.arange(0, BLOCK_K)
|
||||
if IS_3D:
|
||||
x = tl.load(X + range_m[:, None, None] * BLOCK_N * BLOCK_K + range_n[None, :, None] * BLOCK_K + range_k[None, None, :])
|
||||
else:
|
||||
tl.store(Z + range_n, z)
|
||||
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
|
||||
z = GENERATE_TEST_HERE
|
||||
if IS_3D:
|
||||
if AXIS is None:
|
||||
tl.store(Z, z)
|
||||
elif AXIS == 0:
|
||||
tl.store(Z + range_n[:, None] * BLOCK_K + range_k[None, :], z)
|
||||
elif AXIS == 1:
|
||||
tl.store(Z + range_m[:, None] * BLOCK_K + range_k[None, :], z)
|
||||
else:
|
||||
tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z)
|
||||
else:
|
||||
if AXIS is None:
|
||||
tl.store(Z, z)
|
||||
elif AXIS == 0:
|
||||
tl.store(Z + range_n, z)
|
||||
else:
|
||||
tl.store(Z + range_m, z)
|
||||
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'})
|
||||
# input
|
||||
@@ -1705,10 +1727,13 @@ def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device):
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
# triton result
|
||||
ret_numel = 1 if axis is None else shape[1 - axis]
|
||||
z_tri = to_triton(numpy_random((ret_numel,), dtype_str=z_dtype_str, rs=rs),
|
||||
z_shape = (1,) if axis is None else tuple(shape_i for i, shape_i in enumerate(shape) if i != axis)
|
||||
z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs),
|
||||
device=device, dst_type=z_tri_dtype_str)
|
||||
BLOCK_K = 1 if len(shape) == 2 else shape[2]
|
||||
IS_3D = bool(len(shape) == 3)
|
||||
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0],
|
||||
BLOCK_N=shape[1], AXIS=axis, num_ctas=num_ctas)
|
||||
BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, num_ctas=num_ctas)
|
||||
z_tri = to_numpy(z_tri)
|
||||
# compare
|
||||
if op == 'sum':
|
||||
|
||||
Reference in New Issue
Block a user