mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
add a failed test case for jit/nojit rand [pr] (#9574)
currently adding jit produced different rand values
This commit is contained in:
@@ -277,6 +277,38 @@ class TestJit(unittest.TestCase):
|
||||
assert len(res3) == 5, "All values should be different, rand works in jit."
|
||||
assert res3 != res2, "Jit rand is diff with diff seeds"
|
||||
|
||||
@unittest.expectedFailure # TODO: fix
|
||||
def test_jit_v_nojit_random_regen(self):
|
||||
def f(a, b):
|
||||
rn = Tensor.randn(*a.shape)
|
||||
rn = rn * a
|
||||
rn2 = Tensor.randn(*a.shape)
|
||||
rn2 = rn2 * b
|
||||
rn = rn + rn2
|
||||
rn2 = rn2 + Tensor.randn(*a.shape)
|
||||
return ((a+b)*rn).realize(), ((a+b)*rn2).realize()
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.randn(10, 10).realize() # realize these before resetting the random seed
|
||||
b = Tensor.randn(10, 10).realize()
|
||||
|
||||
Tensor.manual_seed(1234)
|
||||
without_jit = set()
|
||||
for _ in range(5):
|
||||
o1, o2 = f(a, b)
|
||||
without_jit.add(o1.numpy()[0][0])
|
||||
without_jit.add(o2.numpy()[0][0])
|
||||
assert len(without_jit) == 10, "All values should be different."
|
||||
|
||||
Tensor.manual_seed(1234)
|
||||
jf = TinyJit(f)
|
||||
with_jit = set()
|
||||
for _ in range(5):
|
||||
o1, o2 = jf(a, b)
|
||||
with_jit.add(o1.numpy()[0][0])
|
||||
with_jit.add(o2.numpy()[0][0])
|
||||
assert len(with_jit) == 10, "All values should be different."
|
||||
assert with_jit == without_jit, "Jit rand produced different values from no jit."
|
||||
|
||||
def test_jit_multiple_random_regen(self):
|
||||
def f(a, b):
|
||||
rn = Tensor.randn(*a.shape)
|
||||
|
||||
Reference in New Issue
Block a user