From 74dbb2fc0a30b0e602adbdaba01ed3ddd09e70c4 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Sun, 21 May 2023 15:18:48 -0400 Subject: [PATCH] [DOCS] Add missing ops and corresponding comments (#1699) --- docs/python-api/triton.language.rst | 33 ++++++++++++++++++++++-- python/triton/language/core.py | 40 +++++++++++++++++++++++++++-- 2 files changed, 69 insertions(+), 4 deletions(-) diff --git a/docs/python-api/triton.language.rst b/docs/python-api/triton.language.rst index 5013a0242..6d520f0d6 100644 --- a/docs/python-api/triton.language.rst +++ b/docs/python-api/triton.language.rst @@ -23,6 +23,8 @@ Creation Ops :nosignatures: arange + cat + full zeros @@ -33,11 +35,13 @@ Shape Manipulation Ops :toctree: generated :nosignatures: + broadcast broadcast_to expand_dims - reshape ravel - + reshape + trans + view Linear Algebra Ops @@ -83,11 +87,13 @@ Math Ops abs exp log + fdiv cos sin sqrt sigmoid softmax + umulhi Reduction Ops @@ -151,4 +157,27 @@ Compiler Hint Ops :toctree: generated :nosignatures: + debug_barrier + max_contiguous multiple_of + +Debug Ops +----------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + static_print + static_assert + device_print + device_assert + +Iterators +----------------- + +.. autosummary:: + :toctree: generated + :nosignatures: + + static_range diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 8764cd42a..a3c460996 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1451,12 +1451,15 @@ def xor_sum(input, axis, _builder=None, _generator=None): # ----------------------- -# Internal for debugging +# Compiler Hint Ops # ----------------------- @builtin def debug_barrier(_builder=None): + ''' + Insert a barrier to synchronize all threads in a block. + ''' return semantic.debug_barrier(_builder) @@ -1498,16 +1501,28 @@ def max_contiguous(input, values, _builder=None): @builtin def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _builder=None): + ''' + Print the values at compile time. The parameters are the same as the builtin :code:`print`. + ''' pass @builtin def static_assert(cond, msg="", _builder=None): + ''' + Assert the condition at compile time. The parameters are the same as the builtin :code:`assert`. + ''' pass @builtin def device_print(prefix, *args, _builder=None): + ''' + Print the values at runtime from the device. + + :param prefix: a prefix to print before the values. This is required to be a string literal. + :param args: the values to print. They can be any tensor or scalar. + ''' import string prefix = _constexpr_to_value(prefix) assert isinstance(prefix, str), f"{prefix} is not string" @@ -1525,6 +1540,12 @@ def device_print(prefix, *args, _builder=None): @builtin def device_assert(cond, msg="", _builder=None): + ''' + Assert the condition at runtime from the device. + + :param cond: the condition to assert. This is required to be a boolean tensor. + :param msg: the message to print if the assertion fails. This is required to be a string literal. + ''' msg = _constexpr_to_value(msg) import inspect frame = inspect.currentframe() @@ -1550,7 +1571,22 @@ def device_assert(cond, msg="", _builder=None): class static_range: - """Iterator that counts upward forever.""" + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.static_range(10): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it also guides the compiler to unroll the loop aggressively. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + """ def __init__(self, arg1, arg2=None, step=None): assert isinstance(arg1, constexpr)