add tiny test for randomness + remove ulong buffers (#7648)

* add tiny test for randomness

* Tensor._device_seeds is a Tuple

* no tuple, just a 2 element tensor

* no more longs

* fix tests, and maybe ocelot works now

* NV still doesn't work. cleanup rules

* test + two more rules
This commit is contained in:
George Hotz
2024-11-12 12:45:52 +08:00
committed by GitHub
parent c06a5a9c72
commit 4f1f823021
4 changed files with 31 additions and 7 deletions

View File

@@ -6,6 +6,7 @@ import torch
from tinygrad import nn, dtypes, Tensor, Device, TinyJit
from tinygrad.helpers import getenv, CI
from tinygrad.device import is_dtype_supported
from tinygrad.engine.realize import lower_schedule, CompiledRunner
from hypothesis import given, settings, strategies as strat
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
@@ -92,10 +93,16 @@ class TestRandomness(unittest.TestCase):
counts = Tensor.arange(20, dtype=dtypes.uint32)
counts0, counts1 = counts.chunk(2)
r = Tensor._threefry_random_bits(1337 << 32, counts0, counts1).numpy()
r = Tensor._threefry_random_bits(Tensor([0, 1337], dtype='uint32'), counts0, counts1).numpy()
np.testing.assert_allclose(jr, r)
def test_threefry_doesnt_use_long(self):
for ei in lower_schedule(Tensor.rand(20).schedule()):
if isinstance(ei.prg, CompiledRunner):
for u in ei.prg.p.uops:
self.assertNotIn(u.dtype, {dtypes.long, dtypes.ulong}, msg=f"long found in {ei.prg.p.name}")
def test_threefry_against_reference_full(self):
Tensor.manual_seed(1337)

View File

@@ -29,6 +29,14 @@ class TestTiny(unittest.TestCase):
self.assertListEqual((out:=a@b).flatten().tolist(), [1.0]*(N*N))
if IMAGE < 2: self.assertEqual(out.dtype, out_dtype)
# *** randomness ***
def test_random(self):
out = Tensor.rand(10)
for x in out.tolist():
self.assertGreaterEqual(x, 0.0)
self.assertLessEqual(x, 1.0)
# *** JIT (for Python speed) ***
def test_jit(self):

View File

@@ -274,8 +274,17 @@ sym = symbolic_flat+PatternMatcher([
# tensor core cleanups
(UPat.var("add") + UPat(Ops.WMMA, name="wmma"),
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
# threefry
# threefry + remove longs
(UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32),
(UPat.var('x', dtypes.uint32).cast(dtypes.uint64).cast(dtypes.uint32), lambda x: x), # cast there and back is noop (TODO: genericize)
((UPat.var('x', dtypes.uint64)&0xFFFFFFFF).cast(dtypes.uint32), lambda x: x.cast(dtypes.uint32)), # cast does truncation
(((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
(((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
# hacks for threefry long removal when padded (TODO: genericize)
(UPat.var('x', dtypes.uint32).cast(dtypes.uint64) * UPat.var('y').where(UPat.const(dtypes.uint64, 1<<32), UPat.const(dtypes.uint64, 0)),
lambda x,y: y.where(x, UOp.const(dtypes.uint32, 0)).cast(dtypes.uint64) * (1<<32)),
((UPat.var('x', dtypes.uint64)&(UPat.var('y').where(UPat.const(dtypes.uint64, 0xFFFFFFFF), UPat.const(dtypes.uint64, 0)))).cast(dtypes.uint32),
lambda x,y: y.where(x.cast(dtypes.uint32), UOp.const(dtypes.uint32, 0))),
# arange loop folding
(acc_pat.assign(UPat.any(arange_m, arange_m+UPat.var("extra"))+acc_pat), loop_collapse),
# indexing, with cast or where

View File

@@ -464,9 +464,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {}
@staticmethod
def _threefry_random_bits(key, counts0, counts1):
def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor):
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
x = F.Threefry.apply(*x._broadcasted(key))
x = F.Threefry.apply(x, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
return counts0.cat(counts1)
@@ -494,9 +494,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
# generate per device seeds and rng counter if we haven't seen this device yet
if device not in Tensor._device_seeds:
Tensor._device_seeds[device] = Tensor([((Tensor._seed & 0xffffffff) << 32) \
| int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big") & 0xffffffff],
device=device, dtype=dtypes.uint64, requires_grad=False)
Tensor._device_seeds[device] = Tensor(
[int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
device=device, dtype=dtypes.uint32, requires_grad=False)
Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
had_counter = False
else: had_counter = True