viz: no global state (#15705)

* start viz data

* get_full_rewrites also moves

* update ref_map

* work

* update consumers

* cleaner cli

* linter

* cleanup tests

* back

* better

* sqtt tests
This commit is contained in:
qazal
2026-04-13 15:35:20 +03:00
committed by GitHub
parent 4c1fb18a09
commit ac027055ef
5 changed files with 111 additions and 99 deletions

View File

@@ -1,15 +1,16 @@
import unittest, contextlib
from tinygrad import Device, Tensor, Context, TinyJit
from tinygrad.device import Compiled, ProfileProgramEvent, ProfileDeviceEvent
from tinygrad.viz.serve import load_amd_counters
from tinygrad.viz.serve import load_amd_counters, VizData
@contextlib.contextmanager
def save_sqtt():
yield (ret:=[])
data = VizData()
yield data.ctxs
Device[Device.DEFAULT].synchronize()
Device[Device.DEFAULT]._at_profile_finalize()
load_amd_counters(ret, Compiled.profile_events)
ret[:] = [r for r in ret if r["name"].startswith("SQTT")]
load_amd_counters(data, Compiled.profile_events)
data.ctxs[:] = [r for r in data.ctxs if r["name"].startswith("SQTT")]
@unittest.skipUnless(Device.DEFAULT == "AMD", "only runs on AMD")
class TestSQTTProfiler(unittest.TestCase):

View File

@@ -10,7 +10,7 @@ from tinygrad.helpers import VIZ, cpu_profile, ProfilePointEvent, unwrap
from tinygrad.device import Buffer
from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, active_group, _name_cnt, RewriteTrace
from tinygrad.viz.serve import get_rewrites, get_full_rewrite, uop_to_json
from tinygrad.viz.serve import load_rewrites, get_full_rewrite, uop_to_json, VizData
@track_rewrites(name=True)
def exec_rewrite(sink:UOp, pm_lst:list[PatternMatcher], names:None|list[str]=None) -> UOp:
@@ -21,19 +21,19 @@ def exec_rewrite(sink:UOp, pm_lst:list[PatternMatcher], names:None|list[str]=Non
# small container class for the viz server module
class VizTrace:
# loader init
def __init__(self): self._trace:RewriteTrace|None = None
def __init__(self): self._data:VizData|None = None
@property
def trace(self) -> RewriteTrace: return unwrap(self._trace)
def set_trace(self) -> None:
self._trace = RewriteTrace(tracked_keys.copy(), tracked_ctxs.copy(), uop_fields.copy())
import tinygrad.viz.serve as serve_module
serve_module.trace = self._trace
def data(self) -> VizData: return unwrap(self._data)
def set_data(self) -> None:
data = VizData(RewriteTrace(tracked_keys.copy(), tracked_ctxs.copy(), uop_fields.copy()))
load_rewrites(data)
self._data = data
# the API
def list_items(self) -> list[dict]: return get_rewrites(self.trace)
def list_items(self) -> list[dict]:
return self.data.ctxs
def get_details(self, rewrite_idx:int, step:int) -> Generator[dict, None, None]:
lst = self.list_items()
assert len(lst) > rewrite_idx, f"only loaded {len(lst)} traces, expecting at least {rewrite_idx}"
return get_full_rewrite(self.trace.rewrites[rewrite_idx][step])
assert len(self.data.trace.rewrites) > rewrite_idx, f"only loaded {len(self.data.trace.rewrites)} traces, expecting at least {rewrite_idx}"
return get_full_rewrite(self.data, self.data.trace.rewrites[rewrite_idx][step])
@contextlib.contextmanager
def save_viz():
@@ -52,7 +52,7 @@ def save_viz():
try:
yield viz
finally:
viz.set_trace()
viz.set_data()
TRACK_MATCH_STATS.value = prev_tms
PROFILE.value = prev_profile
VIZ.value = prev_viz
@@ -194,7 +194,7 @@ class TestViz(unittest.TestCase):
class TestStruct:
colored_field: str
a = UOp(Ops.CUSTOM, arg=TestStruct(colored("xyz", "magenta")+colored("12345", "blue")))
a2 = uop_to_json(a)[id(a)]
a2 = uop_to_json(a, VizData())[id(a)]
self.assertEqual(ansistrip(a2["label"]), f"CUSTOM\n{TestStruct.__qualname__}(colored_field='xyz12345')")
def test_colored_label_multiline(self):
@@ -217,11 +217,11 @@ class TestViz(unittest.TestCase):
# use smaller stack limit for faster test (default is 250000)
with Context(REWRITE_STACK_LIMIT=100): self.assertRaises(RuntimeError, exec_rewrite, a, [pm])
graphs = flatten(x["graph"].values() for x in viz.get_details(0, 0))
self.assertEqual(graphs[0], uop_to_json(a)[id(a)])
self.assertEqual(graphs[1], uop_to_json(b)[id(b)])
self.assertEqual(graphs[0], uop_to_json(a, VizData())[id(a)])
self.assertEqual(graphs[1], uop_to_json(b, VizData())[id(b)])
# fallback to NOOP with the error message
nop = UOp(Ops.NOOP, arg="infinite loop in fixed_point_rewrite")
self.assertEqual(graphs[2], uop_to_json(nop)[id(nop)])
self.assertEqual(graphs[2], uop_to_json(nop, VizData())[id(nop)])
def test_const_node_visibility(self):
with save_viz() as viz:
@@ -241,7 +241,7 @@ class TestViz(unittest.TestCase):
c = UOp.const(dtypes.float, 1.0, device="CPU", shape=(3,4)) # creates CONST->RESHAPE->EXPAND chain
a = UOp(Ops.DEFINE_VAR, dtypes.float, arg=("a", 0.0, 10.0))
alu = a + c
graph = uop_to_json(alu)
graph = uop_to_json(alu, VizData())
# the RESHAPE and EXPAND nodes from the const should not appear in the graph
labels = {v["label"].split("\n")[0] for v in graph.values()}
self.assertNotIn("RESHAPE", labels)