mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
fix failed threefry (#10646)
This commit is contained in:
@@ -118,7 +118,7 @@ class SpeedyResNet:
|
||||
# hyper-parameters were exactly the same as the original repo
|
||||
bias_scaler = 58
|
||||
hyp = {
|
||||
'seed' : 209,
|
||||
'seed' : 200,
|
||||
'opt': {
|
||||
'bias_lr': 1.76 * bias_scaler/512,
|
||||
'non_bias_lr': 1.76 / 512,
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 1.5 MiB After Width: | Height: | Size: 1.5 MiB |
Binary file not shown.
@@ -136,9 +136,7 @@ class TestRandomness(unittest.TestCase):
|
||||
jr = np.array([0.9614430665969849, 0.059279561042785645, 0.01909029483795166, 0.47882091999053955, 0.9677121639251709,
|
||||
0.36863112449645996, 0.3102607727050781, 0.06608951091766357, 0.35329878330230713, 0.26518797874450684], dtype=np.float32)
|
||||
r = Tensor.rand(10).numpy()
|
||||
# TODO: this failed because increment happened before _threefry_random_bits
|
||||
with self.assertRaises(AssertionError):
|
||||
np.testing.assert_allclose(r, jr, atol=1e-5, rtol=1e-5)
|
||||
np.testing.assert_allclose(r, jr, atol=1e-5, rtol=1e-5)
|
||||
|
||||
@unittest.skipIf(not_support_multi_device(), "no multi")
|
||||
def test_threefry_tensors_cnt(self):
|
||||
|
||||
@@ -174,12 +174,12 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
def test_rand(self):
|
||||
x = Tensor.rand(32)
|
||||
check_schedule(x, 3, [Tensor._device_rng_counters[x.device]])
|
||||
check_schedule(x, 4, [Tensor._device_rng_counters[x.device]])
|
||||
|
||||
def test_rand_recompute_arange(self):
|
||||
x = Tensor.rand(32)
|
||||
with Context(DONT_GROUP_REDUCES=1):
|
||||
check_schedule(x, 2, [Tensor._device_rng_counters[x.device]])
|
||||
check_schedule(x, 3, [Tensor._device_rng_counters[x.device]])
|
||||
|
||||
@unittest.skip("TODO: do not divide by zero given x.idiv(VALID)")
|
||||
def test_rand_handcoded(self):
|
||||
|
||||
@@ -512,12 +512,13 @@ class Tensor(MathTrait):
|
||||
Tensor._device_seeds[device] = Tensor(
|
||||
[int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
|
||||
device=device, dtype=dtypes.uint32, requires_grad=False)
|
||||
Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
|
||||
Tensor._device_rng_counters[device] = Tensor([num], device=device, dtype=dtypes.uint32, requires_grad=False)
|
||||
# increment rng counter for devices
|
||||
else: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()
|
||||
|
||||
# threefry random bits
|
||||
counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._device_rng_counters[device])
|
||||
bits_count = Tensor._device_rng_counters[device] - num
|
||||
counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+bits_count)
|
||||
counts1 = counts0 + ceildiv(num, 2)
|
||||
bits = Tensor._threefry_random_bits(Tensor._device_seeds[device], counts0, counts1)[:num]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user