mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
@@ -290,6 +290,13 @@ class TestVizIntegration(BaseTestViz):
|
||||
self.assertEqual(list(next(get_viz_details(1, 0))["graph"]), [id(c)])
|
||||
self.assertEqual(list(next(get_viz_details(1, 1))["graph"]), [id(c+2)])
|
||||
|
||||
def test_recurse(self):
|
||||
a = Tensor.empty(10)
|
||||
for _ in range(10_000): a += a
|
||||
graph_rewrite(a.uop, PatternMatcher([]))
|
||||
lst = get_viz_list()
|
||||
assert len(lst) == 1
|
||||
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry
|
||||
from tinygrad.viz.serve import get_profile
|
||||
|
||||
|
||||
@@ -281,6 +281,16 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
with Context(TRACK_MATCH_STATS=(0 if name is None else TRACK_MATCH_STATS.value)):
|
||||
return graph_rewrite(self, _substitute, dvars, bottom_up=True, name=name)
|
||||
|
||||
# *** uop tracing stuff ***
|
||||
|
||||
@recursive_property
|
||||
def trace_num(self):
|
||||
num = next(ucount)
|
||||
# KERNEL also has a UOp in the arg
|
||||
arg = type(self.arg)(self.arg.ast.trace_num, self.arg.metadata) if self.op is Ops.KERNEL else self.arg
|
||||
uop_fields[num] = (self.op, self.dtype, tuple(s.trace_num for s in self.src), arg, self.tag)+((self.metadata,) if TRACEMETA>=2 else ())
|
||||
return num
|
||||
|
||||
# *** uop syntactic sugar ***
|
||||
|
||||
@property
|
||||
@@ -905,15 +915,8 @@ class PatternMatcher:
|
||||
# *** non-blocking UOp tracker ***
|
||||
|
||||
ucount = itertools.count()
|
||||
uop_number:weakref.WeakKeyDictionary[UOp, int] = weakref.WeakKeyDictionary()
|
||||
uop_fields:dict[int, tuple] = {}
|
||||
def track_uop(u:UOp):
|
||||
if (cret:=uop_number.get(u)) is not None: return cret
|
||||
uop_number[u] = num = next(ucount)
|
||||
# KERNEL also has a UOp in the arg
|
||||
arg = type(u.arg)(track_uop(u.arg.ast), u.arg.metadata) if u.op is Ops.KERNEL else u.arg
|
||||
uop_fields[num] = (u.op, u.dtype, tuple(track_uop(s) for s in u.src), arg, u.tag)+((u.metadata,) if TRACEMETA>=2 else ())
|
||||
return num
|
||||
def track_uop(u:UOp): return u.trace_num
|
||||
|
||||
# *** tracking pattern matcher ***
|
||||
|
||||
|
||||
Reference in New Issue
Block a user