mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
viz with bottom_up=True (#7894)
* add failing test * single pass it * linter
This commit is contained in:
@@ -4,11 +4,11 @@ from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, contexts, track_rewrites
|
||||
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json
|
||||
|
||||
@track_rewrites()
|
||||
def rewrite(sink:UOp, pm:PatternMatcher, ctx=None): return graph_rewrite(sink, pm, ctx)
|
||||
@track_rewrites(named=True)
|
||||
def rewrite(sink:UOp, pm:PatternMatcher, **kwargs): return graph_rewrite(sink, pm, **kwargs)
|
||||
|
||||
def helper_test_viz(sink:UOp, pm:PatternMatcher, ctx=None) -> List[UOp]:
|
||||
rewrite(sink, pm, ctx)
|
||||
def helper_test_viz(sink:UOp, pm:PatternMatcher, **kwargs) -> List[UOp]:
|
||||
rewrite(sink, pm, **kwargs)
|
||||
assert len(contexts) == 1
|
||||
assert len(contexts[0][1]) == 1
|
||||
k = get_metadata(contexts)[0][0]
|
||||
@@ -52,7 +52,7 @@ class TestViz(unittest.TestCase):
|
||||
pm = PatternMatcher([
|
||||
(UPat(Ops.LOAD, name="x"), store_load),
|
||||
])
|
||||
uops = helper_test_viz(a+b, pm, {})
|
||||
uops = helper_test_viz(a+b, pm, ctx={})
|
||||
self.assertEqual(len(uops), 2)
|
||||
self.assertEqual(uops[-1], graph_rewrite(a+b, pm, {}))
|
||||
|
||||
@@ -89,5 +89,25 @@ class TestViz(unittest.TestCase):
|
||||
assert not any(v[0].startswith("CONST") for v in graph.values())
|
||||
assert len([x for x in graph.values() if "CONST" in x[0]]) == 1
|
||||
|
||||
def test_bottom_up_rewrite(self):
|
||||
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||
n1 = a.sin()
|
||||
uop = n1.sin()
|
||||
pm = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
ret = helper_test_viz(uop, pm, ctx={a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=True)
|
||||
self.assertEqual(len(ret), 2)
|
||||
self.assertIs(ret[0], a.sin().sqrt()) # first rewrite
|
||||
self.assertIs(ret[1], a.sqrt().sqrt()) # second one
|
||||
|
||||
def test_top_down_rewrite(self):
|
||||
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||
n1 = a.sin()
|
||||
uop = n1.sin()
|
||||
pm = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
# if it wasn't bottom_up, it's rewritten once
|
||||
ret = helper_test_viz(uop, pm, ctx={a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=False)
|
||||
self.assertEqual(len(ret), 1)
|
||||
self.assertIs(ret[0], a.sqrt().sin()) # only rewrite
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -688,6 +688,7 @@ match_stats:Dict[UPat, List[Union[int, float]]] = dict()
|
||||
class TrackedRewriteContext:
|
||||
loc: Tuple[str, int] # location that called graph_rewrite
|
||||
sink: UOp # the sink passed into the rewrite
|
||||
bottom_up: bool
|
||||
matches: List[Tuple[UOp, Optional[UOp], Optional[UPat], float]] = field(default_factory=list) # all matches of sparents
|
||||
|
||||
rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = []
|
||||
@@ -775,7 +776,7 @@ class RewriteContext:
|
||||
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp:
|
||||
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0:
|
||||
rewrite_stack[-1][1].append(TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
|
||||
rewrite_stack[-1][1].append(TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up))
|
||||
return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).rewrite(sink)
|
||||
|
||||
# ***** uop type spec *****
|
||||
|
||||
@@ -82,14 +82,14 @@ def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata)
|
||||
replaces: Dict[UOp, UOp] = {}
|
||||
sink = ctx.sink
|
||||
for i,(u0,u1,upat,_) in enumerate(ctx.matches):
|
||||
if ctx.bottom_up: replaces = {} # if it's bottom_up it's single pass
|
||||
replaces[u0] = u0 if u1 is None else u1
|
||||
# if the match didn't result in a rewrite we move forward
|
||||
if u1 is None: continue
|
||||
# first, rewrite this UOp with the current rewrite + all the seen matches before this
|
||||
# first, rewrite this UOp with the current rewrite + all the matches in replaces
|
||||
new_sink = _replace_uop(sink, {**replaces})
|
||||
# sanity check
|
||||
if new_sink is sink:
|
||||
raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}")
|
||||
if new_sink is sink: raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}")
|
||||
# update ret data
|
||||
g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not Ops.CONST])
|
||||
g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())))
|
||||
|
||||
Reference in New Issue
Block a user