mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
update arange range check (#15794)
it was not checking negative steps correctly
This commit is contained in:
@@ -281,6 +281,17 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([], lambda: torch.arange(5.5, 175.5, 2.5), lambda: Tensor.arange(5.5, 175.5, 2.5), forward_only=True)
|
||||
helper_test_op([], lambda: torch.arange(-30.2, -0.3, 0.75), lambda: Tensor.arange(-30.2, -0.3, 0.75), forward_only=True)
|
||||
helper_test_op([], lambda: torch.arange(-50.3, -380.2, -2.25), lambda: Tensor.arange(-50.3, -380.2, -2.25), forward_only=True)
|
||||
# boundary values that fit exactly in int8 (min=-128, max=127)
|
||||
helper_test_op([], lambda: torch.arange(128, dtype=torch.int8), lambda: Tensor.arange(128, dtype=dtypes.int8), forward_only=True)
|
||||
helper_test_op([], lambda: torch.arange(-128, 128, dtype=torch.int8), lambda: Tensor.arange(-128, 128, dtype=dtypes.int8), forward_only=True)
|
||||
helper_test_op([], lambda: torch.arange(127, -129, -1, dtype=torch.int8),
|
||||
lambda: Tensor.arange(127, -129, -1, dtype=dtypes.int8), forward_only=True)
|
||||
# overflow: tinygrad raises (torch silently wraps)
|
||||
with self.assertRaises(OverflowError): Tensor.arange(2**33, dtype=dtypes.int)
|
||||
with self.assertRaises(OverflowError): Tensor.arange(129, dtype=dtypes.int8) # last=128 overflows
|
||||
with self.assertRaises(OverflowError): Tensor.arange(-129, 128, dtype=dtypes.int8) # start=-129 overflows
|
||||
with self.assertRaises(OverflowError): Tensor.arange(128, 0, -1, dtype=dtypes.int8) # start=128 overflows
|
||||
with self.assertRaises(OverflowError): Tensor.arange(127, -130, -1, dtype=dtypes.int8) # last=-129 overflows
|
||||
|
||||
def test_arange_big(self):
|
||||
helper_test_op([], lambda: torch.arange(256, dtype=torch.int32), lambda: Tensor.arange(256), forward_only=True)
|
||||
|
||||
@@ -107,10 +107,6 @@ class TestIdxUpcast(unittest.TestCase):
|
||||
uops = self._schedule_render(a)
|
||||
assert all(uop.dtype is not dtypes.long for uop in uops)
|
||||
|
||||
def test_arange_raise_overflow(self):
|
||||
with self.assertRaises(ValueError):
|
||||
self._schedule_render(Tensor.arange(2**33, dtype=dtypes.int))
|
||||
|
||||
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
|
||||
def test_int64_unsupported_overflow_sym(self):
|
||||
with self.assertRaises((KeyError, RuntimeError)):
|
||||
|
||||
@@ -740,7 +740,8 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
if stop is None: stop, start = start, 0
|
||||
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
|
||||
if start < (dt:=to_dtype(dtype)).min or dt.max < (stop-step): raise ValueError(f"arange [{start}, {stop}) is not representable in dtype {dtype}")
|
||||
lo, hi = (start, stop-step) if step > 0 else (stop-step, start)
|
||||
if lo < (dt:=to_dtype(dtype)).min or dt.max < hi: raise OverflowError(f"arange [{start}, {stop}) is not representable in dtype {dtype}")
|
||||
# NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
|
||||
if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs)
|
||||
return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype)
|
||||
|
||||
Reference in New Issue
Block a user