mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
remove MOCKGPU workaround in rand (#9565)
also `requires_grad_` to save a line
This commit is contained in:
@@ -507,9 +507,6 @@ class Tensor(SimpleMathTrait):
|
||||
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
|
||||
num = ceildiv(numel * dtype.itemsize, 4)
|
||||
|
||||
# when using MOCKGPU and NV generate rand on CPU
|
||||
if getenv("MOCKGPU") and device.startswith("NV"): device = "CPU"
|
||||
|
||||
# generate per device seeds and rng counter if we haven't seen this device yet
|
||||
if device not in Tensor._device_seeds:
|
||||
Tensor._device_seeds[device] = Tensor(
|
||||
@@ -532,12 +529,7 @@ class Tensor(SimpleMathTrait):
|
||||
one = Tensor.ones_like(bits, device=bits.device, dtype=dtype).bitcast(uint_dtype)
|
||||
bits = bits.rshift((dtype.itemsize * 8) - nmant).bitwise_or(one)
|
||||
# bitcast back to the original dtype and reshape
|
||||
out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape)
|
||||
|
||||
# move back to the original device if we were using MOCKGPU
|
||||
if getenv("MOCKGPU") and _device: out = out.to(_device)
|
||||
|
||||
out.requires_grad = kwargs.get("requires_grad")
|
||||
out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape).requires_grad_(kwargs.get("requires_grad"))
|
||||
return out.contiguous() if contiguous else out
|
||||
|
||||
# ***** creation helper functions *****
|
||||
|
||||
Reference in New Issue
Block a user