mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Added default axis=None for reduction, which reduces across all the axes. (#1712)
This commit is contained in:
@@ -1379,6 +1379,9 @@ reduce_configs2 = [
|
||||
for op in ['min', 'max', 'sum', 'argmin', 'argmax']
|
||||
for shape in reduce2d_shapes
|
||||
for axis in [0, 1]
|
||||
] + [
|
||||
(op, 'float32', [16, 32], None)
|
||||
for op in ['min', 'max', 'sum']
|
||||
]
|
||||
|
||||
|
||||
@@ -1393,7 +1396,9 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
range_n = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
|
||||
z = GENERATE_TEST_HERE
|
||||
if AXIS == 1:
|
||||
if AXIS is None:
|
||||
tl.store(Z, z)
|
||||
elif AXIS == 1:
|
||||
tl.store(Z + range_m, z)
|
||||
else:
|
||||
tl.store(Z + range_n, z)
|
||||
@@ -1418,7 +1423,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
else:
|
||||
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
|
||||
# triton result
|
||||
z_tri = to_triton(numpy_random((shape[1 - axis],), dtype_str=z_dtype_str, rs=rs),
|
||||
ret_numel = 1 if axis is None else shape[1 - axis]
|
||||
z_tri = to_triton(numpy_random((ret_numel,), dtype_str=z_dtype_str, rs=rs),
|
||||
device=device, dst_type=z_tri_dtype_str)
|
||||
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis)
|
||||
z_tri = to_numpy(z_tri)
|
||||
|
||||
@@ -388,7 +388,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
for name, value in zip(names, values):
|
||||
# by default, constexpr are assigned into python variable
|
||||
value = _unwrap_if_constexpr(value)
|
||||
if not _is_triton_tensor(value) and \
|
||||
if value is not None and \
|
||||
not _is_triton_tensor(value) and \
|
||||
not isinstance(value, native_nontensor_types):
|
||||
value = language.core._to_tensor(value, self.builder)
|
||||
self.set_value(name, value)
|
||||
|
||||
@@ -1297,8 +1297,8 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
else:
|
||||
handles = [r.handle for r in results]
|
||||
_builder.create_reduce_ret(*handles)
|
||||
|
||||
axis = _constexpr_to_value(axis)
|
||||
if axis is not None:
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.reduction(input, axis, make_combine_region, _builder)
|
||||
|
||||
|
||||
@@ -1369,7 +1369,7 @@ def _max_combine(a, b):
|
||||
|
||||
@triton.jit
|
||||
@_add_reduction_docstr("maximum")
|
||||
def max(input, axis):
|
||||
def max(input, axis=None):
|
||||
input = _promote_reduction_input(input)
|
||||
return reduce(input, axis, _max_combine)
|
||||
|
||||
@@ -1399,7 +1399,7 @@ def _min_combine(a, b):
|
||||
|
||||
@triton.jit
|
||||
@_add_reduction_docstr("minimum")
|
||||
def min(input, axis):
|
||||
def min(input, axis=None):
|
||||
input = _promote_reduction_input(input)
|
||||
return reduce(input, axis, _min_combine)
|
||||
|
||||
@@ -1428,7 +1428,7 @@ def _sum_combine(a, b):
|
||||
|
||||
@triton.jit
|
||||
@_add_reduction_docstr("sum")
|
||||
def sum(input, axis):
|
||||
def sum(input, axis=None):
|
||||
input = _promote_reduction_input(input)
|
||||
return reduce(input, axis, _sum_combine)
|
||||
|
||||
@@ -1440,7 +1440,7 @@ def _xor_combine(a, b):
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("xor sum")
|
||||
def xor_sum(input, axis, _builder=None, _generator=None):
|
||||
def xor_sum(input, axis=None, _builder=None, _generator=None):
|
||||
scalar_ty = input.type.scalar
|
||||
if not scalar_ty.is_int():
|
||||
raise ValueError("xor_sum only supported for integers")
|
||||
|
||||
@@ -1236,6 +1236,13 @@ def where(condition: tl.tensor,
|
||||
def reduction(
|
||||
inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder
|
||||
) -> Tuple[tl.tensor, ...]:
|
||||
if axis is None:
|
||||
new_inputs = []
|
||||
for i in range(len(inputs)):
|
||||
new_shape = [inputs[i].numel.value]
|
||||
new_inputs.append(view(inputs[i], new_shape, builder))
|
||||
inputs = tuple(new_inputs)
|
||||
axis = 0
|
||||
# get result shape
|
||||
shape = inputs[0].type.shape
|
||||
ret_shape = [s for i, s in enumerate(shape) if i != axis]
|
||||
|
||||
Reference in New Issue
Block a user