mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
refactor to @track_matches + add failing test_nested_rewrite (#9592)
* test_nested_rewrite * refactor to track_matches * positional arg
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import unittest, decimal, json
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites, UPat
|
||||
from tinygrad.codegen.symbolic import symbolic
|
||||
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt, _substitute
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
|
||||
@@ -10,6 +10,10 @@ from tinygrad.viz.serve import get_metadata, uop_to_json, to_perfetto
|
||||
symbolic = TrackedPatternMatcher(symbolic.patterns)
|
||||
substitute = TrackedPatternMatcher(_substitute.patterns)
|
||||
|
||||
inner_rewrite = TrackedPatternMatcher([
|
||||
(UPat.cvar("x"), lambda x: None if x.dtype == dtypes.float32 else UOp.const(dtypes.float32, x.arg)),
|
||||
])
|
||||
|
||||
class TestViz(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clear the global context
|
||||
@@ -99,6 +103,13 @@ class TestViz(unittest.TestCase):
|
||||
key = get_metadata(keys, contexts)[1][0]
|
||||
self.assertEqual(key, "output_(a+b) n2")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_name_in_positional_arg(self):
|
||||
@track_rewrites(named=True)
|
||||
def test(sink): return graph_rewrite(sink, symbolic, None, False, "name")
|
||||
test(UOp.variable("a", 0, 1))
|
||||
self.assertEqual(contexts[0].pop().name, "name")
|
||||
|
||||
# NOTE: CONST UOps do not get nodes in the graph
|
||||
def test_dont_create_const_nodes(self):
|
||||
a = UOp.variable("a", 0, 10)
|
||||
@@ -145,6 +156,25 @@ class TestViz(unittest.TestCase):
|
||||
self.assertEqual(lineno, inner_rewrite.__code__.co_firstlineno)
|
||||
self.assertEqual(fp, inner_rewrite.__code__.co_filename)
|
||||
|
||||
def test_nested_rewrite(self):
|
||||
def make_float(x:UOp, y:UOp):
|
||||
if x.dtype == dtypes.float: return None
|
||||
x2 = graph_rewrite(x, inner_rewrite, name="inner_x")
|
||||
y2 = graph_rewrite(y, inner_rewrite, name="inner_y")
|
||||
return None if (x2 is x and y2 is y) else x2+y2
|
||||
outer_rewrite = TrackedPatternMatcher([(UPat.cvar("x")+UPat.cvar("y"), make_float),])
|
||||
@track_rewrites(named=True)
|
||||
def rewrite(u:UOp): return graph_rewrite(u, outer_rewrite, name="outer")
|
||||
a = UOp.const(dtypes.int, 1)+UOp.const(dtypes.int, 2)
|
||||
rewrite(a)
|
||||
self.assertEqual(len(contexts), 1)
|
||||
tracked = contexts[0]
|
||||
self.assertEqual(len(tracked), 3)
|
||||
# TODO: reorder nested rewrites by the deepest one
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertEqual([x.name for x in tracked], ["inner_x", "inner_y", "outer"])
|
||||
self.assertEqual([len(x.matches) for x in tracked], [1, 1, 1])
|
||||
|
||||
class TextVizProfiler(unittest.TestCase):
|
||||
def test_perfetto_node(self):
|
||||
prof = [ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=False),
|
||||
|
||||
@@ -840,8 +840,8 @@ class TrackedGraphRewrite:
|
||||
loc: tuple[str, int] # location that called graph_rewrite
|
||||
sink: UOp # the sink input to graph_rewrite
|
||||
bottom_up: bool
|
||||
matches: list[tuple[UOp, UOp, UPat]] = field(default_factory=list) # before+after of all the matches
|
||||
name: Optional[str] = None
|
||||
matches: list[tuple[UOp, UOp, UPat]] # before+after of all the matches
|
||||
name: str|None
|
||||
tracked_keys:list[Any] = []
|
||||
tracked_ctxs:list[list[TrackedGraphRewrite]] = []
|
||||
_name_cnt:dict[str, int] = {}
|
||||
@@ -858,6 +858,15 @@ def track_rewrites(named=False, name_fxn:Callable|None=None):
|
||||
return __wrapper
|
||||
return _decorator
|
||||
|
||||
def track_matches(func):
|
||||
def _track_func(*args, **kwargs):
|
||||
if TRACK_MATCH_STATS >= 2 and tracked_ctxs:
|
||||
loc = ((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno)
|
||||
tracked_ctxs[-1].append(TrackedGraphRewrite(loc, args[0], kwargs.get("bottom_up", False), [], kwargs.get("name", None)))
|
||||
ret = func(*args, **kwargs)
|
||||
return ret
|
||||
return _track_func
|
||||
|
||||
class TrackedPatternMatcher(PatternMatcher):
|
||||
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
||||
ret = None
|
||||
@@ -939,15 +948,13 @@ class RewriteContext:
|
||||
self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg))
|
||||
return ret
|
||||
|
||||
@track_matches
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, track_children=False) -> UOp:
|
||||
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
|
||||
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
|
||||
rewrite_ctx = RewriteContext(pm, ctx, children=sink.get_children_map() if track_children else None)
|
||||
return rewrite_ctx.bottom_up_rewrite(sink) if bottom_up else rewrite_ctx.top_down_rewrite(sink)
|
||||
|
||||
@track_matches
|
||||
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, track_children=False) -> dict[UOp, UOp]:
|
||||
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
|
||||
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
|
||||
rewrite_ctx = RewriteContext(pm, ctx, children=sink.get_children_map() if track_children else None)
|
||||
return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user