mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add tiny test for randomness + remove ulong buffers (#7648)
* add tiny test for randomness * Tensor._device_seeds is a Tuple * no tuple, just a 2 element tensor * no more longs * fix tests, and maybe ocelot works now * NV still doesn't work. cleanup rules * test + two more rules
This commit is contained in:
@@ -6,6 +6,7 @@ import torch
|
||||
from tinygrad import nn, dtypes, Tensor, Device, TinyJit
|
||||
from tinygrad.helpers import getenv, CI
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.engine.realize import lower_schedule, CompiledRunner
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
|
||||
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
@@ -92,10 +93,16 @@ class TestRandomness(unittest.TestCase):
|
||||
|
||||
counts = Tensor.arange(20, dtype=dtypes.uint32)
|
||||
counts0, counts1 = counts.chunk(2)
|
||||
r = Tensor._threefry_random_bits(1337 << 32, counts0, counts1).numpy()
|
||||
r = Tensor._threefry_random_bits(Tensor([0, 1337], dtype='uint32'), counts0, counts1).numpy()
|
||||
|
||||
np.testing.assert_allclose(jr, r)
|
||||
|
||||
def test_threefry_doesnt_use_long(self):
|
||||
for ei in lower_schedule(Tensor.rand(20).schedule()):
|
||||
if isinstance(ei.prg, CompiledRunner):
|
||||
for u in ei.prg.p.uops:
|
||||
self.assertNotIn(u.dtype, {dtypes.long, dtypes.ulong}, msg=f"long found in {ei.prg.p.name}")
|
||||
|
||||
def test_threefry_against_reference_full(self):
|
||||
Tensor.manual_seed(1337)
|
||||
|
||||
|
||||
@@ -29,6 +29,14 @@ class TestTiny(unittest.TestCase):
|
||||
self.assertListEqual((out:=a@b).flatten().tolist(), [1.0]*(N*N))
|
||||
if IMAGE < 2: self.assertEqual(out.dtype, out_dtype)
|
||||
|
||||
# *** randomness ***
|
||||
|
||||
def test_random(self):
|
||||
out = Tensor.rand(10)
|
||||
for x in out.tolist():
|
||||
self.assertGreaterEqual(x, 0.0)
|
||||
self.assertLessEqual(x, 1.0)
|
||||
|
||||
# *** JIT (for Python speed) ***
|
||||
|
||||
def test_jit(self):
|
||||
|
||||
Reference in New Issue
Block a user