split arange threefry (#4590)

This commit is contained in:
wozeparrot
2024-05-14 21:10:22 -07:00
committed by GitHub
parent 9425973bc7
commit 7f009cf9fa

View File

@@ -336,17 +336,18 @@ class Tensor:
# threefry
if (num := prod((shape:=argfix(*shape)))) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
counts = (Tensor.arange(num, device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device)).realize().pad(((0,num%2),))
counts1 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device)).realize()
counts2 = counts1 + math.ceil(num / 2)
Tensor._rng_counter.assign(Tensor._rng_counter + num).realize()
rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
ks = [0x0, Tensor._seed ^ 0x0 ^ 0x1BD11BDA, Tensor._seed]
x = [(c := counts.chunk(2))[0] + ks[-1], c[1] + ks[0]]
x = [counts1 + ks[-1], counts2 + ks[0]]
for i in range(5):
for r in rotations[i % 2]: x[0], x[1] = (x0 := x[0] + x[1]), x0 ^ ((x[1] << r) + (x[1] >> (32 - r)))
x = [(x[0] + ks[i % 3]), (x[1] + ks[(i + 1) % 3] + i + 1)]
out = x[0].cat(x[1])[:num].rshift(8).cast(dtypes.float32).div(2 ** 24)
out = x[0].cat(x[1]).rshift(8).cast(dtypes.float32).div(2 ** 24)[:num]
out = out.reshape(shape).cast(dtypes.default_float if dtype is None else dtype)
out.requires_grad = kwargs.get("requires_grad")
return out.contiguous()