mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Rename tl.reduction -> tl.reduce and improve testing (#1521)
`tl.reduction` is currently tested indirectly through the existing reduction operators, but it's good to have a direct test for the function itself. --------- Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -24,7 +24,7 @@ repos:
|
||||
rev: v1.6.0
|
||||
hooks:
|
||||
- id: autopep8
|
||||
args: ["-a", "-i", "--max-line-length", "88"]
|
||||
args: ["-i"]
|
||||
stages: [commit, push, manual]
|
||||
- repo: https://github.com/pycqa/flake8
|
||||
rev: 6.0.0
|
||||
|
||||
@@ -96,9 +96,13 @@ Reduction Ops
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
argmax
|
||||
argmin
|
||||
max
|
||||
min
|
||||
reduce
|
||||
sum
|
||||
xor_sum
|
||||
|
||||
|
||||
Atomic Ops
|
||||
|
||||
8
python/pyproject.toml
Normal file
8
python/pyproject.toml
Normal file
@@ -0,0 +1,8 @@
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=40.8.0", "wheel", "cmake>=3.18"]
|
||||
|
||||
[tool.autopep8]
|
||||
aggressive = 1
|
||||
ignore = "E501,E701,E731,W690"
|
||||
max_line_length = 88
|
||||
@@ -1,8 +0,0 @@
|
||||
[metadata]
|
||||
description_file = README.md
|
||||
|
||||
[pycodestyle]
|
||||
ignore = E501,E701,E731
|
||||
|
||||
[flake8]
|
||||
ignore = E501,E701,E731
|
||||
@@ -1334,6 +1334,43 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
|
||||
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
|
||||
delta = mean_2 - mean_1
|
||||
new_weight = weight_1 + weight_2
|
||||
w2_over_w = weight_2 / new_weight
|
||||
return (
|
||||
mean_1 + delta * w2_over_w,
|
||||
m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,
|
||||
new_weight,
|
||||
)
|
||||
|
||||
|
||||
def test_generic_reduction(device='cuda'):
|
||||
|
||||
@triton.jit
|
||||
def var_mean_kernel(X, out_mean, out_var, BLOCK: tl.constexpr):
|
||||
xindex = tl.arange(0, BLOCK)
|
||||
x = tl.load(X + xindex)
|
||||
mean = x
|
||||
m2 = tl.zeros_like(x)
|
||||
weight = tl.full(x.shape, 1, x.dtype)
|
||||
(mean, m2, weight) = tl.reduce((mean, m2, weight), 0, _welford_combine)
|
||||
tl.store(out_mean, mean)
|
||||
tl.store(out_var, m2 / weight)
|
||||
|
||||
SIZE = 512
|
||||
x = torch.rand(SIZE, device=device)
|
||||
out_mean = torch.empty((), device=device)
|
||||
out_var = torch.empty((), device=device)
|
||||
|
||||
var_mean_kernel[(1,)](x, out_mean, out_var, BLOCK=SIZE)
|
||||
|
||||
expect_var, expect_mean = torch.var_mean(x, dim=0, correction=0)
|
||||
torch.testing.assert_close(out_mean, expect_mean)
|
||||
torch.testing.assert_close(out_var, expect_var)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test permute
|
||||
# ---------------
|
||||
|
||||
@@ -57,6 +57,7 @@ from .core import (
|
||||
pointer_type,
|
||||
program_id,
|
||||
ravel,
|
||||
reduce,
|
||||
reshape,
|
||||
sigmoid,
|
||||
sin,
|
||||
@@ -164,6 +165,7 @@ __all__ = [
|
||||
"randn",
|
||||
"randn4x",
|
||||
"ravel",
|
||||
"reduce",
|
||||
"reshape",
|
||||
"sigmoid",
|
||||
"sin",
|
||||
|
||||
@@ -1199,7 +1199,7 @@ def _insertion_guard(builder):
|
||||
|
||||
|
||||
@builtin
|
||||
def reduction(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
def reduce(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
"""Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis`
|
||||
|
||||
:param input: the input tensor, or tuple of tensors
|
||||
@@ -1208,8 +1208,8 @@ def reduction(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
|
||||
"""
|
||||
if isinstance(input, tensor):
|
||||
return reduction((input,), axis, combine_fn,
|
||||
_builder=_builder, _generator=_generator)[0]
|
||||
return reduce((input,), axis, combine_fn,
|
||||
_builder=_builder, _generator=_generator)[0]
|
||||
|
||||
def make_combine_region(reduce_op):
|
||||
in_scalar_tys = [t.type.scalar for t in input]
|
||||
@@ -1261,8 +1261,8 @@ def _argreduce(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
index = index.__getitem__(expand_dims_index, _builder=_builder)
|
||||
index = broadcast_to(index, input.shape, _builder=_builder)
|
||||
|
||||
rvalue, rindices = reduction((input, index), axis, combine_fn,
|
||||
_builder=_builder, _generator=_generator)
|
||||
rvalue, rindices = reduce((input, index), axis, combine_fn,
|
||||
_builder=_builder, _generator=_generator)
|
||||
return rindices
|
||||
|
||||
|
||||
@@ -1275,7 +1275,7 @@ def _max_combine(a, b):
|
||||
@_add_reduction_docstr("maximum")
|
||||
def max(input, axis):
|
||||
input = _promote_reduction_input(input)
|
||||
return reduction(input, axis, _max_combine)
|
||||
return reduce(input, axis, _max_combine)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -1305,7 +1305,7 @@ def _min_combine(a, b):
|
||||
@_add_reduction_docstr("minimum")
|
||||
def min(input, axis):
|
||||
input = _promote_reduction_input(input)
|
||||
return reduction(input, axis, _min_combine)
|
||||
return reduce(input, axis, _min_combine)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -1334,7 +1334,7 @@ def _sum_combine(a, b):
|
||||
@_add_reduction_docstr("sum")
|
||||
def sum(input, axis):
|
||||
input = _promote_reduction_input(input)
|
||||
return reduction(input, axis, _sum_combine)
|
||||
return reduce(input, axis, _sum_combine)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -1350,8 +1350,8 @@ def xor_sum(input, axis, _builder=None, _generator=None):
|
||||
raise ValueError("xor_sum only supported for integers")
|
||||
|
||||
input = _promote_reduction_input(input, _builder=_builder)
|
||||
return reduction(input, axis, _xor_combine,
|
||||
_builder=_builder, _generator=_generator)
|
||||
return reduce(input, axis, _xor_combine,
|
||||
_builder=_builder, _generator=_generator)
|
||||
|
||||
|
||||
# -----------------------
|
||||
|
||||
Reference in New Issue
Block a user