diff --git a/test/test_jit.py b/test/test_jit.py index 11657e3397..a7d4536cfd 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -833,5 +833,20 @@ class TestJitGraphSplit(unittest.TestCase): multigraph=[self.ji_graph(2), self.ji_copy(), self.ji_comp()], hcqgraph=[self.ji_graph(4)]) +class TestJitRandom(unittest.TestCase): + def test_jit_rangeify(self): + tst = {0:[], 1:[]} + for r in [0,1]: + Tensor.manual_seed(1337) + with Context(RANGEIFY=r): + _ = Tensor.randint(4, high=3) + # this second one makes the behavior different + _ = Tensor.randint(4, high=3) + @TinyJit + def f(): return Tensor.randint(20, high=5) + for _ in range(5): tst[r].append(f().tolist()) + for i, (t0, t1) in enumerate(zip(tst[0], tst[1])): + self.assertListEqual(t0, t1, msg=f"mismatch at list {i}") + if __name__ == '__main__': unittest.main() diff --git a/test/test_softmax_fusion.py b/test/test_softmax_fusion.py index 54a801f72c..554f793d69 100644 --- a/test/test_softmax_fusion.py +++ b/test/test_softmax_fusion.py @@ -30,7 +30,7 @@ def single_kernel_softmax(x_in:Tensor, axis=-1, dtype:DTypeLike|None=None) -> Te def run_one_schedule_item(out): lower_schedule_item(get_single_element(out.schedule())).run() class TestFuse(unittest.TestCase): - def _test_fuse(self, fxn, *args, atol=1e-7, allow_multiple=False, **kwargs): + def _test_fuse(self, fxn, *args, atol=1e-6, allow_multiple=False, **kwargs): GlobalCounters.reset() out_single = fxn(*args, **kwargs).fuse() if not allow_multiple: run_one_schedule_item(out_single) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index c4371dea64..97cc691066 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -42,9 +42,9 @@ earliest_rewrites = double_reshape+PatternMatcher([ (UPat(Ops.COPY, src=(UPat(GroupOp.Movement, name="r"), UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.contiguous(), d)) if r.size != r.base.size else None), - # assign only to buffer + # assign only to buffer, otherwise make it a CONTIGUOUS (UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.BUFFER}, name="target"), UPat(name="x")), name="assign"), - lambda x,target,assign: x.f(Ops.NOOP, tag=assign.tag) if target.base.op is not Ops.BUFFER else None), + lambda x,target,assign: x.f(Ops.CONTIGUOUS, tag=assign.tag) if target.base.op is not Ops.BUFFER else None), # realize before assign if input permutes the target buffer (UPat(Ops.ASSIGN, src=(UPat.var("a"), UPat.var("b")), name="assign"), lambda a,b,assign: assign.replace(src=(a, b.contiguous())) \