From 600a39771d6ec1b714d774790367f280ccb7ca1f Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 28 Jul 2024 14:34:10 -0400 Subject: [PATCH] fix Tensor.arange if (stop-start) and step have different signs (#5775) --- test/test_dtype.py | 3 +++ tinygrad/tensor.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/test/test_dtype.py b/test/test_dtype.py index d5c7d2c0ce..872c46dac1 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -453,6 +453,9 @@ class TestTypeSpec(unittest.TestCase): _assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5)) _assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7)) _assert_eq(Tensor.arange(3, 8.5, 3), dtypes.default_float, np.arange(3, 8.5, 3)) + # stop-start and step have different signs + _assert_eq(Tensor.arange(3, 5, -2), dtypes.default_int, np.arange(3, 5, -2)) + _assert_eq(Tensor.arange(5.0, 3.0), dtypes.default_float, np.arange(5.0, 3.0)) @given(strat.sampled_from(core_dtypes), strat.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne])) def test_bool_ops(self, dtype, op): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7d9a37def7..03468396a9 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -505,6 +505,8 @@ class Tensor: if stop is None: stop, start = start, 0 assert all(isinstance(s, (int, float)) for s in (start, stop, step)), f"symbolic arange not supported {start=}, {stop=}, {step=}" dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int) + # NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs + if (stop-start)/step <= 0: return Tensor([], dtype=dtype, **kwargs) return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype) @staticmethod