mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
load all traces before asserting in test_viz (#12004)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import unittest, decimal, json, struct
|
||||
from dataclasses import dataclass
|
||||
from typing import Generator
|
||||
|
||||
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, TrackedPatternMatcher
|
||||
from tinygrad.uop.ops import graph_rewrite, track_rewrites, TRACK_MATCH_STATS
|
||||
@@ -19,6 +20,10 @@ from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewr
|
||||
traces = [(tracked_keys, tracked_ctxs, uop_fields)]
|
||||
from tinygrad.viz.serve import get_metadata, uop_to_json, get_details
|
||||
def get_viz_list(): return get_metadata(traces)
|
||||
def get_viz_details(rewrite_idx:int, step:int) -> Generator[dict, None, None]:
|
||||
lst = get_viz_list()
|
||||
assert len(lst) > rewrite_idx, "only loaded {len(lst)} traces, expecting at least {idx}"
|
||||
return get_details(tracked_ctxs[rewrite_idx][step])
|
||||
|
||||
class BaseTestViz(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -129,7 +134,7 @@ class TestViz(BaseTestViz):
|
||||
(UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
|
||||
])
|
||||
with self.assertRaises(RuntimeError): exec_rewrite(a, [pm])
|
||||
graphs = flatten(x["graph"].values() for x in get_details(tracked_ctxs[0][0]))
|
||||
graphs = flatten(x["graph"].values() for x in get_viz_details(0, 0))
|
||||
self.assertEqual(graphs[0], uop_to_json(a)[id(a)])
|
||||
self.assertEqual(graphs[1], uop_to_json(b)[id(b)])
|
||||
# fallback to NOOP with the error message
|
||||
@@ -143,7 +148,7 @@ class TestViz(BaseTestViz):
|
||||
exec_rewrite(alu, [sym])
|
||||
lst = get_viz_list()
|
||||
self.assertEqual(len(lst), 1)
|
||||
graphs = [x["graph"] for x in get_details(tracked_ctxs[0][0])]
|
||||
graphs = [x["graph"] for x in get_viz_details(0, 0)]
|
||||
# embed const in the parent node when possible
|
||||
self.assertEqual(list(graphs[0]), [id(a), id(alu)])
|
||||
self.assertEqual(list(graphs[1]), [id(z)])
|
||||
@@ -247,7 +252,7 @@ class TestVizIntegration(BaseTestViz):
|
||||
b = Tensor.empty(1)
|
||||
metadata = (alu:=a+b).uop.metadata
|
||||
alu.kernelize()
|
||||
graph = next(get_details(tracked_ctxs[0][0]))["graph"]
|
||||
graph = next(get_viz_details(0, 0))["graph"]
|
||||
self.assertEqual(len([n for n in graph.values() if repr(metadata) in n["label"]]), 1)
|
||||
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry
|
||||
|
||||
Reference in New Issue
Block a user