mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
add tiny test for randomness + remove ulong buffers (#7648)
* add tiny test for randomness * Tensor._device_seeds is a Tuple * no tuple, just a 2 element tensor * no more longs * fix tests, and maybe ocelot works now * NV still doesn't work. cleanup rules * test + two more rules
This commit is contained in:
@@ -6,6 +6,7 @@ import torch
|
||||
from tinygrad import nn, dtypes, Tensor, Device, TinyJit
|
||||
from tinygrad.helpers import getenv, CI
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.engine.realize import lower_schedule, CompiledRunner
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
|
||||
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
@@ -92,10 +93,16 @@ class TestRandomness(unittest.TestCase):
|
||||
|
||||
counts = Tensor.arange(20, dtype=dtypes.uint32)
|
||||
counts0, counts1 = counts.chunk(2)
|
||||
r = Tensor._threefry_random_bits(1337 << 32, counts0, counts1).numpy()
|
||||
r = Tensor._threefry_random_bits(Tensor([0, 1337], dtype='uint32'), counts0, counts1).numpy()
|
||||
|
||||
np.testing.assert_allclose(jr, r)
|
||||
|
||||
def test_threefry_doesnt_use_long(self):
|
||||
for ei in lower_schedule(Tensor.rand(20).schedule()):
|
||||
if isinstance(ei.prg, CompiledRunner):
|
||||
for u in ei.prg.p.uops:
|
||||
self.assertNotIn(u.dtype, {dtypes.long, dtypes.ulong}, msg=f"long found in {ei.prg.p.name}")
|
||||
|
||||
def test_threefry_against_reference_full(self):
|
||||
Tensor.manual_seed(1337)
|
||||
|
||||
|
||||
@@ -29,6 +29,14 @@ class TestTiny(unittest.TestCase):
|
||||
self.assertListEqual((out:=a@b).flatten().tolist(), [1.0]*(N*N))
|
||||
if IMAGE < 2: self.assertEqual(out.dtype, out_dtype)
|
||||
|
||||
# *** randomness ***
|
||||
|
||||
def test_random(self):
|
||||
out = Tensor.rand(10)
|
||||
for x in out.tolist():
|
||||
self.assertGreaterEqual(x, 0.0)
|
||||
self.assertLessEqual(x, 1.0)
|
||||
|
||||
# *** JIT (for Python speed) ***
|
||||
|
||||
def test_jit(self):
|
||||
|
||||
@@ -274,8 +274,17 @@ sym = symbolic_flat+PatternMatcher([
|
||||
# tensor core cleanups
|
||||
(UPat.var("add") + UPat(Ops.WMMA, name="wmma"),
|
||||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||
# threefry
|
||||
# threefry + remove longs
|
||||
(UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32),
|
||||
(UPat.var('x', dtypes.uint32).cast(dtypes.uint64).cast(dtypes.uint32), lambda x: x), # cast there and back is noop (TODO: genericize)
|
||||
((UPat.var('x', dtypes.uint64)&0xFFFFFFFF).cast(dtypes.uint32), lambda x: x.cast(dtypes.uint32)), # cast does truncation
|
||||
(((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
|
||||
(((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
|
||||
# hacks for threefry long removal when padded (TODO: genericize)
|
||||
(UPat.var('x', dtypes.uint32).cast(dtypes.uint64) * UPat.var('y').where(UPat.const(dtypes.uint64, 1<<32), UPat.const(dtypes.uint64, 0)),
|
||||
lambda x,y: y.where(x, UOp.const(dtypes.uint32, 0)).cast(dtypes.uint64) * (1<<32)),
|
||||
((UPat.var('x', dtypes.uint64)&(UPat.var('y').where(UPat.const(dtypes.uint64, 0xFFFFFFFF), UPat.const(dtypes.uint64, 0)))).cast(dtypes.uint32),
|
||||
lambda x,y: y.where(x.cast(dtypes.uint32), UOp.const(dtypes.uint32, 0))),
|
||||
# arange loop folding
|
||||
(acc_pat.assign(UPat.any(arange_m, arange_m+UPat.var("extra"))+acc_pat), loop_collapse),
|
||||
# indexing, with cast or where
|
||||
|
||||
@@ -464,9 +464,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {}
|
||||
|
||||
@staticmethod
|
||||
def _threefry_random_bits(key, counts0, counts1):
|
||||
def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor):
|
||||
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
|
||||
x = F.Threefry.apply(*x._broadcasted(key))
|
||||
x = F.Threefry.apply(x, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
|
||||
counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
|
||||
return counts0.cat(counts1)
|
||||
|
||||
@@ -494,9 +494,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
|
||||
# generate per device seeds and rng counter if we haven't seen this device yet
|
||||
if device not in Tensor._device_seeds:
|
||||
Tensor._device_seeds[device] = Tensor([((Tensor._seed & 0xffffffff) << 32) \
|
||||
| int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big") & 0xffffffff],
|
||||
device=device, dtype=dtypes.uint64, requires_grad=False)
|
||||
Tensor._device_seeds[device] = Tensor(
|
||||
[int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
|
||||
device=device, dtype=dtypes.uint32, requires_grad=False)
|
||||
Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
|
||||
had_counter = False
|
||||
else: had_counter = True
|
||||
|
||||
Reference in New Issue
Block a user