diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index a8e02268eb..82020c87c5 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -1,5 +1,6 @@ import unittest, decimal, json, struct from dataclasses import dataclass +from typing import Generator from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, TrackedPatternMatcher from tinygrad.uop.ops import graph_rewrite, track_rewrites, TRACK_MATCH_STATS @@ -19,6 +20,10 @@ from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewr traces = [(tracked_keys, tracked_ctxs, uop_fields)] from tinygrad.viz.serve import get_metadata, uop_to_json, get_details def get_viz_list(): return get_metadata(traces) +def get_viz_details(rewrite_idx:int, step:int) -> Generator[dict, None, None]: + lst = get_viz_list() + assert len(lst) > rewrite_idx, "only loaded {len(lst)} traces, expecting at least {idx}" + return get_details(tracked_ctxs[rewrite_idx][step]) class BaseTestViz(unittest.TestCase): def setUp(self): @@ -129,7 +134,7 @@ class TestViz(BaseTestViz): (UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)), ]) with self.assertRaises(RuntimeError): exec_rewrite(a, [pm]) - graphs = flatten(x["graph"].values() for x in get_details(tracked_ctxs[0][0])) + graphs = flatten(x["graph"].values() for x in get_viz_details(0, 0)) self.assertEqual(graphs[0], uop_to_json(a)[id(a)]) self.assertEqual(graphs[1], uop_to_json(b)[id(b)]) # fallback to NOOP with the error message @@ -143,7 +148,7 @@ class TestViz(BaseTestViz): exec_rewrite(alu, [sym]) lst = get_viz_list() self.assertEqual(len(lst), 1) - graphs = [x["graph"] for x in get_details(tracked_ctxs[0][0])] + graphs = [x["graph"] for x in get_viz_details(0, 0)] # embed const in the parent node when possible self.assertEqual(list(graphs[0]), [id(a), id(alu)]) self.assertEqual(list(graphs[1]), [id(z)]) @@ -247,7 +252,7 @@ class TestVizIntegration(BaseTestViz): b = Tensor.empty(1) metadata = (alu:=a+b).uop.metadata alu.kernelize() - graph = next(get_details(tracked_ctxs[0][0]))["graph"] + graph = next(get_viz_details(0, 0))["graph"] self.assertEqual(len([n for n in graph.values() if repr(metadata) in n["label"]]), 1) from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry