mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
split arange threefry (#4590)
This commit is contained in:
@@ -336,17 +336,18 @@ class Tensor:
|
||||
|
||||
# threefry
|
||||
if (num := prod((shape:=argfix(*shape)))) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
|
||||
counts = (Tensor.arange(num, device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device)).realize().pad(((0,num%2),))
|
||||
counts1 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device)).realize()
|
||||
counts2 = counts1 + math.ceil(num / 2)
|
||||
Tensor._rng_counter.assign(Tensor._rng_counter + num).realize()
|
||||
|
||||
rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
|
||||
ks = [0x0, Tensor._seed ^ 0x0 ^ 0x1BD11BDA, Tensor._seed]
|
||||
|
||||
x = [(c := counts.chunk(2))[0] + ks[-1], c[1] + ks[0]]
|
||||
x = [counts1 + ks[-1], counts2 + ks[0]]
|
||||
for i in range(5):
|
||||
for r in rotations[i % 2]: x[0], x[1] = (x0 := x[0] + x[1]), x0 ^ ((x[1] << r) + (x[1] >> (32 - r)))
|
||||
x = [(x[0] + ks[i % 3]), (x[1] + ks[(i + 1) % 3] + i + 1)]
|
||||
out = x[0].cat(x[1])[:num].rshift(8).cast(dtypes.float32).div(2 ** 24)
|
||||
out = x[0].cat(x[1]).rshift(8).cast(dtypes.float32).div(2 ** 24)[:num]
|
||||
out = out.reshape(shape).cast(dtypes.default_float if dtype is None else dtype)
|
||||
out.requires_grad = kwargs.get("requires_grad")
|
||||
return out.contiguous()
|
||||
|
||||
Reference in New Issue
Block a user