Files
tinygrad/test/unit/test_viz.py
qazal ac39f27ae6 viz: non blocking UOp tracing (#10913)
* viz: non blocking UOp tracing

* u.arg

* no if Ops.KENREL

* drop replace

* switch to weakref.WeakKeyDictionary

* back

* remove ram usage skips, viz works here

* cache on reconstruct
2025-06-23 19:59:28 +03:00

273 lines
11 KiB
Python

import unittest, decimal, json
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, TrackedPatternMatcher
from tinygrad.uop.ops import graph_rewrite, track_rewrites, TRACK_MATCH_STATS
from tinygrad.uop.symbolic import sym
from tinygrad.dtype import dtypes
@track_rewrites(name=True)
def exec_rewrite(sink:UOp, pm_lst:list[PatternMatcher], names:None|list[str]=None) -> UOp:
for i,pm in enumerate(pm_lst):
sink = graph_rewrite(sink, TrackedPatternMatcher(pm.patterns), name=names[i] if names else None)
return sink
# real VIZ=1 pickles these tracked values
from tinygrad.viz.serve import get_metadata
from tinygrad.uop.ops import tracked_keys, tracked_ctxs, active_rewrites, _name_cnt
def get_viz_list(): return get_metadata(tracked_keys, tracked_ctxs)
class TestViz(unittest.TestCase):
def setUp(self):
# clear the global context
for lst in [tracked_keys, tracked_ctxs, active_rewrites, _name_cnt]: lst.clear()
self.tms = TRACK_MATCH_STATS.value
TRACK_MATCH_STATS.value = 2
def tearDown(self): TRACK_MATCH_STATS.value = self.tms
def test_simple(self):
a = UOp.variable("a", 0, 10)
exec_rewrite((a+0)*1, [sym])
lst = get_viz_list()
# VIZ displays rewrites in groups of tracked functions
self.assertEqual(len(lst), 1)
# each group has a list of steps
self.assertEqual(len(lst[0]["steps"]), 1)
# each step has a list of matches
self.assertEqual(lst[0]["steps"][0]["match_count"], 2)
def test_rewrites(self):
a = UOp.variable("a", 0, 10)
exec_rewrite(a*1, [sym])
exec_rewrite(a*2, [sym])
lst = get_viz_list()
self.assertEqual(len(lst), 2)
# names dedup using a counter
self.assertEqual(lst[0]["name"], "exec_rewrite n1")
self.assertEqual(lst[1]["name"], "exec_rewrite n2")
def test_steps(self):
a = UOp.variable("a", 0, 10)
exec_rewrite(a+1, [PatternMatcher([]), PatternMatcher([])], ["x", "y"])
steps = get_viz_list()[0]["steps"]
# steps can optionally have a name
self.assertEqual(steps[0]["name"], "x")
self.assertEqual(steps[1]["name"], "y")
def test_rewrite_location(self):
def inner(sink): return graph_rewrite(sink, PatternMatcher([]))
@track_rewrites(name=True)
def outer(sink): return inner(sink)
outer(UOp.variable("a", 1, 10))
lst = get_viz_list()
# step location comes from inner rewrite
fp, lineno = lst[0]["steps"][0]["loc"]
self.assertEqual(fp, inner.__code__.co_filename)
self.assertEqual(lineno, inner.__code__.co_firstlineno)
def test_exceptions(self):
# VIZ tracks rewrites up to the error
def count_3(x:UOp):
assert x.arg <= 3
return x.replace(arg=x.arg+1)
err_pm = PatternMatcher([(UPat.cvar("x"), count_3),])
a = UOp.const(dtypes.int, 1)
with self.assertRaises(AssertionError): exec_rewrite(a, [err_pm])
lst = get_viz_list()
err_step = lst[0]["steps"][0]
self.assertEqual(err_step["match_count"], 3)
def test_default_name(self):
a = UOp.variable("a", 1, 10)
@track_rewrites()
def name_default(): return graph_rewrite(a, PatternMatcher([]))
name_default()
lst = get_viz_list()
self.assertEqual(lst[0]["name"], "name_default n1")
# name can also be the first arg
def test_self_name(self):
@track_rewrites()
def name_is_self(s:UOp): return graph_rewrite(s, PatternMatcher([]))
name_is_self(arg:=UOp.variable("a", 1, 10))
lst = get_viz_list()
self.assertEqual(lst[0]["name"], str(arg))
# name can also come from a function
def test_dyn_name_fxn(self):
@track_rewrites(name=lambda a,ret: a.render())
def name_from_fxn(s:UOp): return graph_rewrite(s, PatternMatcher([]))
name_from_fxn(UOp.variable("a", 1, 10)+1)
lst = get_viz_list()
self.assertEqual(lst[0]["name"], "(a+1) n1")
# VIZ displays nested graph_rewrites in a tree view
def leaf_rewrite(x:UOp): return x.rtag(1) if x.tag is None else None
leaf = TrackedPatternMatcher([(UPat(Ops.DEFINE_VAR, name="x"), leaf_rewrite)])
def branch_rewrite(x:UOp, y:UOp):
if x.tag is not None: return
x2 = graph_rewrite(x, leaf, name="leaf_left")
y2 = graph_rewrite(y, leaf, name="leaf_right")
return x2 * y2
branch = TrackedPatternMatcher([(UPat.var("x")+UPat.var("y"), branch_rewrite)])
def root_rewrite(root:UOp):
new_src = tuple(graph_rewrite(b, branch, name=f"branch_{i}") for i,b in enumerate(root.src))
return root.replace(src=new_src)
root = TrackedPatternMatcher([(UPat(Ops.SINK, src=UPat(Ops.ADD), name="root"), root_rewrite),])
class TestVizTree(TestViz):
def assertStepEqual(self, step:dict, want:dict):
for k,v in want.items():
self.assertEqual(step[k], v, f"failed at '{k}': {v} != {step[k]}\n{step=}")
def test_tree_view(self):
a = UOp.variable("a",0,10)
b = UOp.variable("b",0,10)
c = UOp.variable("c",0,10)
d = UOp.variable("d",0,10)
sink = UOp.sink(a+b, c+d)
@track_rewrites()
def tree_rewrite(): return graph_rewrite(sink, root, name="root")
tree_rewrite()
lst = get_viz_list()
steps = lst[0]["steps"]
self.assertEqual(len(steps), 1+2+4)
self.assertStepEqual(steps[0], {"name":"root", "depth":0, "match_count":1})
self.assertStepEqual(steps[1], {"name":"branch_0", "depth":1, "match_count":1})
self.assertStepEqual(steps[2], {"name":"leaf_left", "depth":2, "match_count":1})
self.assertStepEqual(steps[3], {"name":"leaf_right", "depth":2, "match_count":1})
self.assertStepEqual(steps[4], {"name":"branch_1", "depth":1, "match_count":1})
self.assertStepEqual(steps[5], {"name":"leaf_left", "depth":2, "match_count":1})
self.assertStepEqual(steps[6], {"name":"leaf_right", "depth":2, "match_count":1})
import gc
from tinygrad.device import Buffer
def bufs_allocated() -> int:
gc.collect()
return sum([isinstance(x, Buffer) for x in gc.get_objects()])
class TestVizGC(TestViz):
def test_gc(self):
init = bufs_allocated()
a = UOp.new_buffer("NULL", 10, dtypes.char)
a.buffer.allocate()
exec_rewrite(a, [PatternMatcher([])])
del a
self.assertEqual(bufs_allocated()-init, 0)
lst = get_viz_list()
self.assertEqual(len(lst), 1)
@unittest.skip("it's not generic enough to handle arbitrary UOps in arg")
def test_gc_uop_in_arg(self):
init = bufs_allocated()
a = UOp.new_buffer("NULL", 10, dtypes.char)
a.buffer.allocate()
exec_rewrite(UOp(Ops.CUSTOM, src=(a,), arg=a), [PatternMatcher([])])
del a
self.assertEqual(bufs_allocated()-init, 0)
lst = get_viz_list()
self.assertEqual(len(lst), 1)
# VIZ integrates with other parts of tinygrad
from tinygrad import Tensor, Device
from tinygrad.engine.realize import get_program
class TestVizIntegration(TestViz):
# kernelize has a custom name function in VIZ
def test_kernelize_tracing(self):
a = Tensor.empty(4, 4)
Tensor.kernelize(a+1, a+2)
lst = get_viz_list()
self.assertEqual(len(lst), 1)
self.assertEqual(lst[0]["name"], "Schedule 2 Kernels n1")
# codegen supports rendering of code blocks
def test_codegen_tracing(self):
ast = Tensor.schedule(Tensor.empty(4)+Tensor.empty(4))[0].ast
prg = get_program(ast, Device[Device.DEFAULT].renderer)
lst = get_viz_list()
self.assertEqual(len(lst), 2)
self.assertEqual(lst[0]["name"], "Schedule 1 Kernel n1")
self.assertEqual(lst[1]["name"], prg.name)
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
from tinygrad.viz.serve import to_perfetto
class TextVizProfiler(unittest.TestCase):
def test_perfetto_node(self):
prof = [ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=False),
ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))]
j = json.loads(to_perfetto(prof))
# Device regs always first
self.assertEqual(j['traceEvents'][0]['name'], 'process_name')
self.assertEqual(j['traceEvents'][0]['ph'], 'M')
self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV')
self.assertEqual(j['traceEvents'][1]['name'], 'thread_name')
self.assertEqual(j['traceEvents'][1]['ph'], 'M')
self.assertEqual(j['traceEvents'][1]['pid'], j['traceEvents'][0]['pid'])
self.assertEqual(j['traceEvents'][1]['tid'], 0)
self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE')
self.assertEqual(j['traceEvents'][2]['name'], 'thread_name')
self.assertEqual(j['traceEvents'][2]['ph'], 'M')
self.assertEqual(j['traceEvents'][2]['pid'], j['traceEvents'][0]['pid'])
self.assertEqual(j['traceEvents'][2]['tid'], 1)
self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY')
self.assertEqual(j['traceEvents'][3]['name'], 'E_2')
self.assertEqual(j['traceEvents'][3]['ts'], 0)
self.assertEqual(j['traceEvents'][3]['dur'], 10)
self.assertEqual(j['traceEvents'][3]['ph'], 'X')
self.assertEqual(j['traceEvents'][3]['pid'], j['traceEvents'][0]['pid'])
self.assertEqual(j['traceEvents'][3]['tid'], 0)
def test_perfetto_copy_node(self):
prof = [ProfileRangeEvent(device='NV', name='COPYxx', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=True),
ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100))]
j = json.loads(to_perfetto(prof))
self.assertEqual(j['traceEvents'][3]['name'], 'COPYxx')
self.assertEqual(j['traceEvents'][3]['ts'], 900) # diff clock
self.assertEqual(j['traceEvents'][3]['dur'], 10)
self.assertEqual(j['traceEvents'][3]['ph'], 'X')
self.assertEqual(j['traceEvents'][3]['tid'], 1)
def test_perfetto_graph(self):
prof = [ProfileDeviceEvent(device='NV', comp_tdiff=decimal.Decimal(-1000), copy_tdiff=decimal.Decimal(-100)),
ProfileDeviceEvent(device='NV:1', comp_tdiff=decimal.Decimal(-500), copy_tdiff=decimal.Decimal(-50)),
ProfileGraphEvent(ents=[ProfileGraphEntry(device='NV', name='E_25_4n2', st_id=0, en_id=1, is_copy=False),
ProfileGraphEntry(device='NV:1', name='NV -> NV:1', st_id=2, en_id=3, is_copy=True)],
deps=[[], [0]],
sigs=[decimal.Decimal(1000), decimal.Decimal(1002), decimal.Decimal(1004), decimal.Decimal(1008)])]
j = json.loads(to_perfetto(prof))
# Device regs always first
self.assertEqual(j['traceEvents'][0]['args']['name'], 'NV')
self.assertEqual(j['traceEvents'][1]['args']['name'], 'COMPUTE')
self.assertEqual(j['traceEvents'][2]['args']['name'], 'COPY')
self.assertEqual(j['traceEvents'][3]['args']['name'], 'NV:1')
self.assertEqual(j['traceEvents'][4]['args']['name'], 'COMPUTE')
self.assertEqual(j['traceEvents'][5]['args']['name'], 'COPY')
self.assertEqual(j['traceEvents'][6]['name'], 'E_25_4n2')
self.assertEqual(j['traceEvents'][6]['ts'], 0)
self.assertEqual(j['traceEvents'][6]['dur'], 2)
self.assertEqual(j['traceEvents'][6]['pid'], j['traceEvents'][0]['pid'])
self.assertEqual(j['traceEvents'][7]['name'], 'NV -> NV:1')
self.assertEqual(j['traceEvents'][7]['ts'], 954)
self.assertEqual(j['traceEvents'][7]['dur'], 4)
self.assertEqual(j['traceEvents'][7]['pid'], j['traceEvents'][3]['pid'])
if __name__ == "__main__":
unittest.main()