outerworld uses expand (#12578)

This commit is contained in:
George Hotz
2025-10-10 10:25:25 +08:00
committed by GitHub
parent f2c3a72b0c
commit 5977df267f
3 changed files with 79 additions and 73 deletions

View File

@@ -1,7 +1,7 @@
import unittest
from tinygrad import Tensor, nn
from tinygrad.helpers import Context, GlobalCounters
from tinygrad.uop.ops import UOp, graph_rewrite, PatternMatcher, UPat, Ops
from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops
class TestRangeifyAssign(unittest.TestCase):
def test_assign_permuted(self):
@@ -229,76 +229,6 @@ class TestRangeify(unittest.TestCase):
# contiguous + reduce can support ranges?
@unittest.skip("okay to disable this for now")
class TestOuterworld(unittest.TestCase):
def test_passthrough_range(self):
t = Tensor.rand(10, 10).realize()
# passthrough ranges
a = UOp.range(10, -1)
sel = t[a]
cpy = sel.contiguous(a).realize()
self.assertTrue((t==cpy).all().item())
def test_flip_range(self):
t = Tensor.rand(10, 10).realize()
# passthrough ranges
a = UOp.range(10, -1)
sel = t[9-a]
cpy = sel.contiguous(a).realize()
self.assertTrue((t.flip(0)==cpy).all().item())
def test_vmap(self):
def f(x): return x.sum(axis=0)*2
x = Tensor.ones(3, 10, 2).contiguous()
# vmap across axis 0
a = UOp.range(3, -1)
out = f(x[a])
out = out.contiguous(a)
# 3x2 grid of 20
out.realize()
print(out.numpy())
@unittest.skip("opts don't work")
def test_triple_gemm(self):
x = Tensor.rand(1, 16).realize()
W = Tensor.rand(3, 16, 16).realize()
manual = (x @ W[0] @ W[1] @ W[2]).contiguous().realize()
a = UOp.range(3, -1)
x = x.assign(x @ W[a])
out = x.contiguous(a)[-1].contiguous().realize()
self.assertTrue((manual==out).all().item())
def test_setitem_pyrange(self):
with Context(DEBUG=0):
t = Tensor.rand(10).realize()
o = Tensor.empty(10)
GlobalCounters.reset()
for i in range(10):
o[i] = t[i]
o.realize()
self.assertTrue((t==o).all().item())
@unittest.skip("TODO: fix this")
def test_setitem(self):
with Context(DEBUG=0):
t = Tensor.rand(10).realize()
o = Tensor.empty(10)
GlobalCounters.reset()
i = UOp.range(10, -1)
o[i] = t[i]
o.contiguous(i).realize()
self.assertTrue((t==o).all().item())
@unittest.skip("pm_rangeify no longer exists. test this in a different way")
class TestRangeifyPM(unittest.TestCase):
def setUp(self): self.base = Tensor.empty(10*10).reshape(10, 10).contiguous()