fix rand > uint32.max (#15330)

need to keep low and high as 1D tensor.
`PYTHONPATH=. LLAMA3_SIZE=405B python3 examples/mlperf/models/flat_llama.py` works now
This commit is contained in:
chenyu
2026-03-17 22:00:01 -04:00
committed by GitHub
parent b45edeb965
commit 94926d00d8
2 changed files with 13 additions and 5 deletions

View File

@@ -162,5 +162,13 @@ class TestTensorUnique(unittest.TestCase):
Tensor.realize(b,c)
self.assertIs(b.uop.buffer, c.uop.buffer)
class TestRand(unittest.TestCase):
def test_rand_large_tensor(self):
# large tensor rand (num > uint32.max) should not crash in frontend
Tensor.manual_seed(0)
Tensor.rand(2**17, 2**17).schedule()
Tensor.rand(2**17, 2**17).schedule()
Tensor.rand(2**17, 2**17).schedule()
if __name__ == '__main__':
unittest.main()

View File

@@ -633,12 +633,12 @@ class Tensor(OpMixin):
Tensor._device_rng_counters[device] = Tensor([0, 0], device=device, dtype=dtypes.uint32, requires_grad=False).contiguous()
# increment rng counter for devices
new_low = Tensor._device_rng_counters[device][0] + (num & 0xffffffff)
new_high = Tensor._device_rng_counters[device][1] + (num >> 32) + (new_low < Tensor._device_rng_counters[device][0]).cast(dtypes.uint32)
Tensor._device_rng_counters[device].assign(Tensor.stack(new_low, new_high))
new_low = Tensor._device_rng_counters[device][0:1] + (num & 0xffffffff)
new_high = Tensor._device_rng_counters[device][1:2] + (num >> 32) + (new_low < Tensor._device_rng_counters[device][0]).cast(dtypes.uint32)
Tensor._device_rng_counters[device].assign(new_low.cat(new_high))
low = Tensor._device_rng_counters[device][0] - (num & 0xffffffff)
high = Tensor._device_rng_counters[device][1] - (num >> 32) - (Tensor._device_rng_counters[device][0] < (num & 0xffffffff)).cast(dtypes.uint32)
low = Tensor._device_rng_counters[device][0:1] - (num & 0xffffffff)
high = Tensor._device_rng_counters[device][1:2] - (num >> 32) - (Tensor._device_rng_counters[device][0] < (num & 0xffffffff)).cast(dtypes.uint32)
# threefry random bits
if num > dtypes.uint32.max: