delete children tracking from uop (#12491)

* delete children tracking from uop

* uop children no longer exists

* no tracked children

* that test is flaky too
This commit is contained in:
George Hotz
2025-10-08 09:04:14 +08:00
committed by GitHub
parent 648e5bb223
commit 945cc46475
6 changed files with 15 additions and 121 deletions

View File

@@ -1137,6 +1137,7 @@ class TestMultiRamUsage(unittest.TestCase):
del _
self.assertUsed(0)
@unittest.skip("flaky")
def test_zeros_copy(self):
_ = Tensor.zeros(self.N, self.N).contiguous().to(devices_2).realize()
# NOTE: the first one on the DEFAULT device should be freed

View File

@@ -544,30 +544,6 @@ class TestUopsObject(unittest.TestCase):
with Timing("create 10k uops:"): ret = [UOp(Ops.CONST, dtypes.int, arg=10000000+i) for i in range(10000)]
assert len(ret) == 10000
class TestUOpChildren(unittest.TestCase):
def test_children_exist(self):
a = UOp.variable("weird_name_234", 0, 10)
b = a*a
self.assertEqual(len(a.children), 1)
self.assertIs(list(a.children)[0](), b)
def test_children_cleaned_up(self):
a = UOp.variable("weird_name_235", 0, 10)
b = a*a
self.assertEqual(len(a.children), 1)
del b
self.assertEqual(len(a.children), 0)
def test_children_cleaned_up_two(self):
a = UOp.variable("weird_name_236", 0, 10)
b = a*a
c = a*2
self.assertEqual(len(a.children), 2)
del b
self.assertEqual(len(a.children), 1)
del c
self.assertEqual(len(a.children), 0)
class TestUOpRender(unittest.TestCase):
def test_render_vectorize_same(self):
u = UOp(Ops.VECTORIZE, src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0)))

View File

@@ -1,63 +0,0 @@
import unittest
from tinygrad import Tensor
from tinygrad.uop.ops import PatternMatcher, Ops, UPat, graph_rewrite, RewriteContext, UOp
from tinygrad.schedule.kernelize import kernelize_sym, merge_views
class TestRewriteTrackedChildren(unittest.TestCase):
@unittest.skip("track_children no longer supported")
def test_children_in_context(self):
def print_children(ctx:RewriteContext, sink:UOp):
view_w_child = sink.src[0].src[0].src[0]
assert view_w_child.op is Ops.VIEW
assert set([x.arg for x in ctx.children[view_w_child]]) == set((2,3))
ctx.update_children()
assert set([x.arg for x in ctx.children[view_w_child]]) == set((3,4))
# this is the 3
assert len(ctx.children[sink.src[0].src[1]]) == 1
assert next(iter(ctx.children[sink.src[0].src[1]])).op is Ops.ADD
# this is the 4
assert len(ctx.children[sink.src[0].src[0]]) == 1
assert next(iter(ctx.children[sink.src[0].src[0]])).op is Ops.ADD
rewrite = PatternMatcher([
(UPat(Ops.CONST, arg=2, name="x"), lambda x: x.replace(arg=4)),
(UPat(Ops.SINK, name="sink"), print_children)
])
a = Tensor(2)
b = Tensor(3)
c = a + b
sink = c.uop.sink()
sink = graph_rewrite(sink, rewrite, track_children=True)
def test_simple_child(self):
rewrite = PatternMatcher([
(UPat(Ops.CONST, arg=2, name="x"), lambda x: x.replace(arg=4)),
])
a = Tensor(2)
b = Tensor(3)
c = a + b
sink = c.uop
view_w_child = a.uop.src[0]
print([x().arg for x in view_w_child.children])
print([x.arg for x in sink.get_children_map()[view_w_child]])
self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((2,3)))
# children can either be added to or removed from the map with graph_rewrite
# added to is easy to detect, just hook the UOp constructor
# when are children removed?
# * if a rewrite rule returns a UOp, the matched node is removed from the graph
sink = graph_rewrite(sink, rewrite)
print([x().arg for x in view_w_child.children])
print([x.arg for x in sink.get_children_map()[view_w_child]])
self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((3,4)))
@unittest.skip("track_children no longer supported")
def test_child_after_parent_update(self):
def print_children(ctx, r):
ctx.update_children()
print(ctx.children[r])
extra = PatternMatcher([(UPat(Ops.REDUCE_AXIS, name="r"), print_children)])
a = Tensor.empty(3, 3)
r = (a+0).sum()
graph_rewrite(r.uop, merge_views+kernelize_sym+extra, track_children=True)
if __name__ == '__main__':
unittest.main()

