mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
delete children tracking from uop (#12491)
* delete children tracking from uop * uop children no longer exists * no tracked children * that test is flaky too
This commit is contained in:
@@ -1137,6 +1137,7 @@ class TestMultiRamUsage(unittest.TestCase):
|
||||
del _
|
||||
self.assertUsed(0)
|
||||
|
||||
@unittest.skip("flaky")
|
||||
def test_zeros_copy(self):
|
||||
_ = Tensor.zeros(self.N, self.N).contiguous().to(devices_2).realize()
|
||||
# NOTE: the first one on the DEFAULT device should be freed
|
||||
|
||||
@@ -544,30 +544,6 @@ class TestUopsObject(unittest.TestCase):
|
||||
with Timing("create 10k uops:"): ret = [UOp(Ops.CONST, dtypes.int, arg=10000000+i) for i in range(10000)]
|
||||
assert len(ret) == 10000
|
||||
|
||||
class TestUOpChildren(unittest.TestCase):
|
||||
def test_children_exist(self):
|
||||
a = UOp.variable("weird_name_234", 0, 10)
|
||||
b = a*a
|
||||
self.assertEqual(len(a.children), 1)
|
||||
self.assertIs(list(a.children)[0](), b)
|
||||
|
||||
def test_children_cleaned_up(self):
|
||||
a = UOp.variable("weird_name_235", 0, 10)
|
||||
b = a*a
|
||||
self.assertEqual(len(a.children), 1)
|
||||
del b
|
||||
self.assertEqual(len(a.children), 0)
|
||||
|
||||
def test_children_cleaned_up_two(self):
|
||||
a = UOp.variable("weird_name_236", 0, 10)
|
||||
b = a*a
|
||||
c = a*2
|
||||
self.assertEqual(len(a.children), 2)
|
||||
del b
|
||||
self.assertEqual(len(a.children), 1)
|
||||
del c
|
||||
self.assertEqual(len(a.children), 0)
|
||||
|
||||
class TestUOpRender(unittest.TestCase):
|
||||
def test_render_vectorize_same(self):
|
||||
u = UOp(Ops.VECTORIZE, src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 0)))
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.uop.ops import PatternMatcher, Ops, UPat, graph_rewrite, RewriteContext, UOp
|
||||
from tinygrad.schedule.kernelize import kernelize_sym, merge_views
|
||||
|
||||
class TestRewriteTrackedChildren(unittest.TestCase):
|
||||
@unittest.skip("track_children no longer supported")
|
||||
def test_children_in_context(self):
|
||||
def print_children(ctx:RewriteContext, sink:UOp):
|
||||
view_w_child = sink.src[0].src[0].src[0]
|
||||
assert view_w_child.op is Ops.VIEW
|
||||
assert set([x.arg for x in ctx.children[view_w_child]]) == set((2,3))
|
||||
ctx.update_children()
|
||||
assert set([x.arg for x in ctx.children[view_w_child]]) == set((3,4))
|
||||
# this is the 3
|
||||
assert len(ctx.children[sink.src[0].src[1]]) == 1
|
||||
assert next(iter(ctx.children[sink.src[0].src[1]])).op is Ops.ADD
|
||||
# this is the 4
|
||||
assert len(ctx.children[sink.src[0].src[0]]) == 1
|
||||
assert next(iter(ctx.children[sink.src[0].src[0]])).op is Ops.ADD
|
||||
rewrite = PatternMatcher([
|
||||
(UPat(Ops.CONST, arg=2, name="x"), lambda x: x.replace(arg=4)),
|
||||
(UPat(Ops.SINK, name="sink"), print_children)
|
||||
])
|
||||
a = Tensor(2)
|
||||
b = Tensor(3)
|
||||
c = a + b
|
||||
sink = c.uop.sink()
|
||||
sink = graph_rewrite(sink, rewrite, track_children=True)
|
||||
|
||||
def test_simple_child(self):
|
||||
rewrite = PatternMatcher([
|
||||
(UPat(Ops.CONST, arg=2, name="x"), lambda x: x.replace(arg=4)),
|
||||
])
|
||||
a = Tensor(2)
|
||||
b = Tensor(3)
|
||||
c = a + b
|
||||
sink = c.uop
|
||||
view_w_child = a.uop.src[0]
|
||||
print([x().arg for x in view_w_child.children])
|
||||
print([x.arg for x in sink.get_children_map()[view_w_child]])
|
||||
self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((2,3)))
|
||||
# children can either be added to or removed from the map with graph_rewrite
|
||||
# added to is easy to detect, just hook the UOp constructor
|
||||
# when are children removed?
|
||||
# * if a rewrite rule returns a UOp, the matched node is removed from the graph
|
||||
sink = graph_rewrite(sink, rewrite)
|
||||
print([x().arg for x in view_w_child.children])
|
||||
print([x.arg for x in sink.get_children_map()[view_w_child]])
|
||||
self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((3,4)))
|
||||
|
||||
@unittest.skip("track_children no longer supported")
|
||||
def test_child_after_parent_update(self):
|
||||
def print_children(ctx, r):
|
||||
ctx.update_children()
|
||||
print(ctx.children[r])
|
||||
extra = PatternMatcher([(UPat(Ops.REDUCE_AXIS, name="r"), print_children)])
|
||||
a = Tensor.empty(3, 3)
|
||||
r = (a+0).sum()
|
||||
graph_rewrite(r.uop, merge_views+kernelize_sym+extra, track_children=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -23,33 +23,17 @@ from tinygrad.schedule.kernelize import get_kernelize_map
|
||||
# *** all in scope Tensors are here. this gets relevant UOps ***
|
||||
|
||||
all_tensors: dict[weakref.ref[Tensor], None] = {}
|
||||
def _find_all_tensors_for_uops(all_uops: set[UOp]) -> list[Tensor]:
|
||||
return [t for tref in all_tensors if (t:=tref()) is not None and t.uop in all_uops]
|
||||
|
||||
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str|None=None) -> None:
|
||||
# get all children of keys in applied_map
|
||||
all_uops: set[UOp] = set()
|
||||
search_uops = list(applied_map)
|
||||
while len(search_uops):
|
||||
x = search_uops.pop()
|
||||
if x in all_uops: continue
|
||||
all_uops.add(x)
|
||||
search_uops.extend([u for c in x.children if (u:=c()) is not None])
|
||||
fixed_tensors = [t for tref in all_tensors if (t:=tref()) is not None and (t.uop in applied_map or any(x in t.uop.parents for x in applied_map))]
|
||||
|
||||
# link the found UOps back to Tensors. exit early if there's no Tensors to realize
|
||||
# NOTE: this uses all_tensors, but it's fast
|
||||
if len(fixed_tensors := _find_all_tensors_for_uops(all_uops)):
|
||||
# potentially rewrite all the discovered Tensors
|
||||
sink = UOp.sink(*[t.uop for t in fixed_tensors])
|
||||
new_sink = sink.substitute(applied_map, name=name)
|
||||
# get all Tensors and apply the map
|
||||
sink = UOp.sink(*[t.uop for t in fixed_tensors])
|
||||
new_sink = sink.substitute(applied_map, name=name)
|
||||
|
||||
# NOTE: you can check the Tensor graph early here
|
||||
#if __debug__: type_verify(list(new_sink.toposort()), tensor_uop_spec)
|
||||
|
||||
# set the relevant uop to the realized UOps
|
||||
for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
|
||||
if s is ns: continue
|
||||
t.uop = ns
|
||||
# set the relevant uop to the realized UOps
|
||||
for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
|
||||
if s is ns: continue
|
||||
t.uop = ns
|
||||
|
||||
# **** Tensor helper functions ****
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, Callable, cast, TYPE_CHECKING, Type, Sequence
|
||||
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref, collections
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from tinygrad.uop import Ops, GroupOp
|
||||
from tinygrad.uop.mathtraits import MathTrait
|
||||
@@ -62,8 +62,7 @@ class UOpMetaClass(type):
|
||||
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, tag:Any=None,
|
||||
metadata:tuple[Metadata,...]|None=None, _buffer:Buffer|None=None):
|
||||
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg, tag), None)) is not None and (ret:=wret()) is not None: return ret
|
||||
UOpMetaClass.ucache[key] = ref = weakref.ref(created:=super().__call__(*key))
|
||||
for s in src: s.children.add(ref)
|
||||
UOpMetaClass.ucache[key] = weakref.ref(created:=super().__call__(*key))
|
||||
if metadata is not None: all_metadata[created] = metadata
|
||||
# NOTE: this value is set by pickle when pickling a realized tensor
|
||||
if _buffer is not None:
|
||||
@@ -101,13 +100,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
src:tuple[UOp, ...] = tuple()
|
||||
arg:Any = None
|
||||
tag:Any = None
|
||||
children:set[weakref.ref[UOp]] = field(default_factory=set)
|
||||
def __del__(self):
|
||||
if Ops is not None and self.op is Ops.BUFFER and (buffer:=buffers.get(self)) is not None: buffer.ref(-1)
|
||||
try:
|
||||
if (ref:=UOpMetaClass.ucache.get(k:=(self.op, self.dtype, self.src, self.arg, self.tag))) is not None:
|
||||
for s in self.src: s.children.discard(ref)
|
||||
del UOpMetaClass.ucache[k]
|
||||
try: del UOpMetaClass.ucache[(self.op, self.dtype, self.src, self.arg, self.tag)]
|
||||
except AttributeError: pass
|
||||
def __reduce__(self):
|
||||
args = [self.op, self.dtype, self.src, self.arg, self.tag, self.metadata]
|
||||
|
||||
@@ -64,8 +64,9 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
# always exclude DEVICE/CONST/UNIQUE
|
||||
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u)
|
||||
# only exclude CONST VIEW source if it has no other children in the graph
|
||||
if u.op is Ops.CONST and len(u.src) != 0 and all(cr.op is Ops.CONST for c in u.src[0].children if (cr:=c()) is not None and cr in toposort):
|
||||
excluded.update(u.src)
|
||||
# TODO: find a different way to do this, children isn't tracked
|
||||
#if u.op is Ops.CONST and len(u.src) != 0 and all(cr.op is Ops.CONST for c in u.src[0].children if (cr:=c()) is not None and cr in toposort):
|
||||
# excluded.update(u.src)
|
||||
for u in toposort:
|
||||
if u in excluded: continue
|
||||
argst = codecs.decode(str(u.arg), "unicode_escape")
|
||||
|
||||
Reference in New Issue
Block a user