mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix rand > uint32.max (#15330)
need to keep low and high as 1D tensor. `PYTHONPATH=. LLAMA3_SIZE=405B python3 examples/mlperf/models/flat_llama.py` works now
This commit is contained in:
@@ -162,5 +162,13 @@ class TestTensorUnique(unittest.TestCase):
|
||||
Tensor.realize(b,c)
|
||||
self.assertIs(b.uop.buffer, c.uop.buffer)
|
||||
|
||||
class TestRand(unittest.TestCase):
|
||||
def test_rand_large_tensor(self):
|
||||
# large tensor rand (num > uint32.max) should not crash in frontend
|
||||
Tensor.manual_seed(0)
|
||||
Tensor.rand(2**17, 2**17).schedule()
|
||||
Tensor.rand(2**17, 2**17).schedule()
|
||||
Tensor.rand(2**17, 2**17).schedule()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -633,12 +633,12 @@ class Tensor(OpMixin):
|
||||
Tensor._device_rng_counters[device] = Tensor([0, 0], device=device, dtype=dtypes.uint32, requires_grad=False).contiguous()
|
||||
|
||||
# increment rng counter for devices
|
||||
new_low = Tensor._device_rng_counters[device][0] + (num & 0xffffffff)
|
||||
new_high = Tensor._device_rng_counters[device][1] + (num >> 32) + (new_low < Tensor._device_rng_counters[device][0]).cast(dtypes.uint32)
|
||||
Tensor._device_rng_counters[device].assign(Tensor.stack(new_low, new_high))
|
||||
new_low = Tensor._device_rng_counters[device][0:1] + (num & 0xffffffff)
|
||||
new_high = Tensor._device_rng_counters[device][1:2] + (num >> 32) + (new_low < Tensor._device_rng_counters[device][0]).cast(dtypes.uint32)
|
||||
Tensor._device_rng_counters[device].assign(new_low.cat(new_high))
|
||||
|
||||
low = Tensor._device_rng_counters[device][0] - (num & 0xffffffff)
|
||||
high = Tensor._device_rng_counters[device][1] - (num >> 32) - (Tensor._device_rng_counters[device][0] < (num & 0xffffffff)).cast(dtypes.uint32)
|
||||
low = Tensor._device_rng_counters[device][0:1] - (num & 0xffffffff)
|
||||
high = Tensor._device_rng_counters[device][1:2] - (num >> 32) - (Tensor._device_rng_counters[device][0] < (num & 0xffffffff)).cast(dtypes.uint32)
|
||||
|
||||
# threefry random bits
|
||||
if num > dtypes.uint32.max:
|
||||
|
||||
Reference in New Issue
Block a user