fix failed threefry (#10646)

This commit is contained in:
wozeparrot
2025-06-05 17:17:42 -07:00
committed by GitHub
parent e67642d430
commit 0d86f8d375
6 changed files with 7 additions and 8 deletions

View File

@@ -118,7 +118,7 @@ class SpeedyResNet:
# hyper-parameters were exactly the same as the original repo
bias_scaler = 58
hyp = {
'seed' : 209,
'seed' : 200,
'opt': {
'bias_lr': 1.76 * bias_scaler/512,
'non_bias_lr': 1.76 / 512,

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.5 MiB

After

Width:  |  Height:  |  Size: 1.5 MiB

View File

@@ -136,9 +136,7 @@ class TestRandomness(unittest.TestCase):
jr = np.array([0.9614430665969849, 0.059279561042785645, 0.01909029483795166, 0.47882091999053955, 0.9677121639251709,
0.36863112449645996, 0.3102607727050781, 0.06608951091766357, 0.35329878330230713, 0.26518797874450684], dtype=np.float32)
r = Tensor.rand(10).numpy()
# TODO: this failed because increment happened before _threefry_random_bits
with self.assertRaises(AssertionError):
np.testing.assert_allclose(r, jr, atol=1e-5, rtol=1e-5)
np.testing.assert_allclose(r, jr, atol=1e-5, rtol=1e-5)
@unittest.skipIf(not_support_multi_device(), "no multi")
def test_threefry_tensors_cnt(self):

View File

@@ -174,12 +174,12 @@ class TestSchedule(unittest.TestCase):
def test_rand(self):
x = Tensor.rand(32)
check_schedule(x, 3, [Tensor._device_rng_counters[x.device]])
check_schedule(x, 4, [Tensor._device_rng_counters[x.device]])
def test_rand_recompute_arange(self):
x = Tensor.rand(32)
with Context(DONT_GROUP_REDUCES=1):
check_schedule(x, 2, [Tensor._device_rng_counters[x.device]])
check_schedule(x, 3, [Tensor._device_rng_counters[x.device]])
@unittest.skip("TODO: do not divide by zero given x.idiv(VALID)")
def test_rand_handcoded(self):

View File

@@ -512,12 +512,13 @@ class Tensor(MathTrait):
Tensor._device_seeds[device] = Tensor(
[int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
device=device, dtype=dtypes.uint32, requires_grad=False)
Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
Tensor._device_rng_counters[device] = Tensor([num], device=device, dtype=dtypes.uint32, requires_grad=False)
# increment rng counter for devices
else: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()
# threefry random bits
counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._device_rng_counters[device])
bits_count = Tensor._device_rng_counters[device] - num
counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+bits_count)
counts1 = counts0 + ceildiv(num, 2)
bits = Tensor._threefry_random_bits(Tensor._device_seeds[device], counts0, counts1)[:num]