use recursive_property in UOp tracing (#12469)

* test

* simple passing
This commit is contained in:
qazal
2025-10-06 21:10:52 +03:00
committed by GitHub
parent 1af05dae77
commit f664bcc8bd
2 changed files with 18 additions and 8 deletions

View File

@@ -290,6 +290,13 @@ class TestVizIntegration(BaseTestViz):
self.assertEqual(list(next(get_viz_details(1, 0))["graph"]), [id(c)])
self.assertEqual(list(next(get_viz_details(1, 1))["graph"]), [id(c+2)])
def test_recurse(self):
a = Tensor.empty(10)
for _ in range(10_000): a += a
graph_rewrite(a.uop, PatternMatcher([]))
lst = get_viz_list()
assert len(lst) == 1
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry
from tinygrad.viz.serve import get_profile

View File

@@ -281,6 +281,16 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
with Context(TRACK_MATCH_STATS=(0 if name is None else TRACK_MATCH_STATS.value)):
return graph_rewrite(self, _substitute, dvars, bottom_up=True, name=name)
# *** uop tracing stuff ***
@recursive_property
def trace_num(self):
num = next(ucount)
# KERNEL also has a UOp in the arg
arg = type(self.arg)(self.arg.ast.trace_num, self.arg.metadata) if self.op is Ops.KERNEL else self.arg
uop_fields[num] = (self.op, self.dtype, tuple(s.trace_num for s in self.src), arg, self.tag)+((self.metadata,) if TRACEMETA>=2 else ())
return num
# *** uop syntactic sugar ***
@property
@@ -905,15 +915,8 @@ class PatternMatcher:
# *** non-blocking UOp tracker ***
ucount = itertools.count()
uop_number:weakref.WeakKeyDictionary[UOp, int] = weakref.WeakKeyDictionary()
uop_fields:dict[int, tuple] = {}
def track_uop(u:UOp):
if (cret:=uop_number.get(u)) is not None: return cret
uop_number[u] = num = next(ucount)
# KERNEL also has a UOp in the arg
arg = type(u.arg)(track_uop(u.arg.ast), u.arg.metadata) if u.op is Ops.KERNEL else u.arg
uop_fields[num] = (u.op, u.dtype, tuple(track_uop(s) for s in u.src), arg, u.tag)+((u.metadata,) if TRACEMETA>=2 else ())
return num
def track_uop(u:UOp): return u.trace_num
# *** tracking pattern matcher ***