From ac39f27ae6aa0e075e5dd68d07fab47ff1ce8095 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 23 Jun 2025 19:59:28 +0300 Subject: [PATCH] viz: non blocking UOp tracing (#10913) * viz: non blocking UOp tracing * u.arg * no if Ops.KENREL * drop replace * switch to weakref.WeakKeyDictionary * back * remove ram usage skips, viz works here * cache on reconstruct --- test/test_multitensor.py | 5 +---- test/test_schedule.py | 1 - test/unit/test_viz.py | 29 +++++++++++++++++++++++++++++ tinygrad/uop/ops.py | 23 ++++++++++++++++++----- tinygrad/viz/serve.py | 16 +++++++++++----- 5 files changed, 59 insertions(+), 15 deletions(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 7e218b8bcb..7168339d4b 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -1,4 +1,4 @@ -import unittest, functools, random, os +import unittest, functools, random from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes, Variable from tinygrad.device import is_dtype_supported from tinygrad.uop.ops import Ops, UOp @@ -1101,7 +1101,6 @@ class TestTensorOps(unittest.TestCase): def test_bitcast(self): helper_test_shard_op([(256,), (256,)], lambda x: x.bitcast(dtypes.int)) -# TODO: make these tests pass with VIZ=1 @unittest.skipIf(not_support_multi_device(), "no multi") class TestMultiRamUsage(unittest.TestCase): def setUp(self): @@ -1129,13 +1128,11 @@ class TestMultiRamUsage(unittest.TestCase): def test_zeros_shard(self, devices=(d1, d2)): _ = Tensor.zeros(self.N, self.N).contiguous().shard(devices, axis=0).realize() - assert int(os.getenv("VIZ", "0")) == 0 self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage def test_zeros_shard_self(self): self.test_zeros_shard((d0, d1)) def test_zeros_contiguous_shard(self): _ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).contiguous().realize() - assert int(os.getenv("VIZ", "0")) == 0 self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage @unittest.skipIf(not_support_multi_device(), "need multi") diff --git a/test/test_schedule.py b/test/test_schedule.py index de63056016..1ff279a157 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1631,7 +1631,6 @@ class TestSchedule(unittest.TestCase): @unittest.expectedFailure def test_conv2d_fused_half(self): _test_conv2d(5, dtype=dtypes.half) - @unittest.skipIf(getenv("VIZ"), "TODO: VIZ blocks gc") def test_schedule_mem_used(self): base = GlobalCounters.mem_used Tensor.ones(256).contiguous().realize() diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index 4f3fa74a64..a2fe5e8012 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -142,6 +142,35 @@ class TestVizTree(TestViz): self.assertStepEqual(steps[5], {"name":"leaf_left", "depth":2, "match_count":1}) self.assertStepEqual(steps[6], {"name":"leaf_right", "depth":2, "match_count":1}) +import gc +from tinygrad.device import Buffer + +def bufs_allocated() -> int: + gc.collect() + return sum([isinstance(x, Buffer) for x in gc.get_objects()]) + +class TestVizGC(TestViz): + def test_gc(self): + init = bufs_allocated() + a = UOp.new_buffer("NULL", 10, dtypes.char) + a.buffer.allocate() + exec_rewrite(a, [PatternMatcher([])]) + del a + self.assertEqual(bufs_allocated()-init, 0) + lst = get_viz_list() + self.assertEqual(len(lst), 1) + + @unittest.skip("it's not generic enough to handle arbitrary UOps in arg") + def test_gc_uop_in_arg(self): + init = bufs_allocated() + a = UOp.new_buffer("NULL", 10, dtypes.char) + a.buffer.allocate() + exec_rewrite(UOp(Ops.CUSTOM, src=(a,), arg=a), [PatternMatcher([])]) + del a + self.assertEqual(bufs_allocated()-init, 0) + lst = get_viz_list() + self.assertEqual(len(lst), 1) + # VIZ integrates with other parts of tinygrad from tinygrad import Tensor, Device diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index af214e70f3..25d8192e8e 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -721,6 +721,19 @@ class PatternMatcher: while new_n is not None: last_n, new_n = new_n, self.rewrite(new_n, ctx) return last_n +# *** non-blocking UOp tracker *** + +ucount = itertools.count() +uop_number:weakref.WeakKeyDictionary[UOp, int] = weakref.WeakKeyDictionary() +uop_fields:dict[int, tuple] = {} +def track_uop(u:UOp): + if (cret:=uop_number.get(u)) is not None: return cret + uop_number[u] = num = next(ucount) + # KERNEL also has a UOp in the arg + arg = type(u.arg)(track_uop(u.arg.ast), u.arg.metadata) if u.op is Ops.KERNEL else u.arg + uop_fields[num] = (u.op, u.dtype, tuple(track_uop(s) for s in u.src), arg, u.tag) + return num + # *** tracking pattern matcher *** VIZ = ContextVar("VIZ", 0) @@ -729,8 +742,8 @@ match_stats:dict[UPat, list[Union[int, float]]] = dict() @dataclass(frozen=True) class TrackedGraphRewrite: loc: tuple[str, int] # location that called graph_rewrite - sink: UOp # the sink input to graph_rewrite - matches: list[tuple[UOp, UOp, UPat]] # before+after of all the matches + sink: int # the sink input to graph_rewrite + matches: list[tuple[int, int, UPat]] # before+after of all the matches name: str|None # optional name of the rewrite depth: int # depth if it's a subrewrite bottom_up: bool @@ -774,7 +787,7 @@ def track_matches(func): if tracking:=(TRACK_MATCH_STATS >= 2 and tracked_ctxs): loc = ((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno) depth = len(active_rewrites) - tracked_ctxs[-1].append(ctx:=TrackedGraphRewrite(loc, args[0], [], kwargs.get("name", None), depth, kwargs.get("bottom_up", False))) + tracked_ctxs[-1].append(ctx:=TrackedGraphRewrite(loc, track_uop(args[0]), [], kwargs.get("name", None), depth, kwargs.get("bottom_up", False))) active_rewrites.append(ctx) ret = func(*args, **kwargs) if tracking: active_rewrites.pop() @@ -796,7 +809,7 @@ class TrackedPatternMatcher(PatternMatcher): match_stats[p][0] += 1 match_stats[p][3] += (et:=time.perf_counter()-st) if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable()) - if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and active_rewrites: active_rewrites[-1].matches.append((uop, ret, p)) + if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and active_rewrites: active_rewrites[-1].matches.append((track_uop(uop),track_uop(ret), p)) return ret match_stats[p][2] += time.perf_counter()-st return None @@ -809,7 +822,7 @@ if TRACK_MATCH_STATS or PROFILE: if TRACK_MATCH_STATS >= 2: with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f: print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}") - with Context(PICKLE_BUFFERS=0): pickle.dump((tracked_keys, tracked_ctxs), f) + with Context(PICKLE_BUFFERS=0): pickle.dump((tracked_keys, tracked_ctxs, uop_fields), f) if VIZ: launch_viz("VIZ", temp("rewrites.pkl", append_user=True)) if getenv("PRINT_MATCH_STATS", 1): ret = [0,0,0.0,0.0] diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 0ec548e034..1b58120975 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, decimal, socketserver +import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, decimal, socketserver, functools from http.server import BaseHTTPRequestHandler from urllib.parse import parse_qs, urlparse from typing import Any, TypedDict, Generator @@ -74,11 +74,17 @@ def uop_to_json(x:UOp) -> dict[int, dict]: "ref":id(u.arg.ast) if u.op is Ops.KERNEL else None, "tag":u.tag} return graph +@functools.cache +def _reconstruct(a:int): + op, dtype, src, arg, tag = contexts[2][a] + arg = type(arg)(_reconstruct(arg.ast), arg.metadata) if op is Ops.KERNEL else arg + return UOp(op, dtype, tuple(_reconstruct(s) for s in src), arg, tag) + def get_details(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]: - yield {"graph":uop_to_json(next_sink:=ctx.sink), "uop":str(ctx.sink), "changed_nodes":None, "diff":None, "upat":None} + yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink)), "uop":str(next_sink), "changed_nodes":None, "diff":None, "upat":None} replaces: dict[UOp, UOp] = {} - for u0,u1,upat in tqdm(ctx.matches): - replaces[u0] = u1 + for u0_num,u1_num,upat in tqdm(ctx.matches): + replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num) try: new_sink = next_sink.substitute(replaces) except RecursionError as e: new_sink = UOp(Ops.NOOP, arg=str(e)) yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":str(new_sink), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json], @@ -189,7 +195,7 @@ if __name__ == "__main__": contexts, profile = load_pickle(args.kernels), load_pickle(args.profile) # NOTE: this context is a tuple of list[keys] and list[values] - ctxs = get_metadata(*contexts) if contexts is not None else [] + ctxs = get_metadata(*contexts[:2]) if contexts is not None else [] perfetto_profile = to_perfetto(profile) if profile is not None else None