From 2fc0bd150b2f88d371c0e320aa97c5969333a1c7 Mon Sep 17 00:00:00 2001 From: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:18:25 +0200 Subject: [PATCH] Arange overflow raises error and one_hot upcast (#11975) * add error * to_dtype * shorten line * add test * upcast one hot dim im overflows --- test/test_tensor.py | 4 ++++ tinygrad/tensor.py | 6 ++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/test/test_tensor.py b/test/test_tensor.py index 3b243773d8..27c17ae04e 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -918,6 +918,10 @@ class TestIdxUpcast(unittest.TestCase): uops = self._schedule_render(a) assert all(uop.dtype is not dtypes.long for uop in uops) + def test_arange_raise_overflow(self): + with self.assertRaises(ValueError): + self._schedule_render(Tensor.arange(2**33, dtype=dtypes.int)) + @unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported") def test_int64_unsupported_overflow_sym(self): with self.assertRaises(KeyError): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index a6116e0291..fb37bd7b18 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -8,7 +8,7 @@ from tinygrad.dtype import _from_np_dtype, _to_np_dtype from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, RANGEIFY, FUSE_ATTENTION from tinygrad.gradient import compute_gradient -from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, index_to_concrete_int +from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, index_to_concrete_int, sint_to_uop from tinygrad.uop.spec import tensor_uop_spec, type_verify from tinygrad.device import Device, Buffer from tinygrad.engine.realize import run_schedule @@ -632,6 +632,7 @@ class Tensor(MathTrait): """ if stop is None: stop, start = start, 0 dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int) + if start < (dt:=to_dtype(dtype)).min or dt.max < (stop-step): raise ValueError(f"arange [{start}, {stop}) is not representable in dtype {dtype}") # NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs) return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumalu(0, Ops.ADD) + (start - step)).cast(dtype) @@ -3897,7 +3898,8 @@ class Tensor(MathTrait): def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1) -> Tensor: if not dtypes.is_int(self.dtype): raise RuntimeError(f"_one_hot_along_dim expects int index tensor, getting {self.dtype}") offset = self.ndim - self._resolve_dim(dim) - 1 - return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset) + dt = dtypes.int64 if sint_to_uop(num_classes).overflows(dtypes.int32) else dtypes.int32 + return self == Tensor.arange(num_classes, dtype=dt, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset) def one_hot(self, num_classes:int=-1) -> Tensor: """