From 94926d00d8c102467147afdf2f794ceef8b9b92b Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 17 Mar 2026 22:00:01 -0400 Subject: [PATCH] fix rand > uint32.max (#15330) need to keep low and high as 1D tensor. `PYTHONPATH=. LLAMA3_SIZE=405B python3 examples/mlperf/models/flat_llama.py` works now --- test/null/test_tensor.py | 8 ++++++++ tinygrad/tensor.py | 10 +++++----- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/test/null/test_tensor.py b/test/null/test_tensor.py index 4bf3d68f06..84fdae1a5a 100644 --- a/test/null/test_tensor.py +++ b/test/null/test_tensor.py @@ -162,5 +162,13 @@ class TestTensorUnique(unittest.TestCase): Tensor.realize(b,c) self.assertIs(b.uop.buffer, c.uop.buffer) +class TestRand(unittest.TestCase): + def test_rand_large_tensor(self): + # large tensor rand (num > uint32.max) should not crash in frontend + Tensor.manual_seed(0) + Tensor.rand(2**17, 2**17).schedule() + Tensor.rand(2**17, 2**17).schedule() + Tensor.rand(2**17, 2**17).schedule() + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 2da7ff16e6..bfafeb08ed 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -633,12 +633,12 @@ class Tensor(OpMixin): Tensor._device_rng_counters[device] = Tensor([0, 0], device=device, dtype=dtypes.uint32, requires_grad=False).contiguous() # increment rng counter for devices - new_low = Tensor._device_rng_counters[device][0] + (num & 0xffffffff) - new_high = Tensor._device_rng_counters[device][1] + (num >> 32) + (new_low < Tensor._device_rng_counters[device][0]).cast(dtypes.uint32) - Tensor._device_rng_counters[device].assign(Tensor.stack(new_low, new_high)) + new_low = Tensor._device_rng_counters[device][0:1] + (num & 0xffffffff) + new_high = Tensor._device_rng_counters[device][1:2] + (num >> 32) + (new_low < Tensor._device_rng_counters[device][0]).cast(dtypes.uint32) + Tensor._device_rng_counters[device].assign(new_low.cat(new_high)) - low = Tensor._device_rng_counters[device][0] - (num & 0xffffffff) - high = Tensor._device_rng_counters[device][1] - (num >> 32) - (Tensor._device_rng_counters[device][0] < (num & 0xffffffff)).cast(dtypes.uint32) + low = Tensor._device_rng_counters[device][0:1] - (num & 0xffffffff) + high = Tensor._device_rng_counters[device][1:2] - (num >> 32) - (Tensor._device_rng_counters[device][0] < (num & 0xffffffff)).cast(dtypes.uint32) # threefry random bits if num > dtypes.uint32.max: