[FRONTEND] Added default axis=None for reduction, which reduces across all the axes. (#1712)

This commit is contained in:
Philippe Tillet
2023-05-28 16:13:21 -07:00
committed by GitHub
parent 420e4acecc
commit 4e2f57add5
4 changed files with 23 additions and 9 deletions

View File

@@ -1379,6 +1379,9 @@ reduce_configs2 = [
for op in ['min', 'max', 'sum', 'argmin', 'argmax']
for shape in reduce2d_shapes
for axis in [0, 1]
] + [
(op, 'float32', [16, 32], None)
for op in ['min', 'max', 'sum']
]
@@ -1393,7 +1396,9 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
range_n = tl.arange(0, BLOCK_N)
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
z = GENERATE_TEST_HERE
if AXIS == 1:
if AXIS is None:
tl.store(Z, z)
elif AXIS == 1:
tl.store(Z + range_m, z)
else:
tl.store(Z + range_n, z)
@@ -1418,7 +1423,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
else:
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
# triton result
z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs),
ret_numel = 1 if axis is None else shape[1 - axis]
z_tri = to_triton(numpy_random((ret_numel,), dtype_str=z_dtype_str, rs=rs),
device=device, dst_type=z_tri_dtype_str)
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
z_tri = to_numpy(z_tri)

View File

@@ -388,7 +388,8 @@ class CodeGenerator(ast.NodeVisitor):
for name, value in zip(names, values):
# by default, constexpr are assigned into python variable
value = _unwrap_if_constexpr(value)
if not _is_triton_tensor(value) and \
if value is not None and \
not _is_triton_tensor(value) and \
not isinstance(value, native_nontensor_types):
value = language.core._to_tensor(value, self.builder)
self.set_value(name, value)

View File

@@ -1297,8 +1297,8 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None):
else:
handles = [r.handle for r in results]
_builder.create_reduce_ret(*handles)
axis = _constexpr_to_value(axis)
if axis is not None:
axis = _constexpr_to_value(axis)
return semantic.reduction(input, axis, make_combine_region, _builder)
@@ -1369,7 +1369,7 @@ def _max_combine(a, b):
@triton.jit
@_add_reduction_docstr("maximum")
def max(input, axis):
def max(input, axis=None):
input = _promote_reduction_input(input)
return reduce(input, axis, _max_combine)
@@ -1399,7 +1399,7 @@ def _min_combine(a, b):
@triton.jit
@_add_reduction_docstr("minimum")
def min(input, axis):
def min(input, axis=None):
input = _promote_reduction_input(input)
return reduce(input, axis, _min_combine)
@@ -1428,7 +1428,7 @@ def _sum_combine(a, b):
@triton.jit
@_add_reduction_docstr("sum")
def sum(input, axis):
def sum(input, axis=None):
input = _promote_reduction_input(input)
return reduce(input, axis, _sum_combine)
@@ -1440,7 +1440,7 @@ def _xor_combine(a, b):
@builtin
@_add_reduction_docstr("xor sum")
def xor_sum(input, axis, _builder=None, _generator=None):
def xor_sum(input, axis=None, _builder=None, _generator=None):
scalar_ty = input.type.scalar
if not scalar_ty.is_int():
raise ValueError("xor_sum only supported for integers")

View File

@@ -1236,6 +1236,13 @@ def where(condition: tl.tensor,
def reduction(
inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder
) -> Tuple[tl.tensor, ...]:
if axis is None:
new_inputs = []
for i in range(len(inputs)):
new_shape = [inputs[i].numel.value]
new_inputs.append(view(inputs[i], new_shape, builder))
inputs = tuple(new_inputs)
axis = 0
# get result shape
shape = inputs[0].type.shape
ret_shape = [s for i, s in enumerate(shape) if i != axis]