From f664bcc8bd7298f570cae02184702d94f50771cd Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 6 Oct 2025 21:10:52 +0300 Subject: [PATCH] use recursive_property in UOp tracing (#12469) * test * simple passing --- test/unit/test_viz.py | 7 +++++++ tinygrad/uop/ops.py | 19 +++++++++++-------- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index 7ecdbe4172..7212b56aff 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -290,6 +290,13 @@ class TestVizIntegration(BaseTestViz): self.assertEqual(list(next(get_viz_details(1, 0))["graph"]), [id(c)]) self.assertEqual(list(next(get_viz_details(1, 1))["graph"]), [id(c+2)]) + def test_recurse(self): + a = Tensor.empty(10) + for _ in range(10_000): a += a + graph_rewrite(a.uop, PatternMatcher([])) + lst = get_viz_list() + assert len(lst) == 1 + from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry from tinygrad.viz.serve import get_profile diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index f6f03fd0d5..f790672619 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -281,6 +281,16 @@ class UOp(MathTrait, metaclass=UOpMetaClass): with Context(TRACK_MATCH_STATS=(0 if name is None else TRACK_MATCH_STATS.value)): return graph_rewrite(self, _substitute, dvars, bottom_up=True, name=name) + # *** uop tracing stuff *** + + @recursive_property + def trace_num(self): + num = next(ucount) + # KERNEL also has a UOp in the arg + arg = type(self.arg)(self.arg.ast.trace_num, self.arg.metadata) if self.op is Ops.KERNEL else self.arg + uop_fields[num] = (self.op, self.dtype, tuple(s.trace_num for s in self.src), arg, self.tag)+((self.metadata,) if TRACEMETA>=2 else ()) + return num + # *** uop syntactic sugar *** @property @@ -905,15 +915,8 @@ class PatternMatcher: # *** non-blocking UOp tracker *** ucount = itertools.count() -uop_number:weakref.WeakKeyDictionary[UOp, int] = weakref.WeakKeyDictionary() uop_fields:dict[int, tuple] = {} -def track_uop(u:UOp): - if (cret:=uop_number.get(u)) is not None: return cret - uop_number[u] = num = next(ucount) - # KERNEL also has a UOp in the arg - arg = type(u.arg)(track_uop(u.arg.ast), u.arg.metadata) if u.op is Ops.KERNEL else u.arg - uop_fields[num] = (u.op, u.dtype, tuple(track_uop(s) for s in u.src), arg, u.tag)+((u.metadata,) if TRACEMETA>=2 else ()) - return num +def track_uop(u:UOp): return u.trace_num # *** tracking pattern matcher ***