[FRONTEND] Rename tl.reduction -> tl.reduce and improve testing (#1521)

`tl.reduction` is currently tested indirectly through the existing
reduction operators, but it's good to have a direct test for the
function itself.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
peterbell10
2023-04-14 21:35:31 +00:00
committed by GitHub
parent bfd1f65ac7
commit 0d76c4ca95
7 changed files with 62 additions and 19 deletions

View File

@@ -24,7 +24,7 @@ repos:
rev: v1.6.0
hooks:
- id: autopep8
args: ["-a", "-i", "--max-line-length", "88"]
args: ["-i"]
stages: [commit, push, manual]
- repo: https://github.com/pycqa/flake8
rev: 6.0.0

View File

@@ -96,9 +96,13 @@ Reduction Ops
:toctree: generated
:nosignatures:
argmax
argmin
max
min
reduce
sum
xor_sum
Atomic Ops

8
python/pyproject.toml Normal file
View File

@@ -0,0 +1,8 @@
[build-system]
requires = ["setuptools>=40.8.0", "wheel", "cmake>=3.18"]
[tool.autopep8]
aggressive = 1
ignore = "E501,E701,E731,W690"
max_line_length = 88

View File

@@ -1,8 +0,0 @@
[metadata]
description_file = README.md
[pycodestyle]
ignore = E501,E701,E731
[flake8]
ignore = E501,E701,E731

View File

@@ -1334,6 +1334,43 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
@triton.jit
def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
delta = mean_2 - mean_1
new_weight = weight_1 + weight_2
w2_over_w = weight_2 / new_weight
return (
mean_1 + delta * w2_over_w,
m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,
new_weight,
)
def test_generic_reduction(device='cuda'):
@triton.jit
def var_mean_kernel(X, out_mean, out_var, BLOCK: tl.constexpr):
xindex = tl.arange(0, BLOCK)
x = tl.load(X + xindex)
mean = x
m2 = tl.zeros_like(x)
weight = tl.full(x.shape, 1, x.dtype)
(mean, m2, weight) = tl.reduce((mean, m2, weight), 0, _welford_combine)
tl.store(out_mean, mean)
tl.store(out_var, m2 / weight)
SIZE = 512
x = torch.rand(SIZE, device=device)
out_mean = torch.empty((), device=device)
out_var = torch.empty((), device=device)
var_mean_kernel[(1,)](x, out_mean, out_var, BLOCK=SIZE)
expect_var, expect_mean = torch.var_mean(x, dim=0, correction=0)
torch.testing.assert_close(out_mean, expect_mean)
torch.testing.assert_close(out_var, expect_var)
# ---------------
# test permute
# ---------------

View File

@@ -57,6 +57,7 @@ from .core import (
pointer_type,
program_id,
ravel,
reduce,
reshape,
sigmoid,
sin,
@@ -164,6 +165,7 @@ __all__ = [
"randn",
"randn4x",
"ravel",
"reduce",
"reshape",
"sigmoid",
"sin",

View File

@@ -1199,7 +1199,7 @@ def _insertion_guard(builder):
@builtin
def reduction(input, axis, combine_fn, _builder=None, _generator=None):
def reduce(input, axis, combine_fn, _builder=None, _generator=None):
"""Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis`
:param input: the input tensor, or tuple of tensors
@@ -1208,8 +1208,8 @@ def reduction(input, axis, combine_fn, _builder=None, _generator=None):
"""
if isinstance(input, tensor):
return reduction((input,), axis, combine_fn,
_builder=_builder, _generator=_generator)[0]
return reduce((input,), axis, combine_fn,
_builder=_builder, _generator=_generator)[0]
def make_combine_region(reduce_op):
in_scalar_tys = [t.type.scalar for t in input]
@@ -1261,8 +1261,8 @@ def _argreduce(input, axis, combine_fn, _builder=None, _generator=None):
index = index.__getitem__(expand_dims_index, _builder=_builder)
index = broadcast_to(index, input.shape, _builder=_builder)
rvalue, rindices = reduction((input, index), axis, combine_fn,
_builder=_builder, _generator=_generator)
rvalue, rindices = reduce((input, index), axis, combine_fn,
_builder=_builder, _generator=_generator)
return rindices
@@ -1275,7 +1275,7 @@ def _max_combine(a, b):
@_add_reduction_docstr("maximum")
def max(input, axis):
input = _promote_reduction_input(input)
return reduction(input, axis, _max_combine)
return reduce(input, axis, _max_combine)
@triton.jit
@@ -1305,7 +1305,7 @@ def _min_combine(a, b):
@_add_reduction_docstr("minimum")
def min(input, axis):
input = _promote_reduction_input(input)
return reduction(input, axis, _min_combine)
return reduce(input, axis, _min_combine)
@triton.jit
@@ -1334,7 +1334,7 @@ def _sum_combine(a, b):
@_add_reduction_docstr("sum")
def sum(input, axis):
input = _promote_reduction_input(input)
return reduction(input, axis, _sum_combine)
return reduce(input, axis, _sum_combine)
@triton.jit
@@ -1350,8 +1350,8 @@ def xor_sum(input, axis, _builder=None, _generator=None):
raise ValueError("xor_sum only supported for integers")
input = _promote_reduction_input(input, _builder=_builder)
return reduction(input, axis, _xor_combine,
_builder=_builder, _generator=_generator)
return reduce(input, axis, _xor_combine,
_builder=_builder, _generator=_generator)
# -----------------------