From fd23738d9d48f608c50f32e028c4fbc50de9deda Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 17 Dec 2024 13:52:30 +0200 Subject: [PATCH] start viz better typing, unsupport bottom_up=True [pr] (#8284) * start viz refactor * delete bottom_up tracking * more cleanup * early continue --- test/test_viz.py | 1 + tinygrad/ops.py | 5 ++--- tinygrad/viz/serve.py | 32 ++++++++++++++++---------------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/test/test_viz.py b/test/test_viz.py index 2edbea3d8a..d1e58ab724 100644 --- a/test/test_viz.py +++ b/test/test_viz.py @@ -89,6 +89,7 @@ class TestViz(unittest.TestCase): assert not any(v[0].startswith("CONST") for v in graph.values()) assert len([x for x in graph.values() if "CONST" in x[0]]) == 1 + @unittest.skip("TODO: bring this back with better testing") def test_bottom_up_rewrite(self): a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) n1 = a.sin() diff --git a/tinygrad/ops.py b/tinygrad/ops.py index d4b1a60b42..98bd60192a 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -843,7 +843,6 @@ match_stats:Dict[UPat, List[Union[int, float]]] = dict() class TrackedRewriteContext: loc: Tuple[str, int] # location that called graph_rewrite sink: bytes # sanpshot of the sink passed into the rewrite - bottom_up: bool matches: List[Tuple[bytes, Optional[bytes], Optional[UPat], float]] = field(default_factory=list) # before+after snapshot of all the matches rewrite_stack: List[Tuple[Any, List[TrackedRewriteContext]]] = [] @@ -928,9 +927,9 @@ class RewriteContext: return ret def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp: - if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0: + if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and not bottom_up: # TODO: make viz work with bottom_up=True with Context(PICKLE_BUFFERS=0): - rewrite_stack[-1][1].append(TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), pickle.dumps(sink), bottom_up)) + rewrite_stack[-1][1].append(TrackedRewriteContext(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), pickle.dumps(sink))) return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).rewrite(sink) # ***** uop type spec ***** diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index fc38b84890..cdc03e6a15 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -3,7 +3,7 @@ import multiprocessing, pickle, functools, difflib, os, threading, json, time, s from http.server import HTTPServer, BaseHTTPRequestHandler from urllib.parse import parse_qs, urlparse from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Tuple, Optional +from typing import Any, Callable, Dict, List, Tuple, Optional from tinygrad.helpers import colored, getenv, to_function_name, tqdm, unwrap, word_wrap from tinygrad.ops import TrackedRewriteContext, UOp, Ops, lines, GroupOp from tinygrad.codegen.kernel import Kernel @@ -18,12 +18,12 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB" @dataclass class GraphRewriteMetadata: - """Specifies metadata about a single call to graph_rewrite""" + """Overview of a tracked rewrite to viz the sidebar""" loc: Tuple[str, int] """File_path, Lineno""" code_line: str """The Python line calling graph_rewrite""" - kernel_name: Optional[str] + kernel_name: str """The kernel calling graph_rewrite""" upats: List[Tuple[Tuple[str, int], str, float]] """List of all the applied UPats""" @@ -42,19 +42,19 @@ class GraphRewriteDetails(GraphRewriteMetadata): # ** API functions -def pcall(fxn, *args, **kwargs): +# NOTE: if any extra rendering in VIZ fails, we don't crash +def pcall(fxn:Callable[..., str], *args, **kwargs) -> str: try: return fxn(*args, **kwargs) except Exception as e: return f"ERROR: {e}" def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List[List[Tuple[Any, TrackedRewriteContext, GraphRewriteMetadata]]]: - kernels: Dict[Optional[str], List[Tuple[Any, TrackedRewriteContext, GraphRewriteMetadata]]] = {} + kernels: Dict[str, List[Tuple[Any, TrackedRewriteContext, GraphRewriteMetadata]]] = {} for k,ctxs in contexts: - name = to_function_name(k.name) if isinstance(k, Kernel) else k + name = to_function_name(k.name) if isinstance(k, Kernel) else str(k) for ctx in ctxs: if pickle.loads(ctx.sink).op is Ops.CONST: continue upats = [(upat.location, upat.printable(), tm) for _,_,upat,tm in ctx.matches if upat is not None] - if name not in kernels: kernels[name] = [] - kernels[name].append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats))) + kernels.setdefault(name, []).append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats))) return list(kernels.values()) def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: @@ -76,19 +76,19 @@ def _replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp: replaces[base] = ret return ret @functools.lru_cache(None) -def _prg(k:Optional[Kernel]) -> Optional[str]: - try: return k.to_program().src if isinstance(k, Kernel) else None - except Exception: return None +def _prg(k:Kernel): return k.to_program().src def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata) -> GraphRewriteDetails: - g = GraphRewriteDetails(**asdict(metadata), graphs=[pickle.loads(ctx.sink)], diffs=[], changed_nodes=[], kernel_code=pcall(_prg, k)) + g = GraphRewriteDetails(**asdict(metadata), graphs=[pickle.loads(ctx.sink)], diffs=[], changed_nodes=[], + kernel_code=pcall(_prg, k) if isinstance(k, Kernel) else None) replaces: Dict[UOp, UOp] = {} sink = g.graphs[0] for i,(u0_b,u1_b,upat,_) in enumerate(ctx.matches): - if ctx.bottom_up: replaces = {} # if it's bottom_up it's single pass - u0, u1 = pickle.loads(u0_b), None if u1_b is None else pickle.loads(u1_b) - replaces[u0] = u0 if u1 is None else u1 + u0 = pickle.loads(u0_b) # if the match didn't result in a rewrite we move forward - if u1 is None: continue + if u1_b is None: + replaces[u0] = u0 + continue + replaces[u0] = u1 = pickle.loads(u1_b) # first, rewrite this UOp with the current rewrite + all the matches in replaces new_sink = _replace_uop(sink, {**replaces}) # sanity check