mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
outerworld uses expand (#12578)
This commit is contained in:
75
test/test_outerworld.py
Normal file
75
test/test_outerworld.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, UOp, GlobalCounters, Context
|
||||
|
||||
class TestOuterworld(unittest.TestCase):
|
||||
def test_range_plus_1(self):
|
||||
t = Tensor.arange(100).reshape(10,10).realize()
|
||||
|
||||
# passthrough ranges
|
||||
a = UOp.range(10, -1)
|
||||
sel = t[a] + 1
|
||||
assert sel.shape == (10,)
|
||||
cpy = sel.reshape(1, 10).expand(a, 10).contiguous().realize()
|
||||
|
||||
self.assertTrue((t+1==cpy).all().item())
|
||||
|
||||
def test_flip_range(self):
|
||||
t = Tensor.rand(10, 10).realize()
|
||||
|
||||
# passthrough ranges
|
||||
a = UOp.range(10, -1)
|
||||
sel = t[9-a]
|
||||
cpy = sel.reshape(1, 10).expand(a, 10).contiguous().realize()
|
||||
|
||||
self.assertTrue((t.flip(0)==cpy).all().item())
|
||||
|
||||
def test_vmap(self):
|
||||
def f(x): return x.sum(axis=0)*2
|
||||
|
||||
x = Tensor.ones(3, 10, 2).contiguous()
|
||||
|
||||
# vmap across axis 0
|
||||
a = UOp.range(3, -1)
|
||||
out = f(x[a])
|
||||
out = out.reshape(1, 2).expand(a, 2).contiguous()
|
||||
|
||||
# 3x2 grid of 20
|
||||
out.realize()
|
||||
self.assertTrue((out==20).all().item())
|
||||
|
||||
@unittest.skip("opts don't work")
|
||||
def test_triple_gemm(self):
|
||||
x = Tensor.rand(1, 16).realize()
|
||||
W = Tensor.rand(3, 16, 16).realize()
|
||||
|
||||
manual = (x @ W[0] @ W[1] @ W[2]).contiguous().realize()
|
||||
|
||||
a = UOp.range(3, -1)
|
||||
x = x.assign(x @ W[a])
|
||||
out = x.contiguous(a)[-1].contiguous().realize()
|
||||
|
||||
self.assertTrue((manual==out).all().item())
|
||||
|
||||
def test_setitem_pyrange(self):
|
||||
with Context(DEBUG=0):
|
||||
t = Tensor.rand(10).realize()
|
||||
o = Tensor.empty(10)
|
||||
GlobalCounters.reset()
|
||||
for i in range(10):
|
||||
o[i] = t[i]
|
||||
o.realize()
|
||||
self.assertTrue((t==o).all().item())
|
||||
|
||||
@unittest.skip("TODO: fix this")
|
||||
def test_setitem(self):
|
||||
with Context(DEBUG=0):
|
||||
t = Tensor.rand(10).realize()
|
||||
o = Tensor.empty(10)
|
||||
GlobalCounters.reset()
|
||||
i = UOp.range(10, -1)
|
||||
o[i] = t[i]
|
||||
o.contiguous(i).realize()
|
||||
self.assertTrue((t==o).all().item())
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, nn
|
||||
from tinygrad.helpers import Context, GlobalCounters
|
||||
from tinygrad.uop.ops import UOp, graph_rewrite, PatternMatcher, UPat, Ops
|
||||
from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops
|
||||
|
||||
class TestRangeifyAssign(unittest.TestCase):
|
||||
def test_assign_permuted(self):
|
||||
@@ -229,76 +229,6 @@ class TestRangeify(unittest.TestCase):
|
||||
|
||||
# contiguous + reduce can support ranges?
|
||||
|
||||
@unittest.skip("okay to disable this for now")
|
||||
class TestOuterworld(unittest.TestCase):
|
||||
def test_passthrough_range(self):
|
||||
t = Tensor.rand(10, 10).realize()
|
||||
|
||||
# passthrough ranges
|
||||
a = UOp.range(10, -1)
|
||||
sel = t[a]
|
||||
cpy = sel.contiguous(a).realize()
|
||||
|
||||
self.assertTrue((t==cpy).all().item())
|
||||
|
||||
def test_flip_range(self):
|
||||
t = Tensor.rand(10, 10).realize()
|
||||
|
||||
# passthrough ranges
|
||||
a = UOp.range(10, -1)
|
||||
sel = t[9-a]
|
||||
cpy = sel.contiguous(a).realize()
|
||||
|
||||
self.assertTrue((t.flip(0)==cpy).all().item())
|
||||
|
||||
def test_vmap(self):
|
||||
def f(x): return x.sum(axis=0)*2
|
||||
|
||||
x = Tensor.ones(3, 10, 2).contiguous()
|
||||
|
||||
# vmap across axis 0
|
||||
a = UOp.range(3, -1)
|
||||
out = f(x[a])
|
||||
out = out.contiguous(a)
|
||||
|
||||
# 3x2 grid of 20
|
||||
out.realize()
|
||||
print(out.numpy())
|
||||
|
||||
@unittest.skip("opts don't work")
|
||||
def test_triple_gemm(self):
|
||||
x = Tensor.rand(1, 16).realize()
|
||||
W = Tensor.rand(3, 16, 16).realize()
|
||||
|
||||
manual = (x @ W[0] @ W[1] @ W[2]).contiguous().realize()
|
||||
|
||||
a = UOp.range(3, -1)
|
||||
x = x.assign(x @ W[a])
|
||||
out = x.contiguous(a)[-1].contiguous().realize()
|
||||
|
||||
self.assertTrue((manual==out).all().item())
|
||||
|
||||
def test_setitem_pyrange(self):
|
||||
with Context(DEBUG=0):
|
||||
t = Tensor.rand(10).realize()
|
||||
o = Tensor.empty(10)
|
||||
GlobalCounters.reset()
|
||||
for i in range(10):
|
||||
o[i] = t[i]
|
||||
o.realize()
|
||||
self.assertTrue((t==o).all().item())
|
||||
|
||||
@unittest.skip("TODO: fix this")
|
||||
def test_setitem(self):
|
||||
with Context(DEBUG=0):
|
||||
t = Tensor.rand(10).realize()
|
||||
o = Tensor.empty(10)
|
||||
GlobalCounters.reset()
|
||||
i = UOp.range(10, -1)
|
||||
o[i] = t[i]
|
||||
o.contiguous(i).realize()
|
||||
self.assertTrue((t==o).all().item())
|
||||
|
||||
@unittest.skip("pm_rangeify no longer exists. test this in a different way")
|
||||
class TestRangeifyPM(unittest.TestCase):
|
||||
def setUp(self): self.base = Tensor.empty(10*10).reshape(10, 10).contiguous()
|
||||
|
||||
@@ -157,7 +157,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
consumer_rngs = [rctx.range_map[c][0] for c in consumer_map[x] if c in rctx.range_map]
|
||||
if x in rctx.realize_map:
|
||||
# if this is in the realize_map, we create new ranges (at the output)
|
||||
out_rngs = [rctx.new_range(s) for s in x.shape]
|
||||
out_rngs = [rctx.new_range(s) if not isinstance(s, UOp) or s.op is not Ops.RANGE else s for s in x.shape]
|
||||
# all ranges are ended now
|
||||
ending_ranges[x] = False
|
||||
elif x.op in {Ops.MSTACK, Ops.MSELECT}:
|
||||
@@ -207,7 +207,8 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
|
||||
# apply movement ops
|
||||
if x.op in GroupOp.Movement: rngs = apply_movement_op(x, rngs)
|
||||
if x.op is Ops.EXPAND: ending_ranges[x] = True
|
||||
# if the EXPAND is used to inject a range, we don't mark it as ending_ranges. otherwise we do.
|
||||
if x.op is Ops.EXPAND and all(isinstance(y, int) or y.op is not Ops.RANGE for y in x.shape): ending_ranges[x] = True
|
||||
|
||||
# REDUCE_AXIS creates ranges for the axes it is reducing
|
||||
if x.op is Ops.REDUCE_AXIS:
|
||||
|
||||
Reference in New Issue
Block a user