viz with bottom_up=True (#7894)

* add failing test

* single pass it

* linter
This commit is contained in:
qazal
2024-11-25 04:56:48 -05:00
committed by GitHub
parent 2ca41d6a44
commit e823de3828
3 changed files with 30 additions and 9 deletions

View File

@@ -4,11 +4,11 @@ from tinygrad.dtype import dtypes
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, contexts, track_rewrites
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json
@track_rewrites()
def rewrite(sink:UOp, pm:PatternMatcher, ctx=None): return graph_rewrite(sink, pm, ctx)
@track_rewrites(named=True)
def rewrite(sink:UOp, pm:PatternMatcher, **kwargs): return graph_rewrite(sink, pm, **kwargs)
def helper_test_viz(sink:UOp, pm:PatternMatcher, ctx=None) -> List[UOp]:
rewrite(sink, pm, ctx)
def helper_test_viz(sink:UOp, pm:PatternMatcher, **kwargs) -> List[UOp]:
rewrite(sink, pm, **kwargs)
assert len(contexts) == 1
assert len(contexts[0][1]) == 1
k = get_metadata(contexts)[0][0]
@@ -52,7 +52,7 @@ class TestViz(unittest.TestCase):
pm = PatternMatcher([
(UPat(Ops.LOAD, name="x"), store_load),
])
uops = helper_test_viz(a+b, pm, {})
uops = helper_test_viz(a+b, pm, ctx={})
self.assertEqual(len(uops), 2)
self.assertEqual(uops[-1], graph_rewrite(a+b, pm, {}))
@@ -89,5 +89,25 @@ class TestViz(unittest.TestCase):
assert not any(v[0].startswith("CONST") for v in graph.values())
assert len([x for x in graph.values() if "CONST" in x[0]]) == 1
def test_bottom_up_rewrite(self):
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
n1 = a.sin()
uop = n1.sin()
pm = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
ret = helper_test_viz(uop, pm, ctx={a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=True)
self.assertEqual(len(ret), 2)
self.assertIs(ret[0], a.sin().sqrt()) # first rewrite
self.assertIs(ret[1], a.sqrt().sqrt()) # second one
def test_top_down_rewrite(self):
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
n1 = a.sin()
uop = n1.sin()
pm = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
# if it wasn't bottom_up, it's rewritten once
ret = helper_test_viz(uop, pm, ctx={a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=False)
self.assertEqual(len(ret), 1)
self.assertIs(ret[0], a.sqrt().sin()) # only rewrite
if __name__ == "__main__":
unittest.main()

View File

@@ -688,6 +688,7 @@ match_stats:Dict[UPat, List[Union[int, float]]] = dict()
class TrackedRewriteContext:
loc: Tuple[str, int] # location that called graph_rewrite
sink: UOp # the sink passed into the rewrite
bottom_up: bool
matches: List[Tuple[UOp, Optional[UOp], Optional[UPat], float]] = field(default_factory=list) # all matches of sparents
rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = []
@@ -775,7 +776,7 @@ class RewriteContext:
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp:
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0:
rewrite_stack[-1][1].append(TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
rewrite_stack[-1][1].append(TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up))
return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).rewrite(sink)
# ***** uop type spec *****

View File

@@ -82,14 +82,14 @@ def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata)
replaces: Dict[UOp, UOp] = {}
sink = ctx.sink
for i,(u0,u1,upat,_) in enumerate(ctx.matches):
if ctx.bottom_up: replaces = {} # if it's bottom_up it's single pass
replaces[u0] = u0 if u1 is None else u1
# if the match didn't result in a rewrite we move forward
if u1 is None: continue
# first, rewrite this UOp with the current rewrite + all the seen matches before this
# first, rewrite this UOp with the current rewrite + all the matches in replaces
new_sink = _replace_uop(sink, {**replaces})
# sanity check
if new_sink is sink:
raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}")
if new_sink is sink: raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}")
# update ret data
g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not Ops.CONST])
g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())))