View File

@@ -23,33 +23,17 @@ from tinygrad.schedule.kernelize import get_kernelize_map
# *** all in scope Tensors are here. this gets relevant UOps ***
all_tensors: dict[weakref.ref[Tensor], None] = {}
def _find_all_tensors_for_uops(all_uops: set[UOp]) -> list[Tensor]:
return [t for tref in all_tensors if (t:=tref()) is not None and t.uop in all_uops]
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str|None=None) -> None:
# get all children of keys in applied_map
all_uops: set[UOp] = set()
search_uops = list(applied_map)
while len(search_uops):
x = search_uops.pop()
if x in all_uops: continue
all_uops.add(x)
search_uops.extend([u for c in x.children if (u:=c()) is not None])
fixed_tensors = [t for tref in all_tensors if (t:=tref()) is not None and (t.uop in applied_map or any(x in t.uop.parents for x in applied_map))]
# link the found UOps back to Tensors. exit early if there's no Tensors to realize
# NOTE: this uses all_tensors, but it's fast
if len(fixed_tensors := _find_all_tensors_for_uops(all_uops)):
# potentially rewrite all the discovered Tensors
sink = UOp.sink(*[t.uop for t in fixed_tensors])
new_sink = sink.substitute(applied_map, name=name)
# get all Tensors and apply the map
sink = UOp.sink(*[t.uop for t in fixed_tensors])
new_sink = sink.substitute(applied_map, name=name)
# NOTE: you can check the Tensor graph early here
#if __debug__: type_verify(list(new_sink.toposort()), tensor_uop_spec)
# set the relevant uop to the realized UOps
for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
if s is ns: continue
t.uop = ns
# set the relevant uop to the realized UOps
for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
if s is ns: continue
t.uop = ns
# **** Tensor helper functions ****

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from typing import Any, Callable, cast, TYPE_CHECKING, Type, Sequence
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref, collections
from dataclasses import dataclass, field
from dataclasses import dataclass
from enum import Enum, auto
from tinygrad.uop import Ops, GroupOp
from tinygrad.uop.mathtraits import MathTrait
@@ -62,8 +62,7 @@ class UOpMetaClass(type):
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, tag:Any=None,
metadata:tuple[Metadata,...]|None=None, _buffer:Buffer|None=None):
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg, tag), None)) is not None and (ret:=wret()) is not None: return ret
UOpMetaClass.ucache[key] = ref = weakref.ref(created:=super().__call__(*key))
for s in src: s.children.add(ref)
UOpMetaClass.ucache[key] = weakref.ref(created:=super().__call__(*key))
if metadata is not None: all_metadata[created] = metadata
# NOTE: this value is set by pickle when pickling a realized tensor
if _buffer is not None:
@@ -101,13 +100,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
src:tuple[UOp, ...] = tuple()
arg:Any = None
tag:Any = None
children:set[weakref.ref[UOp]] = field(default_factory=set)
def __del__(self):
if Ops is not None and self.op is Ops.BUFFER and (buffer:=buffers.get(self)) is not None: buffer.ref(-1)
try:
if (ref:=UOpMetaClass.ucache.get(k:=(self.op, self.dtype, self.src, self.arg, self.tag))) is not None:
for s in self.src: s.children.discard(ref)
del UOpMetaClass.ucache[k]
try: del UOpMetaClass.ucache[(self.op, self.dtype, self.src, self.arg, self.tag)]
except AttributeError: pass
def __reduce__(self):
args = [self.op, self.dtype, self.src, self.arg, self.tag, self.metadata]

View File

@@ -64,8 +64,9 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
# always exclude DEVICE/CONST/UNIQUE
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u)
# only exclude CONST VIEW source if it has no other children in the graph
if u.op is Ops.CONST and len(u.src) != 0 and all(cr.op is Ops.CONST for c in u.src[0].children if (cr:=c()) is not None and cr in toposort):
excluded.update(u.src)
# TODO: find a different way to do this, children isn't tracked
#if u.op is Ops.CONST and len(u.src) != 0 and all(cr.op is Ops.CONST for c in u.src[0].children if (cr:=c()) is not None and cr in toposort):
# excluded.update(u.src)
for u in toposort:
if u in excluded: continue
argst = codecs.decode(str(u.arg), "unicode_escape")