From 5977df267f08919802de0cca211d33e8d339f624 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 10 Oct 2025 10:25:25 +0800 Subject: [PATCH] outerworld uses expand (#12578) --- test/test_outerworld.py | 75 +++++++++++++++++++++++++++++++++++ test/test_rangeify.py | 72 +-------------------------------- tinygrad/schedule/indexing.py | 5 ++- 3 files changed, 79 insertions(+), 73 deletions(-) create mode 100644 test/test_outerworld.py diff --git a/test/test_outerworld.py b/test/test_outerworld.py new file mode 100644 index 0000000000..449d122017 --- /dev/null +++ b/test/test_outerworld.py @@ -0,0 +1,75 @@ +import unittest +from tinygrad import Tensor, UOp, GlobalCounters, Context + +class TestOuterworld(unittest.TestCase): + def test_range_plus_1(self): + t = Tensor.arange(100).reshape(10,10).realize() + + # passthrough ranges + a = UOp.range(10, -1) + sel = t[a] + 1 + assert sel.shape == (10,) + cpy = sel.reshape(1, 10).expand(a, 10).contiguous().realize() + + self.assertTrue((t+1==cpy).all().item()) + + def test_flip_range(self): + t = Tensor.rand(10, 10).realize() + + # passthrough ranges + a = UOp.range(10, -1) + sel = t[9-a] + cpy = sel.reshape(1, 10).expand(a, 10).contiguous().realize() + + self.assertTrue((t.flip(0)==cpy).all().item()) + + def test_vmap(self): + def f(x): return x.sum(axis=0)*2 + + x = Tensor.ones(3, 10, 2).contiguous() + + # vmap across axis 0 + a = UOp.range(3, -1) + out = f(x[a]) + out = out.reshape(1, 2).expand(a, 2).contiguous() + + # 3x2 grid of 20 + out.realize() + self.assertTrue((out==20).all().item()) + + @unittest.skip("opts don't work") + def test_triple_gemm(self): + x = Tensor.rand(1, 16).realize() + W = Tensor.rand(3, 16, 16).realize() + + manual = (x @ W[0] @ W[1] @ W[2]).contiguous().realize() + + a = UOp.range(3, -1) + x = x.assign(x @ W[a]) + out = x.contiguous(a)[-1].contiguous().realize() + + self.assertTrue((manual==out).all().item()) + + def test_setitem_pyrange(self): + with Context(DEBUG=0): + t = Tensor.rand(10).realize() + o = Tensor.empty(10) + GlobalCounters.reset() + for i in range(10): + o[i] = t[i] + o.realize() + self.assertTrue((t==o).all().item()) + + @unittest.skip("TODO: fix this") + def test_setitem(self): + with Context(DEBUG=0): + t = Tensor.rand(10).realize() + o = Tensor.empty(10) + GlobalCounters.reset() + i = UOp.range(10, -1) + o[i] = t[i] + o.contiguous(i).realize() + self.assertTrue((t==o).all().item()) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/test/test_rangeify.py b/test/test_rangeify.py index 144375c6f4..13cd9d9204 100644 --- a/test/test_rangeify.py +++ b/test/test_rangeify.py @@ -1,7 +1,7 @@ import unittest from tinygrad import Tensor, nn from tinygrad.helpers import Context, GlobalCounters -from tinygrad.uop.ops import UOp, graph_rewrite, PatternMatcher, UPat, Ops +from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops class TestRangeifyAssign(unittest.TestCase): def test_assign_permuted(self): @@ -229,76 +229,6 @@ class TestRangeify(unittest.TestCase): # contiguous + reduce can support ranges? -@unittest.skip("okay to disable this for now") -class TestOuterworld(unittest.TestCase): - def test_passthrough_range(self): - t = Tensor.rand(10, 10).realize() - - # passthrough ranges - a = UOp.range(10, -1) - sel = t[a] - cpy = sel.contiguous(a).realize() - - self.assertTrue((t==cpy).all().item()) - - def test_flip_range(self): - t = Tensor.rand(10, 10).realize() - - # passthrough ranges - a = UOp.range(10, -1) - sel = t[9-a] - cpy = sel.contiguous(a).realize() - - self.assertTrue((t.flip(0)==cpy).all().item()) - - def test_vmap(self): - def f(x): return x.sum(axis=0)*2 - - x = Tensor.ones(3, 10, 2).contiguous() - - # vmap across axis 0 - a = UOp.range(3, -1) - out = f(x[a]) - out = out.contiguous(a) - - # 3x2 grid of 20 - out.realize() - print(out.numpy()) - - @unittest.skip("opts don't work") - def test_triple_gemm(self): - x = Tensor.rand(1, 16).realize() - W = Tensor.rand(3, 16, 16).realize() - - manual = (x @ W[0] @ W[1] @ W[2]).contiguous().realize() - - a = UOp.range(3, -1) - x = x.assign(x @ W[a]) - out = x.contiguous(a)[-1].contiguous().realize() - - self.assertTrue((manual==out).all().item()) - - def test_setitem_pyrange(self): - with Context(DEBUG=0): - t = Tensor.rand(10).realize() - o = Tensor.empty(10) - GlobalCounters.reset() - for i in range(10): - o[i] = t[i] - o.realize() - self.assertTrue((t==o).all().item()) - - @unittest.skip("TODO: fix this") - def test_setitem(self): - with Context(DEBUG=0): - t = Tensor.rand(10).realize() - o = Tensor.empty(10) - GlobalCounters.reset() - i = UOp.range(10, -1) - o[i] = t[i] - o.contiguous(i).realize() - self.assertTrue((t==o).all().item()) - @unittest.skip("pm_rangeify no longer exists. test this in a different way") class TestRangeifyPM(unittest.TestCase): def setUp(self): self.base = Tensor.empty(10*10).reshape(10, 10).contiguous() diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 8741401f33..20fa054852 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -157,7 +157,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: consumer_rngs = [rctx.range_map[c][0] for c in consumer_map[x] if c in rctx.range_map] if x in rctx.realize_map: # if this is in the realize_map, we create new ranges (at the output) - out_rngs = [rctx.new_range(s) for s in x.shape] + out_rngs = [rctx.new_range(s) if not isinstance(s, UOp) or s.op is not Ops.RANGE else s for s in x.shape] # all ranges are ended now ending_ranges[x] = False elif x.op in {Ops.MSTACK, Ops.MSELECT}: @@ -207,7 +207,8 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: # apply movement ops if x.op in GroupOp.Movement: rngs = apply_movement_op(x, rngs) - if x.op is Ops.EXPAND: ending_ranges[x] = True + # if the EXPAND is used to inject a range, we don't mark it as ending_ranges. otherwise we do. + if x.op is Ops.EXPAND and all(isinstance(y, int) or y.op is not Ops.RANGE for y in x.shape): ending_ranges[x] = True # REDUCE_AXIS creates ranges for the axes it is reducing if x.op is Ops.REDUCE_AXIS: