diff --git a/docs/python-api/triton.language.rst b/docs/python-api/triton.language.rst index 6d520f0d6..e6ab2c716 100644 --- a/docs/python-api/triton.language.rst +++ b/docs/python-api/triton.language.rst @@ -111,6 +111,16 @@ Reduction Ops sum xor_sum +Scan Ops +------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + associative_scan + cumsum + cumprod Atomic Ops ---------- diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 7c9e44f75..a566cb07a 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1534,7 +1534,7 @@ scan_configs = [ for type in ['int32', 'float32'] for axis in [1, 0] for shape in scan2d_shapes - for op in ['cumsum'] + for op in ['cumsum', 'cumprod'] ] @@ -1557,7 +1557,7 @@ def test_scan2d(op, dtype_str, shape, axis, num_warps, device): x = numpy_random(shape, dtype_str=dtype_str, rs=rs) z = np.empty_like(x) x_tri = to_triton(x, device=device) - numpy_op = {'cumsum': np.cumsum}[op] + numpy_op = {'cumsum': np.cumsum, 'cumprod': np.cumprod}[op] z_dtype_str = dtype_str z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) # triton result @@ -1566,7 +1566,10 @@ def test_scan2d(op, dtype_str, shape, axis, num_warps, device): z_tri = to_numpy(z_tri) # compare if dtype_str == 'float32': - np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + if op == 'cumprod': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01, atol=1e-3) + else: + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) else: np.testing.assert_equal(z_ref, z_tri) diff --git a/python/triton/interpreter/tl_lang.py b/python/triton/interpreter/tl_lang.py index 50ca578a8..cbe601324 100644 --- a/python/triton/interpreter/tl_lang.py +++ b/python/triton/interpreter/tl_lang.py @@ -629,3 +629,9 @@ class TritonLangProxy: if axis is None: return torch.cumsum(input) return torch.cumsum(input, dim=axis) + + @_tensor_operation + def cumprod(self, input, axis=None): + if axis is None: + return torch.cumprod(input) + return torch.cumprod(input, dim=axis) diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index a279c624e..b9ed276fc 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -35,6 +35,7 @@ from .core import ( cat, constexpr, cos, + cumprod, cumsum, debug_barrier, device_assert, @@ -130,6 +131,7 @@ __all__ = [ "cdiv", "constexpr", "cos", + "cumprod", "cumsum", "debug_barrier", "device_assert", diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 746bc4703..c10cb55f5 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1562,6 +1562,20 @@ def cumsum(input, axis=0): input = _promote_reduction_input(input) return associative_scan(input, axis, _sum_combine) +# cumprod + + +@jit +def _prod_combine(a, b): + return a * b + + +@jit +@_add_scan_docstr("cumprod") +def cumprod(input, axis=0): + # todo rename this to a generic function name + input = _promote_reduction_input(input) + return associative_scan(input, axis, _prod_combine) # ----------------------- # Compiler Hint Ops