[FRONTEND] Add error messages for arange (#1218)

Fix issue https://github.com/openai/triton/issues/244

Check `end` is greater than `start`.
Check if the range can fit in `int32`.
Check the number of elements less than or equal to
`TRITON_MAX_TENSOR_NUMEL = 131072`.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Eric Wang
2023-02-21 18:37:28 -06:00
committed by GitHub
parent d023e1cb06
commit 320ae18093
2 changed files with 15 additions and 4 deletions

View File

@@ -9,6 +9,8 @@ from triton._C.libtriton.triton import ir
T = TypeVar('T')
TRITON_MAX_TENSOR_NUMEL = 131072
def _to_tensor(x, builder):
if isinstance(x, bool):
@@ -254,6 +256,8 @@ class block_type(dtype):
self.numel = 1
for s in self.shape:
self.numel *= s
if self.numel > TRITON_MAX_TENSOR_NUMEL:
raise ValueError(f"numel ({self.numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})")
self.name = self.__str__()
@@ -702,12 +706,13 @@ def num_programs(axis, _builder=None):
@builtin
def arange(start, end, _builder=None):
"""
Returns contiguous values within the open interval [:code:`start`, :code:`end`).
Returns contiguous values within the left-closed and right-open interval [:code:`start`, :code:`end`). \
End - Start must be less than or equal to TRITON_MAX_TENSOR_NUMEL = 131072
:param start: Start of the interval. Must be a power of two.
:type start: int
:param stop: End of the interval. Must be a power of two >= start.
:type stop: int
:type start: int32
:param end: End of the interval. Must be a power of two > start.
:type end: int32
"""
start = _constexpr_to_value(start)
end = _constexpr_to_value(end)

View File

@@ -478,6 +478,12 @@ def not_equal(input: tl.tensor,
def arange(start: int, end: int, builder: ir.builder) -> tl.tensor:
if not isinstance(start, int) or not isinstance(end, int):
raise ValueError("arange's arguments must be of type tl.constexpr")
is_start_int64 = bool(start >> 32)
is_end_int64 = bool(end >> 32)
if is_start_int64 or is_end_int64:
raise ValueError("arange must fit in int32")
if end <= start:
raise ValueError("arange's end argument must be greater than the start argument")
shape = [end - start]
ret_ty = tl.block_type(tl.int32, shape)