diff --git a/test/test_viz.py b/test/test_viz.py index d14fd7d97a..140041b30c 100644 --- a/test/test_viz.py +++ b/test/test_viz.py @@ -15,7 +15,7 @@ def helper_test_viz(sink:UOp, pm:PatternMatcher, **kwargs) -> List[UOp]: assert len(contexts[0]) == 1 k = get_metadata(keys, contexts)[0][0] g = get_details(*k) - return g.graphs[1:] + return g.uops[1:] class TestViz(unittest.TestCase): def setUp(self): diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index c7dc838768..c6e324b1dd 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -34,8 +34,9 @@ class GraphRewriteMetadata: @dataclass class GraphRewriteDetails(GraphRewriteMetadata): """Full details about a single call to graph_rewrite""" - graphs: list[UOp] - """Sink at every step of graph_rewrite""" + uops: list[UOp] + graphs: list[dict] + """Sink at every step of graph_rewrite + the json serialized version""" diffs: list[list[str]] """.diff style before and after of the rewritten UOp child""" changed_nodes: list[list[int]] @@ -88,10 +89,10 @@ def _replace_uop(base:UOp, replaces:dict[UOp, UOp]) -> UOp: @functools.lru_cache(None) def _prg(k:Kernel): return k.to_program().src def get_details(k:Any, ctx:TrackedGraphRewrite, metadata:GraphRewriteMetadata) -> GraphRewriteDetails: - g = GraphRewriteDetails(**asdict(metadata), graphs=[pickle.loads(ctx.sink)], diffs=[], changed_nodes=[], - kernel_code=pcall(_prg, k) if isinstance(k, Kernel) else None) + g = GraphRewriteDetails(**asdict(metadata), uops=[pickle.loads(ctx.sink)], diffs=[], changed_nodes=[], + kernel_code=pcall(_prg, k) if isinstance(k, Kernel) else None, graphs=[]) replaces: dict[UOp, UOp] = {} - sink = g.graphs[0] + g.graphs.append(uop_to_json(sink:=g.uops[0])) for i,(u0_b,u1_b,upat,_) in enumerate(ctx.matches): u0 = pickle.loads(u0_b) # if the match didn't result in a rewrite we move forward @@ -104,9 +105,10 @@ def get_details(k:Any, ctx:TrackedGraphRewrite, metadata:GraphRewriteMetadata) - # sanity check if new_sink is sink: raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}") # update ret data - g.changed_nodes.append([id(x) for x in u1.toposort if x.op is not Ops.CONST]) + g.graphs.append(new_sink_js:=uop_to_json(new_sink)) + g.changed_nodes.append([id(x) for x in u1.toposort if id(x) in new_sink_js]) g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines()))) - g.graphs.append(sink:=new_sink) + g.uops.append(sink:=new_sink) return g # Profiler API @@ -158,7 +160,7 @@ class Handler(BaseHTTPRequestHandler): query = parse_qs(url.query) if (qkernel:=query.get("kernel")) is not None: g = get_details(*kernels[int(qkernel[0])][int(query["idx"][0])]) - jret: Any = {**asdict(g), "graphs": [uop_to_json(x) for x in g.graphs], "uops": [pcall(str,x) for x in g.graphs]} + jret: Any = {**asdict(g), "uops": [pcall(str,x) for x in g.uops]} else: jret = [list(map(lambda x:asdict(x[2]), v)) for v in kernels] ret, content_type = json.dumps(jret).encode(), "application/json" elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json"