mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Add error messages for arange (#1218)
Fix issue https://github.com/openai/triton/issues/244 Check `end` is greater than `start`. Check if the range can fit in `int32`. Check the number of elements less than or equal to `TRITON_MAX_TENSOR_NUMEL = 131072`. --------- Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -9,6 +9,8 @@ from triton._C.libtriton.triton import ir
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
TRITON_MAX_TENSOR_NUMEL = 131072
|
||||
|
||||
|
||||
def _to_tensor(x, builder):
|
||||
if isinstance(x, bool):
|
||||
@@ -254,6 +256,8 @@ class block_type(dtype):
|
||||
self.numel = 1
|
||||
for s in self.shape:
|
||||
self.numel *= s
|
||||
if self.numel > TRITON_MAX_TENSOR_NUMEL:
|
||||
raise ValueError(f"numel ({self.numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})")
|
||||
|
||||
self.name = self.__str__()
|
||||
|
||||
@@ -702,12 +706,13 @@ def num_programs(axis, _builder=None):
|
||||
@builtin
|
||||
def arange(start, end, _builder=None):
|
||||
"""
|
||||
Returns contiguous values within the open interval [:code:`start`, :code:`end`).
|
||||
Returns contiguous values within the left-closed and right-open interval [:code:`start`, :code:`end`). \
|
||||
End - Start must be less than or equal to TRITON_MAX_TENSOR_NUMEL = 131072
|
||||
|
||||
:param start: Start of the interval. Must be a power of two.
|
||||
:type start: int
|
||||
:param stop: End of the interval. Must be a power of two >= start.
|
||||
:type stop: int
|
||||
:type start: int32
|
||||
:param end: End of the interval. Must be a power of two > start.
|
||||
:type end: int32
|
||||
"""
|
||||
start = _constexpr_to_value(start)
|
||||
end = _constexpr_to_value(end)
|
||||
|
||||
@@ -478,6 +478,12 @@ def not_equal(input: tl.tensor,
|
||||
def arange(start: int, end: int, builder: ir.builder) -> tl.tensor:
|
||||
if not isinstance(start, int) or not isinstance(end, int):
|
||||
raise ValueError("arange's arguments must be of type tl.constexpr")
|
||||
is_start_int64 = bool(start >> 32)
|
||||
is_end_int64 = bool(end >> 32)
|
||||
if is_start_int64 or is_end_int64:
|
||||
raise ValueError("arange must fit in int32")
|
||||
if end <= start:
|
||||
raise ValueError("arange's end argument must be greater than the start argument")
|
||||
|
||||
shape = [end - start]
|
||||
ret_ty = tl.block_type(tl.int32, shape)
|
||||
|
||||
Reference in New Issue
Block a user