mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
start on test rewrite map [pr] (#8432)
* start on test rewrite map [pr] * chatgpt writes dumb tests * comment out failing * fix that test * fix gc issue * oh, frame 2 * remove uop mutability * map is only the map * simplier + more tests * test tiny passes * tests that need to pass * parent test passes * child test passes * remove uop mutability [pr] * test fixups * most tests pass * more tests pass * lil test fixups * them too * fix test * unneeded * err, that * fix test_hcq * fix test failures * fix that test * tensor universe * does this pass test * Revert "does this pass test" This reverts commited516b3169. * Revert "tensor universe" This reverts commitc21301852a. * test_mutate_add passes * this can pass * Revert "Merge remote-tracking branch 'origin/no_uop_mutability' into test_rewrite_map" This reverts commit657822dcdc, reversing changes made to2a126c145b. * Revert "test_mutate_add passes" This reverts commitab4fc4c78e. * correct enough * remove test_rewrite_map_schedule.py * viz * uops are immutable --------- Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
218
test/unit/test_rewrite_map.py
Normal file
218
test/unit/test_rewrite_map.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import unittest
|
||||
from tinygrad import dtypes, Tensor
|
||||
from tinygrad.ops import UOp, symbolic, graph_rewrite_map, _substitute
|
||||
from test.unit.test_tensor_uop_representation import is_pattern, realized_pattern, is_pattern_uop
|
||||
|
||||
class TestTensorMutates(unittest.TestCase):
|
||||
def test_mutate_add(self):
|
||||
a = Tensor([1,2,3])
|
||||
b = Tensor([4,5,6])
|
||||
ret = a+b
|
||||
pa = a.lazydata
|
||||
pb = b.lazydata
|
||||
pr = ret.lazydata
|
||||
ret.schedule()
|
||||
self.assertIsNot(pa, a.lazydata)
|
||||
self.assertIsNot(pb, b.lazydata)
|
||||
self.assertIsNot(pr, ret.lazydata)
|
||||
for t in [a,b,ret]: is_pattern(t, realized_pattern)
|
||||
|
||||
def test_reshape_is_same_parent(self):
|
||||
a = Tensor([1,2,3])
|
||||
b = Tensor([4,5,6])
|
||||
c = a+b
|
||||
d = (a+b).reshape(3,1)
|
||||
d.realize()
|
||||
is_pattern_uop(d.lazydata.base, realized_pattern)
|
||||
is_pattern_uop(c.lazydata.base, realized_pattern)
|
||||
|
||||
def test_reshape_is_same_child(self):
|
||||
a = Tensor([1,2,3])
|
||||
b = Tensor([4,5,6])
|
||||
c = a+b
|
||||
d = (a+b).reshape(3,1)
|
||||
c.realize()
|
||||
is_pattern_uop(c.lazydata.base, realized_pattern)
|
||||
is_pattern_uop(d.lazydata.base, realized_pattern)
|
||||
|
||||
class TestRewriteMap(unittest.TestCase):
|
||||
def test_substitute(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
b = UOp.variable('b', 0, 10)
|
||||
c = UOp.variable('c', 0, 10)
|
||||
e = UOp.variable('e', 0, 10)
|
||||
ret = (a+b)*c
|
||||
sub = {a+b: e}
|
||||
sub_map = graph_rewrite_map(ret, _substitute, sub, bottom_up=True)
|
||||
self.assertIs(sub_map[a+b], e)
|
||||
self.assertIs(sub_map[(a+b)*c], e*c)
|
||||
|
||||
def test_substitute_depth_2(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
b = UOp.variable('b', 0, 10)
|
||||
c = UOp.variable('c', 0, 10)
|
||||
d = UOp.variable('d', 0, 10)
|
||||
e = UOp.variable('e', 0, 10)
|
||||
f = UOp.variable('f', 0, 10)
|
||||
ret = (a+b)*c+d
|
||||
sub = {a+b: e, (a+b)*c: f}
|
||||
sub_map = graph_rewrite_map(ret, _substitute, sub, bottom_up=True)
|
||||
self.assertIs(sub_map[a+b], e)
|
||||
self.assertIs(sub_map[(a+b)*c], f)
|
||||
|
||||
def test_add_zero(self):
|
||||
# Build a small graph: add(0, add(const=0, const=5))
|
||||
zero_node = UOp.const(dtypes.int, 0)
|
||||
five_node = UOp.const(dtypes.int, 5)
|
||||
inner_add = zero_node + five_node
|
||||
root_add = zero_node + inner_add
|
||||
|
||||
# Perform top-down rewrite
|
||||
node_map = graph_rewrite_map(root_add, symbolic)
|
||||
|
||||
# We expect that add(0, add(0, 5)) -> add(0, 5) -> 5
|
||||
# Check the mapping
|
||||
assert node_map[root_add] == five_node
|
||||
assert node_map[inner_add] == five_node
|
||||
# zero_node and five_node map to themselves
|
||||
assert node_map[zero_node] == zero_node
|
||||
assert node_map[five_node] == five_node
|
||||
|
||||
def test_double_neg(self):
|
||||
"""
|
||||
Test rewriting neg(neg(5)) => 5 using symbolic.
|
||||
"""
|
||||
# In some versions of TinyGrad, you might do: (-(-five_node))
|
||||
five_node = UOp.const(dtypes.int, 5)
|
||||
# If your code allows UOp(...), do that; else you might do something like:
|
||||
# double_neg_five = -(-five_node)
|
||||
# But let's be explicit:
|
||||
neg_five = -five_node
|
||||
double_neg_five = -neg_five
|
||||
|
||||
node_map = graph_rewrite_map(double_neg_five, symbolic)
|
||||
|
||||
# node_map should map double_neg_five -> five_node
|
||||
self.assertEqual(node_map[double_neg_five], five_node)
|
||||
# five_node maps to itself
|
||||
self.assertEqual(node_map[five_node], five_node)
|
||||
|
||||
def test_add_zero_and_double_neg(self):
|
||||
"""
|
||||
Combine both rewrites: add(0, neg(neg(5))) => add(0, 5) => 5
|
||||
"""
|
||||
zero_node = UOp.const(dtypes.int, 0)
|
||||
five_node = UOp.const(dtypes.int, 5)
|
||||
neg_five = -five_node
|
||||
double_neg_five = -neg_five
|
||||
root_add = zero_node + double_neg_five
|
||||
|
||||
node_map = graph_rewrite_map(root_add, symbolic)
|
||||
|
||||
# node_map: root_add -> five_node, double_neg_five -> five_node
|
||||
self.assertEqual(node_map[root_add], five_node)
|
||||
self.assertEqual(node_map[double_neg_five], five_node)
|
||||
# zero_node, five_node map to themselves
|
||||
self.assertEqual(node_map[zero_node], zero_node)
|
||||
self.assertEqual(node_map[five_node], five_node)
|
||||
|
||||
def test_multi_var_rewrites(self):
|
||||
x_var = UOp.variable('x', 0, 10)
|
||||
y_var = UOp.variable('y', -5, 5)
|
||||
zero_node = UOp.const(dtypes.int, 0)
|
||||
|
||||
sum_with_zero = y_var + zero_node # (y + 0)
|
||||
combined = x_var + sum_with_zero # x + (y + 0)
|
||||
double_neg = -(-combined) # neg(neg(x + y))
|
||||
final_expr = zero_node + double_neg # 0 + (x + y)
|
||||
|
||||
node_map = graph_rewrite_map(final_expr, symbolic)
|
||||
|
||||
# The final root should be (x_var + y_var).
|
||||
expected = x_var + y_var
|
||||
|
||||
# Each sub-expression has its own "final" result.
|
||||
# (y + 0) -> y_var
|
||||
self.assertEqual(node_map[sum_with_zero], y_var)
|
||||
# (x + (y+0)) -> (x + y)
|
||||
self.assertEqual(node_map[combined], expected)
|
||||
# neg(neg(x+y)) -> (x + y)
|
||||
self.assertEqual(node_map[double_neg], expected)
|
||||
# 0 + (x+y) -> (x + y)
|
||||
self.assertEqual(node_map[final_expr], expected)
|
||||
|
||||
# x_var, y_var, zero_node remain unchanged
|
||||
self.assertEqual(node_map[x_var], x_var)
|
||||
self.assertEqual(node_map[y_var], y_var)
|
||||
self.assertEqual(node_map[zero_node], zero_node)
|
||||
|
||||
def test_complex_multi_var_edges(self):
|
||||
"""
|
||||
Build a multi-variable expression with multiple intermediates:
|
||||
|
||||
x_var = UOp.variable('x', 1, 10)
|
||||
y_var = UOp.variable('y', -5, 5)
|
||||
z_var = UOp.variable('z', 0, 5)
|
||||
zero_node = UOp.const(dtypes.int, 0)
|
||||
one_node = UOp.const(dtypes.int, 1)
|
||||
|
||||
yz_sum = y_var + z_var
|
||||
yz_sum_zero = yz_sum + zero_node -> rewrites to yz_sum
|
||||
yz_neg = -yz_sum_zero -> -(y+z)
|
||||
yz_dneg = -yz_neg -> y+z (double neg gone)
|
||||
x_plus_yz = x_var + yz_dneg -> x + (y+z)
|
||||
double_neg_x = -(-x_plus_yz) -> x + (y+z)
|
||||
final_expr = double_neg_x * one_node -> x + (y+z)
|
||||
|
||||
We expect the final result to be (x + (y+z)).
|
||||
Each original node should map to the final node that replaces it,
|
||||
which might be structurally equivalent but not the same reference.
|
||||
"""
|
||||
x_var = UOp.variable('x', 1, 10)
|
||||
y_var = UOp.variable('y', -5, 5)
|
||||
z_var = UOp.variable('z', 0, 5)
|
||||
zero_node = UOp.const(dtypes.int, 0)
|
||||
one_node = UOp.const(dtypes.int, 1)
|
||||
|
||||
# Build sub-expressions
|
||||
yz_sum = y_var + z_var # (y + z)
|
||||
yz_sum_zero = yz_sum + zero_node # (y + z) + 0
|
||||
yz_neg = -yz_sum_zero # -(y+z)
|
||||
yz_dneg = -yz_neg # -(-(y+z)) -> (y+z)
|
||||
x_plus_yz = x_var + yz_dneg # x + (y+z)
|
||||
double_neg_x = -(-x_plus_yz) # neg(neg(x+(y+z))) -> x+(y+z)
|
||||
final_expr = double_neg_x * one_node # (x+(y+z)) * 1 -> x+(y+z)
|
||||
|
||||
node_map = graph_rewrite_map(final_expr, symbolic)
|
||||
|
||||
# (y + z) is unchanged
|
||||
self.assertEqual(node_map[yz_sum], yz_sum)
|
||||
|
||||
# (y+z) + 0 => (y+z)
|
||||
self.assertEqual(node_map[yz_sum_zero], yz_sum)
|
||||
|
||||
# -(y+z) remains -(y+z), but might be a new UOp with updated children
|
||||
# Compare structurally to -(y_var + z_var).
|
||||
self.assertEqual(node_map[yz_neg], -yz_sum)
|
||||
|
||||
# -(-(y+z)) => (y+z)
|
||||
self.assertEqual(node_map[yz_dneg], yz_sum)
|
||||
|
||||
# x + (y+z) => might get recreated if yz_dneg was changed, so compare to x + yz_sum
|
||||
self.assertEqual(node_map[x_plus_yz], x_var + yz_sum)
|
||||
|
||||
# -(-(x+(y+z))) => x + (y+z)
|
||||
self.assertEqual(node_map[double_neg_x], x_var + yz_sum)
|
||||
|
||||
# (x+(y+z)) * 1 => x+(y+z)
|
||||
self.assertEqual(node_map[final_expr], x_var + yz_sum)
|
||||
|
||||
# Unchanged atomic nodes map to themselves
|
||||
self.assertEqual(node_map[x_var], x_var)
|
||||
self.assertEqual(node_map[y_var], y_var)
|
||||
self.assertEqual(node_map[z_var], z_var)
|
||||
self.assertEqual(node_map[zero_node], zero_node)
|
||||
self.assertEqual(node_map[one_node], one_node)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -876,7 +876,7 @@ def launch_viz(env_str:str, data:str):
|
||||
# *** simple graph rewrite engine ***
|
||||
|
||||
class RewriteContext:
|
||||
def __init__(self, pm, ctx):
|
||||
def __init__(self, pm, ctx=None):
|
||||
self.pm: PatternMatcher = pm
|
||||
self.ctx = ctx
|
||||
self.replace: dict[UOp, UOp] = {}
|
||||
@@ -899,6 +899,12 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp
|
||||
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
|
||||
return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).rewrite(sink)
|
||||
|
||||
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> dict[UOp, UOp]:
|
||||
if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True
|
||||
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
|
||||
rewrite_ctx = RewriteContext(pm, ctx)
|
||||
return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.rewrite(k)) for k in list(sink.toposort)[::-1]}
|
||||
|
||||
# ***** uop type spec *****
|
||||
|
||||
# this is the matcher for the final rendered UOps
|
||||
|
||||
Reference in New Issue
Block a user