[DOCS] Add missing ops and corresponding comments (#1699)

This commit is contained in:
Keren Zhou
2023-05-21 15:18:48 -04:00
committed by GitHub
parent a2433f3135
commit 74dbb2fc0a
2 changed files with 69 additions and 4 deletions

View File

@@ -23,6 +23,8 @@ Creation Ops
:nosignatures:
arange
cat
full
zeros
@@ -33,11 +35,13 @@ Shape Manipulation Ops
:toctree: generated
:nosignatures:
broadcast
broadcast_to
expand_dims
reshape
ravel
reshape
trans
view
Linear Algebra Ops
@@ -83,11 +87,13 @@ Math Ops
abs
exp
log
fdiv
cos
sin
sqrt
sigmoid
softmax
umulhi
Reduction Ops
@@ -151,4 +157,27 @@ Compiler Hint Ops
:toctree: generated
:nosignatures:
debug_barrier
max_contiguous
multiple_of
Debug Ops
-----------------
.. autosummary::
:toctree: generated
:nosignatures:
static_print
static_assert
device_print
device_assert
Iterators
-----------------
.. autosummary::
:toctree: generated
:nosignatures:
static_range

View File

@@ -1451,12 +1451,15 @@ def xor_sum(input, axis, _builder=None, _generator=None):
# -----------------------
# Internal for debugging
# Compiler Hint Ops
# -----------------------
@builtin
def debug_barrier(_builder=None):
'''
Insert a barrier to synchronize all threads in a block.
'''
return semantic.debug_barrier(_builder)
@@ -1498,16 +1501,28 @@ def max_contiguous(input, values, _builder=None):
@builtin
def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _builder=None):
'''
Print the values at compile time. The parameters are the same as the builtin :code:`print`.
'''
pass
@builtin
def static_assert(cond, msg="", _builder=None):
'''
Assert the condition at compile time. The parameters are the same as the builtin :code:`assert`.
'''
pass
@builtin
def device_print(prefix, *args, _builder=None):
'''
Print the values at runtime from the device.
:param prefix: a prefix to print before the values. This is required to be a string literal.
:param args: the values to print. They can be any tensor or scalar.
'''
import string
prefix = _constexpr_to_value(prefix)
assert isinstance(prefix, str), f"{prefix} is not string"
@@ -1525,6 +1540,12 @@ def device_print(prefix, *args, _builder=None):
@builtin
def device_assert(cond, msg="", _builder=None):
'''
Assert the condition at runtime from the device.
:param cond: the condition to assert. This is required to be a boolean tensor.
:param msg: the message to print if the assertion fails. This is required to be a string literal.
'''
msg = _constexpr_to_value(msg)
import inspect
frame = inspect.currentframe()
@@ -1550,7 +1571,22 @@ def device_assert(cond, msg="", _builder=None):
class static_range:
"""Iterator that counts upward forever."""
"""
Iterator that counts upward forever.
.. highlight:: python
.. code-block:: python
@triton.jit
def kernel(...):
for i in tl.static_range(10):
...
:note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of
:code:`triton.jit` functions. In addition, it also guides the compiler to unroll the loop aggressively.
:param arg1: the start value.
:param arg2: the end value.
:param step: the step value.
"""
def __init__(self, arg1, arg2=None, step=None):
assert isinstance(arg1, constexpr)