[BACKEND] Enable reduce with 3D tensors and added tests (#2460)

This commit is contained in:
Zahi Moudallal
2023-10-06 15:08:22 -07:00
committed by GitHub
parent a42d517021
commit be19cf3103
4 changed files with 60 additions and 26 deletions

View File

@@ -1661,10 +1661,18 @@ reduce_configs2 = [
for op in ['min', 'max', 'sum']
]
reduce3d_shapes = [(2, 32, 16), (32, 2, 16), (32, 16, 2)]
reduce_configs3 = [
(op, 'float32', shape, axis)
for op in ['min', 'max', 'sum', 'argmin', 'argmax']
for shape in reduce3d_shapes
for axis in [0, 1, 2]
]
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2 + reduce_configs3)
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device):
def test_reduce(op, dtype_str, shape, axis, num_ctas, device):
check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested
if is_hip():
@@ -1672,17 +1680,31 @@ def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device):
# triton kernel
@triton.jit
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, AXIS: tl.constexpr):
range_m = tl.arange(0, BLOCK_M)
range_n = tl.arange(0, BLOCK_N)
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
z = GENERATE_TEST_HERE
if AXIS is None:
tl.store(Z, z)
elif AXIS == 1:
tl.store(Z + range_m, z)
range_k = tl.arange(0, BLOCK_K)
if IS_3D:
x = tl.load(X + range_m[:, None, None] * BLOCK_N * BLOCK_K + range_n[None, :, None] * BLOCK_K + range_k[None, None, :])
else:
tl.store(Z + range_n, z)
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
z = GENERATE_TEST_HERE
if IS_3D:
if AXIS is None:
tl.store(Z, z)
elif AXIS == 0:
tl.store(Z + range_n[:, None] * BLOCK_K + range_k[None, :], z)
elif AXIS == 1:
tl.store(Z + range_m[:, None] * BLOCK_K + range_k[None, :], z)
else:
tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z)
else:
if AXIS is None:
tl.store(Z, z)
elif AXIS == 0:
tl.store(Z + range_n, z)
else:
tl.store(Z + range_m, z)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS)'})
# input
@@ -1705,10 +1727,13 @@ def test_reduce2d(op, dtype_str, shape, axis, num_ctas, device):
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
# triton result
ret_numel = 1 if axis is None else shape[1 - axis]
z_tri = to_triton(numpy_random((ret_numel,), dtype_str=z_dtype_str, rs=rs),
z_shape = (1,) if axis is None else tuple(shape_i for i, shape_i in enumerate(shape) if i != axis)
z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs),
device=device, dst_type=z_tri_dtype_str)
BLOCK_K = 1 if len(shape) == 2 else shape[2]
IS_3D = bool(len(shape) == 3)
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0],
BLOCK_N=shape[1], AXIS=axis, num_ctas=num_ctas)
BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, num_ctas=num_ctas)
z_tri = to_numpy(z_tri)
# compare
if op == 'sum':