diff --git a/tinygrad/ops.py b/tinygrad/ops.py index ab04ab4259..74feb025be 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -590,6 +590,7 @@ class TrackedRewriteContext: kernel: Optional[Kernel] = None # the kernel being rewritten rewrites: List[Tuple[UOp, UOp, UPat]] = field(default_factory=list) # all rewrites of sparents. (before, after, UPat) contexts: List[TrackedRewriteContext] = [] +rewrite_stack: List[TrackedRewriteContext] = [] class TrackedPatternMatcher(PatternMatcher): def __init__(self, patterns:List[Tuple[UPat, Callable]]): super().__init__(patterns) @@ -610,7 +611,7 @@ class TrackedPatternMatcher(PatternMatcher): match_stats[p][2] += (et:=time.perf_counter()-st) match_stats[p][3] += et if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable()) - if TRACK_MATCH_STATS >= 2 and contexts and isinstance(ret, UOp): contexts[-1].rewrites.append((uop, ret, p)) + if TRACK_MATCH_STATS >= 2 and len(rewrite_stack) != 0 and isinstance(ret, UOp): rewrite_stack[-1].rewrites.append((uop, ret, p)) return ret # NOTE: if it returns None, we keep trying to match match_stats[p][2] += time.perf_counter()-st return None @@ -656,8 +657,10 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp: # get Kernel we are rewriting in the context of frm_walk: Optional[FrameType] = frm while frm_walk is not None and not isinstance(kernel:=frm_walk.f_locals.get("self", None), Kernel): kernel, frm_walk = None, frm_walk.f_back - contexts.append(TrackedRewriteContext((frm.f_code.co_filename, frm.f_lineno), sink, kernel)) - return RewriteContext(pm, ctx).rewrite(sink) + rewrite_stack.append(TrackedRewriteContext((frm.f_code.co_filename, frm.f_lineno), sink, kernel)) + ret = RewriteContext(pm, ctx).rewrite(sink) + if TRACK_MATCH_STATS >= 2: contexts.append(rewrite_stack.pop()) + return ret # ***** uop type spec ***** diff --git a/viz/serve.py b/viz/serve.py index b325959a2d..1edd6f9590 100755 --- a/viz/serve.py +++ b/viz/serve.py @@ -5,22 +5,23 @@ import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib, from dataclasses import asdict from urllib.parse import parse_qs, urlparse from http.server import HTTPServer, BaseHTTPRequestHandler -from tinygrad.helpers import getenv, to_function_name -from tinygrad.ops import TrackedRewriteContext, UOp, UOps, UPat, lines +from tinygrad.helpers import getenv, to_function_name, tqdm +from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines from tinygrad.engine.graph import uops_colors, word_wrap from viz.spec import GraphRewriteDetails, GraphRewriteMetadata -def reconstruct_graph(sink:UOp, rewrites:List[Tuple[UOp, UOp, UPat]]) -> Tuple[List[UOp], List[List[str]], List[List[int]]]: - uops: List[UOp] = [sink] +def reconstruct_graph(ctx:TrackedRewriteContext) -> Tuple[List[UOp], List[List[str]], List[List[int]]]: + uops: List[UOp] = [ctx.sink] diffs: List[List[str]] = [] changed_nodes: List[List[int]] = [] seen_replaces: Dict[UOp, UOp] = {} - for i, (first, rewritten, _) in enumerate(rewrites): + for i, (first, rewritten, upat) in enumerate(ctx.rewrites): # first, rewrite this UOp with the current rewrite + all the seen rewrites before this seen_replaces[first] = rewritten new_sink = replace_uop(uops[-1], {**seen_replaces}) # sanity check - assert new_sink is not uops[-1], f"rewritten sink wasn't rewritten! {i}\n{new_sink}\n{uops[-1]}" + if new_sink is uops[-1]: + raise AssertionError(f"rewritten sink wasn't rewritten! {i} {upat.location}") # update ret data changed_nodes.append([id(x) for x in rewritten.sparents if x.op is not UOps.CONST]) diffs.append(list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))) @@ -43,7 +44,7 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: def replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp: if (found:=replaces.get(base)) is not None: return found - replaces[base] = ret = UOp(base.op, base.dtype, tuple(replace_uop(x, replaces) for x in base.src), base.arg) + replaces[base] = ret = base.replace(src=tuple(replace_uop(x, replaces) for x in base.src)) return ret def load_kernels(contexts) -> DefaultDict[str, List[Tuple[GraphRewriteMetadata, TrackedRewriteContext]]]: @@ -79,7 +80,7 @@ class Handler(BaseHTTPRequestHandler): query = parse_qs(url.query) if (qkernel:=query.get("kernel")) is not None: metadata, ctx = list(kernels.values())[int(qkernel[0])][int(query["idx"][0])] - graphs, diffs, changed_nodes = reconstruct_graph(ctx.sink, ctx.rewrites) + graphs, diffs, changed_nodes = reconstruct_graph(ctx) ret = json.dumps(asdict(GraphRewriteDetails(**asdict(metadata), graphs=list(map(uop_to_json, graphs)), diffs=diffs, changed_nodes=changed_nodes, kernel_code=get_src(ctx.kernel)))).encode() else: ret = json.dumps([list(map(lambda x:asdict(x[0]), v)) for v in kernels.values()]).encode() @@ -104,6 +105,9 @@ if __name__ == "__main__": with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[TrackedRewriteContext] = pickle.load(f) print("*** unpickled saved rewrites") kernels = load_kernels(contexts) + if getenv("FUZZ_VIZ"): + for v in tqdm(kernels.values()): + for _,ctx in v: reconstruct_graph(ctx) print("*** loaded kernels") server = HTTPServer(('', 8000), Handler) st = time.perf_counter() diff --git a/viz/test_viz.py b/viz/test_viz.py index bddbaf2196..44832f1bf6 100644 --- a/viz/test_viz.py +++ b/viz/test_viz.py @@ -1,19 +1,17 @@ from typing import List import unittest import os, itertools - -from viz.spec import GraphRewriteMetadata os.environ["TRACK_MATCH_STATS"] = "2" os.environ["PRINT_MATCH_STATS"] = "0" -from extra.models.resnet import ResNet50 from tinygrad import Tensor from tinygrad.engine.realize import lower_schedule from tinygrad.ops import TrackedRewriteContext, UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps from tinygrad.dtype import dtypes, PtrDType -from tinygrad.helpers import CI, Context, all_same, DEBUG, colored, getenv +from tinygrad.helpers import Context, all_same, DEBUG, colored, getenv from tinygrad.codegen.uopgraph import sym, devectorize, float4_folding from test.external.process_replay.helpers import print_diff from viz.serve import reconstruct_graph, uop_to_json, load_kernels +from viz.spec import GraphRewriteMetadata def group_rewrites(kernels:List[GraphRewriteMetadata]): return {k:list(v) for k,v in itertools.groupby(kernels, lambda x:x.loc)} @@ -25,7 +23,7 @@ class TestViz(unittest.TestCase): def assert_valid_ctx(self, contexts:List[TrackedRewriteContext]): assert len(contexts) != 0 for i,ctx in enumerate(contexts): - try: graphs,_,_ = reconstruct_graph(ctx.sink, ctx.rewrites) + try: graphs,_,_ = reconstruct_graph(ctx) except Exception as e: print(colored(f"failed to create graph for ctx {i}", "red")) raise e @@ -72,7 +70,7 @@ class TestViz(unittest.TestCase): ]) ret = graph_rewrite(sink, pm) if DEBUG >= 4: print_diff(sink, ret) - graphs,_,_ = reconstruct_graph(contexts[0].sink, contexts[0].rewrites) + graphs,_,_ = reconstruct_graph(contexts[0]) assert graphs[-1].key == ret.key self.assert_valid_ctx(contexts) @@ -106,15 +104,6 @@ class TestViz(unittest.TestCase): self.assert_valid_ctx(contexts) assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for ctx in contexts) - @unittest.skipIf(CI, "slow, it's generating diffs for 36202 rules") - def test_fuzz_resnet(self): - mdl = ResNet50() - img = Tensor.empty(64, 3, 224, 224) - out = mdl(img) - sched = out.schedule() - list(lower_schedule(sched)) - self.assert_valid_ctx(contexts) - def test_no_ctx(self): simple_pm = PatternMatcher([(UPat(UOps.CONST), lambda:True)]) simple_pm.rewrite(UOp.const(dtypes.int, 2))