mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] fix checks for valid slice and avoid hitting an obscure exception. (#1720)
When comparing to the expected slides, using the `==` operator will dispatch to the component of the slice. If the user writes `a[10:20]` these are `triton.constexpr` instances, and the `__eq__` operator which is implemented as: `return constexpr(self.value == other.value)`. At this point the access to `.value` on the provided `None` yields an exception that isn't very friendly to the user. I am not sure if the implementation of `constexpr` should be hardened instead? Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -467,6 +467,21 @@ def test_broadcast(dtype):
|
||||
broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N)
|
||||
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()
|
||||
|
||||
# ------------------
|
||||
# test invalid slice
|
||||
# ------------------
|
||||
|
||||
|
||||
def test_invalid_slice():
|
||||
dst = torch.empty(128, device='cuda')
|
||||
|
||||
@triton.jit
|
||||
def _kernel(dst):
|
||||
dst[10:]
|
||||
|
||||
with pytest.raises(triton.CompilationError, match='unsupported tensor index'):
|
||||
_kernel[(1,)](dst=dst)
|
||||
|
||||
|
||||
# ----------------
|
||||
# test expand_dims
|
||||
|
||||
@@ -711,7 +711,7 @@ class tensor:
|
||||
for dim, sl in enumerate(slices):
|
||||
if isinstance(sl, constexpr) and sl.value is None:
|
||||
ret = semantic.expand_dims(ret, dim, _builder)
|
||||
elif sl == slice(None, None, None):
|
||||
elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None:
|
||||
pass
|
||||
else:
|
||||
assert False, f"unsupported tensor index: {sl}"
|
||||
|
||||
Reference in New Issue
Block a user