load all traces before asserting in test_viz (#12004)

This commit is contained in:
qazal
2025-09-04 21:34:48 +03:00
committed by GitHub
parent 9dee724fc4
commit 4996bb668b

View File

@@ -1,5 +1,6 @@
import unittest, decimal, json, struct
from dataclasses import dataclass
from typing import Generator
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, TrackedPatternMatcher
from tinygrad.uop.ops import graph_rewrite, track_rewrites, TRACK_MATCH_STATS
@@ -19,6 +20,10 @@ from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewr
traces = [(tracked_keys, tracked_ctxs, uop_fields)]
from tinygrad.viz.serve import get_metadata, uop_to_json, get_details
def get_viz_list(): return get_metadata(traces)
def get_viz_details(rewrite_idx:int, step:int) -> Generator[dict, None, None]:
lst = get_viz_list()
assert len(lst) > rewrite_idx, "only loaded {len(lst)} traces, expecting at least {idx}"
return get_details(tracked_ctxs[rewrite_idx][step])
class BaseTestViz(unittest.TestCase):
def setUp(self):
@@ -129,7 +134,7 @@ class TestViz(BaseTestViz):
(UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
])
with self.assertRaises(RuntimeError): exec_rewrite(a, [pm])
graphs = flatten(x["graph"].values() for x in get_details(tracked_ctxs[0][0]))
graphs = flatten(x["graph"].values() for x in get_viz_details(0, 0))
self.assertEqual(graphs[0], uop_to_json(a)[id(a)])
self.assertEqual(graphs[1], uop_to_json(b)[id(b)])
# fallback to NOOP with the error message
@@ -143,7 +148,7 @@ class TestViz(BaseTestViz):
exec_rewrite(alu, [sym])
lst = get_viz_list()
self.assertEqual(len(lst), 1)
graphs = [x["graph"] for x in get_details(tracked_ctxs[0][0])]
graphs = [x["graph"] for x in get_viz_details(0, 0)]
# embed const in the parent node when possible
self.assertEqual(list(graphs[0]), [id(a), id(alu)])
self.assertEqual(list(graphs[1]), [id(z)])
@@ -247,7 +252,7 @@ class TestVizIntegration(BaseTestViz):
b = Tensor.empty(1)
metadata = (alu:=a+b).uop.metadata
alu.kernelize()
graph = next(get_details(tracked_ctxs[0][0]))["graph"]
graph = next(get_viz_details(0, 0))["graph"]
self.assertEqual(len([n for n in graph.values() if repr(metadata) in n["label"]]), 1)
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry