mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
viz: no global state (#15705)
* start viz data * get_full_rewrites also moves * update ref_map * work * update consumers * cleaner cli * linter * cleanup tests * back * better * sqtt tests
This commit is contained in:
@@ -10,7 +10,7 @@ from tinygrad.helpers import VIZ, cpu_profile, ProfilePointEvent, unwrap
|
||||
from tinygrad.device import Buffer
|
||||
|
||||
from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, active_group, _name_cnt, RewriteTrace
|
||||
from tinygrad.viz.serve import get_rewrites, get_full_rewrite, uop_to_json
|
||||
from tinygrad.viz.serve import load_rewrites, get_full_rewrite, uop_to_json, VizData
|
||||
|
||||
@track_rewrites(name=True)
|
||||
def exec_rewrite(sink:UOp, pm_lst:list[PatternMatcher], names:None|list[str]=None) -> UOp:
|
||||
@@ -21,19 +21,19 @@ def exec_rewrite(sink:UOp, pm_lst:list[PatternMatcher], names:None|list[str]=Non
|
||||
# small container class for the viz server module
|
||||
class VizTrace:
|
||||
# loader init
|
||||
def __init__(self): self._trace:RewriteTrace|None = None
|
||||
def __init__(self): self._data:VizData|None = None
|
||||
@property
|
||||
def trace(self) -> RewriteTrace: return unwrap(self._trace)
|
||||
def set_trace(self) -> None:
|
||||
self._trace = RewriteTrace(tracked_keys.copy(), tracked_ctxs.copy(), uop_fields.copy())
|
||||
import tinygrad.viz.serve as serve_module
|
||||
serve_module.trace = self._trace
|
||||
def data(self) -> VizData: return unwrap(self._data)
|
||||
def set_data(self) -> None:
|
||||
data = VizData(RewriteTrace(tracked_keys.copy(), tracked_ctxs.copy(), uop_fields.copy()))
|
||||
load_rewrites(data)
|
||||
self._data = data
|
||||
# the API
|
||||
def list_items(self) -> list[dict]: return get_rewrites(self.trace)
|
||||
def list_items(self) -> list[dict]:
|
||||
return self.data.ctxs
|
||||
def get_details(self, rewrite_idx:int, step:int) -> Generator[dict, None, None]:
|
||||
lst = self.list_items()
|
||||
assert len(lst) > rewrite_idx, f"only loaded {len(lst)} traces, expecting at least {rewrite_idx}"
|
||||
return get_full_rewrite(self.trace.rewrites[rewrite_idx][step])
|
||||
assert len(self.data.trace.rewrites) > rewrite_idx, f"only loaded {len(self.data.trace.rewrites)} traces, expecting at least {rewrite_idx}"
|
||||
return get_full_rewrite(self.data, self.data.trace.rewrites[rewrite_idx][step])
|
||||
|
||||
@contextlib.contextmanager
|
||||
def save_viz():
|
||||
@@ -52,7 +52,7 @@ def save_viz():
|
||||
try:
|
||||
yield viz
|
||||
finally:
|
||||
viz.set_trace()
|
||||
viz.set_data()
|
||||
TRACK_MATCH_STATS.value = prev_tms
|
||||
PROFILE.value = prev_profile
|
||||
VIZ.value = prev_viz
|
||||
@@ -194,7 +194,7 @@ class TestViz(unittest.TestCase):
|
||||
class TestStruct:
|
||||
colored_field: str
|
||||
a = UOp(Ops.CUSTOM, arg=TestStruct(colored("xyz", "magenta")+colored("12345", "blue")))
|
||||
a2 = uop_to_json(a)[id(a)]
|
||||
a2 = uop_to_json(a, VizData())[id(a)]
|
||||
self.assertEqual(ansistrip(a2["label"]), f"CUSTOM\n{TestStruct.__qualname__}(colored_field='xyz12345')")
|
||||
|
||||
def test_colored_label_multiline(self):
|
||||
@@ -217,11 +217,11 @@ class TestViz(unittest.TestCase):
|
||||
# use smaller stack limit for faster test (default is 250000)
|
||||
with Context(REWRITE_STACK_LIMIT=100): self.assertRaises(RuntimeError, exec_rewrite, a, [pm])
|
||||
graphs = flatten(x["graph"].values() for x in viz.get_details(0, 0))
|
||||
self.assertEqual(graphs[0], uop_to_json(a)[id(a)])
|
||||
self.assertEqual(graphs[1], uop_to_json(b)[id(b)])
|
||||
self.assertEqual(graphs[0], uop_to_json(a, VizData())[id(a)])
|
||||
self.assertEqual(graphs[1], uop_to_json(b, VizData())[id(b)])
|
||||
# fallback to NOOP with the error message
|
||||
nop = UOp(Ops.NOOP, arg="infinite loop in fixed_point_rewrite")
|
||||
self.assertEqual(graphs[2], uop_to_json(nop)[id(nop)])
|
||||
self.assertEqual(graphs[2], uop_to_json(nop, VizData())[id(nop)])
|
||||
|
||||
def test_const_node_visibility(self):
|
||||
with save_viz() as viz:
|
||||
@@ -241,7 +241,7 @@ class TestViz(unittest.TestCase):
|
||||
c = UOp.const(dtypes.float, 1.0, device="CPU", shape=(3,4)) # creates CONST->RESHAPE->EXPAND chain
|
||||
a = UOp(Ops.DEFINE_VAR, dtypes.float, arg=("a", 0.0, 10.0))
|
||||
alu = a + c
|
||||
graph = uop_to_json(alu)
|
||||
graph = uop_to_json(alu, VizData())
|
||||
# the RESHAPE and EXPAND nodes from the const should not appear in the graph
|
||||
labels = {v["label"].split("\n")[0] for v in graph.values()}
|
||||
self.assertNotIn("RESHAPE", labels)
|
||||
|
||||
Reference in New Issue
Block a user