start viz better typing, unsupport bottom_up=True [pr] (#8284)

* start viz refactor

* delete bottom_up tracking

* more cleanup

* early continue
This commit is contained in:
qazal
2024-12-17 13:52:30 +02:00
committed by GitHub
parent 856c068172
commit fd23738d9d
3 changed files with 19 additions and 19 deletions

View File

@@ -89,6 +89,7 @@ class TestViz(unittest.TestCase):
assert not any(v[0].startswith("CONST") for v in graph.values())
assert len([x for x in graph.values() if "CONST" in x[0]]) == 1
@unittest.skip("TODO: bring this back with better testing")
def test_bottom_up_rewrite(self):
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
n1 = a.sin()

View File

@@ -843,7 +843,6 @@ match_stats:Dict[UPat, List[Union[int, float]]] = dict()
class TrackedRewriteContext:
loc: Tuple[str, int] # location that called graph_rewrite
sink: bytes # sanpshot of the sink passed into the rewrite
bottom_up: bool
matches: List[Tuple[bytes, Optional[bytes], Optional[UPat], float]] = field(default_factory=list) # before+after snapshot of all the matches
rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = []
@@ -928,9 +927,9 @@ class RewriteContext:
return ret
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp:
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0:
if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and not bottom_up: # TODO: make viz work with bottom_up=True
with Context(PICKLE_BUFFERS=0):
rewrite_stack[-1][1].append(TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), pickle.dumps(sink), bottom_up))
rewrite_stack[-1][1].append(TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), pickle.dumps(sink)))
return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).rewrite(sink)
# ***** uop type spec *****

View File

@@ -3,7 +3,7 @@ import multiprocessing, pickle, functools, difflib, os, threading, json, time, s
from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import parse_qs, urlparse
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Tuple, Optional
from typing import Any, Callable, Dict, List, Tuple, Optional
from tinygrad.helpers import colored, getenv, to_function_name, tqdm, unwrap, word_wrap
from tinygrad.ops import TrackedRewriteContext, UOp, Ops, lines, GroupOp
from tinygrad.codegen.kernel import Kernel
@@ -18,12 +18,12 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB"
@dataclass
class GraphRewriteMetadata:
"""Specifies metadata about a single call to graph_rewrite"""
"""Overview of a tracked rewrite to viz the sidebar"""
loc: Tuple[str, int]
"""File_path, Lineno"""
code_line: str
"""The Python line calling graph_rewrite"""
kernel_name: Optional[str]
kernel_name: str
"""The kernel calling graph_rewrite"""
upats: List[Tuple[Tuple[str, int], str, float]]
"""List of all the applied UPats"""
@@ -42,19 +42,19 @@ class GraphRewriteDetails(GraphRewriteMetadata):
# ** API functions
def pcall(fxn, *args, **kwargs):
# NOTE: if any extra rendering in VIZ fails, we don't crash
def pcall(fxn:Callable[..., str], *args, **kwargs) -> str:
try: return fxn(*args, **kwargs)
except Exception as e: return f"ERROR: {e}"
def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List[List[Tuple[Any, TrackedRewriteContext, GraphRewriteMetadata]]]:
kernels: Dict[Optional[str], List[Tuple[Any, TrackedRewriteContext, GraphRewriteMetadata]]] = {}
kernels: Dict[str, List[Tuple[Any, TrackedRewriteContext, GraphRewriteMetadata]]] = {}
for k,ctxs in contexts:
name = to_function_name(k.name) if isinstance(k, Kernel) else k
name = to_function_name(k.name) if isinstance(k, Kernel) else str(k)
for ctx in ctxs:
if pickle.loads(ctx.sink).op is Ops.CONST: continue
upats = [(upat.location, upat.printable(), tm) for _,_,upat,tm in ctx.matches if upat is not None]
if name not in kernels: kernels[name] = []
kernels[name].append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats)))
kernels.setdefault(name, []).append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats)))
return list(kernels.values())
def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
@@ -76,19 +76,19 @@ def _replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp:
replaces[base] = ret
return ret
@functools.lru_cache(None)
def _prg(k:Optional[Kernel]) -> Optional[str]:
try: return k.to_program().src if isinstance(k, Kernel) else None
except Exception: return None
def _prg(k:Kernel): return k.to_program().src
def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata) -> GraphRewriteDetails:
g = GraphRewriteDetails(**asdict(metadata), graphs=[pickle.loads(ctx.sink)], diffs=[], changed_nodes=[], kernel_code=pcall(_prg, k))
g = GraphRewriteDetails(**asdict(metadata), graphs=[pickle.loads(ctx.sink)], diffs=[], changed_nodes=[],
kernel_code=pcall(_prg, k) if isinstance(k, Kernel) else None)
replaces: Dict[UOp, UOp] = {}
sink = g.graphs[0]
for i,(u0_b,u1_b,upat,_) in enumerate(ctx.matches):
if ctx.bottom_up: replaces = {} # if it's bottom_up it's single pass
u0, u1 = pickle.loads(u0_b), None if u1_b is None else pickle.loads(u1_b)
replaces[u0] = u0 if u1 is None else u1
u0 = pickle.loads(u0_b)
# if the match didn't result in a rewrite we move forward
if u1 is None: continue
if u1_b is None:
replaces[u0] = u0
continue
replaces[u0] = u1 = pickle.loads(u1_b)
# first, rewrite this UOp with the current rewrite + all the matches in replaces
new_sink = _replace_uop(sink, {**replaces})
# sanity check