update arange range check (#15794)

it was not checking negative steps correctly
This commit is contained in:
chenyu
2026-04-17 16:07:50 -04:00
committed by GitHub
parent 23ca680a3a
commit 0191cc73dc
3 changed files with 13 additions and 5 deletions

View File

@@ -281,6 +281,17 @@ class TestOps(unittest.TestCase):
helper_test_op([], lambda: torch.arange(5.5, 175.5, 2.5), lambda: Tensor.arange(5.5, 175.5, 2.5), forward_only=True)
helper_test_op([], lambda: torch.arange(-30.2, -0.3, 0.75), lambda: Tensor.arange(-30.2, -0.3, 0.75), forward_only=True)
helper_test_op([], lambda: torch.arange(-50.3, -380.2, -2.25), lambda: Tensor.arange(-50.3, -380.2, -2.25), forward_only=True)
# boundary values that fit exactly in int8 (min=-128, max=127)
helper_test_op([], lambda: torch.arange(128, dtype=torch.int8), lambda: Tensor.arange(128, dtype=dtypes.int8), forward_only=True)
helper_test_op([], lambda: torch.arange(-128, 128, dtype=torch.int8), lambda: Tensor.arange(-128, 128, dtype=dtypes.int8), forward_only=True)
helper_test_op([], lambda: torch.arange(127, -129, -1, dtype=torch.int8),
lambda: Tensor.arange(127, -129, -1, dtype=dtypes.int8), forward_only=True)
# overflow: tinygrad raises (torch silently wraps)
with self.assertRaises(OverflowError): Tensor.arange(2**33, dtype=dtypes.int)
with self.assertRaises(OverflowError): Tensor.arange(129, dtype=dtypes.int8) # last=128 overflows
with self.assertRaises(OverflowError): Tensor.arange(-129, 128, dtype=dtypes.int8) # start=-129 overflows
with self.assertRaises(OverflowError): Tensor.arange(128, 0, -1, dtype=dtypes.int8) # start=128 overflows
with self.assertRaises(OverflowError): Tensor.arange(127, -130, -1, dtype=dtypes.int8) # last=-129 overflows
def test_arange_big(self):
helper_test_op([], lambda: torch.arange(256, dtype=torch.int32), lambda: Tensor.arange(256), forward_only=True)

View File

@@ -107,10 +107,6 @@ class TestIdxUpcast(unittest.TestCase):
uops = self._schedule_render(a)
assert all(uop.dtype is not dtypes.long for uop in uops)
def test_arange_raise_overflow(self):
with self.assertRaises(ValueError):
self._schedule_render(Tensor.arange(2**33, dtype=dtypes.int))
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
def test_int64_unsupported_overflow_sym(self):
with self.assertRaises((KeyError, RuntimeError)):

View File

@@ -740,7 +740,8 @@ class Tensor(OpMixin):
"""
if stop is None: stop, start = start, 0
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
if start < (dt:=to_dtype(dtype)).min or dt.max < (stop-step): raise ValueError(f"arange [{start}, {stop}) is not representable in dtype {dtype}")
lo, hi = (start, stop-step) if step > 0 else (stop-step, start)
if lo < (dt:=to_dtype(dtype)).min or dt.max < hi: raise OverflowError(f"arange [{start}, {stop}) is not representable in dtype {dtype}")
# NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs)
return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype)