add a failed test case for jit/nojit rand [pr] (#9574)

currently adding jit produced different rand values
This commit is contained in:
chenyu
2025-03-25 13:32:44 -04:00
committed by GitHub
parent 4cf2b68ca8
commit cddd750d68
2 changed files with 34 additions and 2 deletions

View File

@@ -277,6 +277,38 @@ class TestJit(unittest.TestCase):
assert len(res3) == 5, "All values should be different, rand works in jit."
assert res3 != res2, "Jit rand is diff with diff seeds"
@unittest.expectedFailure # TODO: fix
def test_jit_v_nojit_random_regen(self):
def f(a, b):
rn = Tensor.randn(*a.shape)
rn = rn * a
rn2 = Tensor.randn(*a.shape)
rn2 = rn2 * b
rn = rn + rn2
rn2 = rn2 + Tensor.randn(*a.shape)
return ((a+b)*rn).realize(), ((a+b)*rn2).realize()
Tensor.manual_seed(0)
a = Tensor.randn(10, 10).realize() # realize these before resetting the random seed
b = Tensor.randn(10, 10).realize()
Tensor.manual_seed(1234)
without_jit = set()
for _ in range(5):
o1, o2 = f(a, b)
without_jit.add(o1.numpy()[0][0])
without_jit.add(o2.numpy()[0][0])
assert len(without_jit) == 10, "All values should be different."
Tensor.manual_seed(1234)
jf = TinyJit(f)
with_jit = set()
for _ in range(5):
o1, o2 = jf(a, b)
with_jit.add(o1.numpy()[0][0])
with_jit.add(o2.numpy()[0][0])
assert len(with_jit) == 10, "All values should be different."
assert with_jit == without_jit, "Jit rand produced different values from no jit."
def test_jit_multiple_random_regen(self):
def f(a, b):
rn = Tensor.randn(*a.shape)