[FRONTEND] fix checks for valid slice and avoid hitting an obscure exception. (#1720)

When comparing to the expected slides, using the `==` operator will
dispatch to the component of the slice. If the user writes `a[10:20]`
these are `triton.constexpr` instances, and the `__eq__` operator which
is implemented as: `return constexpr(self.value == other.value)`. At
this point the access to `.value` on the provided `None` yields an
exception that isn't very friendly to the user.

I am not sure if the implementation of `constexpr` should be hardened
instead?

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Mehdi Amini
2023-05-31 09:37:19 -07:00
committed by GitHub
parent 327d362cca
commit 19c65d6007
2 changed files with 16 additions and 1 deletions

View File

@@ -467,6 +467,21 @@ def test_broadcast(dtype):
broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N)
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()
# ------------------
# test invalid slice
# ------------------
def test_invalid_slice():
dst = torch.empty(128, device='cuda')
@triton.jit
def _kernel(dst):
dst[10:]
with pytest.raises(triton.CompilationError, match='unsupported tensor index'):
_kernel[(1,)](dst=dst)
# ----------------
# test expand_dims

View File

@@ -711,7 +711,7 @@ class tensor:
for dim, sl in enumerate(slices):
if isinstance(sl, constexpr) and sl.value is None:
ret = semantic.expand_dims(ret, dim, _builder)
elif sl == slice(None, None, None):
elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None:
pass
else:
assert False, f"unsupported tensor index: {sl}"