mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
toposort recursive_property is faster (#13446)
This commit is contained in:
@@ -98,7 +98,6 @@ buffers:weakref.WeakKeyDictionary[UOp, Buffer|MultiBuffer] = weakref.WeakKeyDict
|
||||
all_metadata:weakref.WeakKeyDictionary[UOp, tuple[Metadata, ...]] = weakref.WeakKeyDictionary() # TODO: should this be here?
|
||||
|
||||
# recursive_property replaces functools.cached_property in recursive UOp functions to prevent RecursionError
|
||||
_NOT_FOUND = object()
|
||||
class recursive_property(property):
|
||||
def __init__(self, fxn):
|
||||
self.fxn = fxn
|
||||
@@ -106,10 +105,16 @@ class recursive_property(property):
|
||||
self.__doc__ = fxn.__doc__
|
||||
def __get__(self, x:UOp|None, owner=None):
|
||||
if x is None: return self
|
||||
if (val:=x.__dict__.get(self.nm, _NOT_FOUND)) is _NOT_FOUND:
|
||||
for s in x.toposort(lambda z: not hasattr(z, self.nm)):
|
||||
s.__dict__[self.nm] = val = self.fxn(s)
|
||||
return val
|
||||
# this is very similar to toposort/topovisit
|
||||
stack: list[tuple[UOp, bool]] = [(x, False)]
|
||||
while stack:
|
||||
node, visited = stack.pop()
|
||||
if self.nm in node.__dict__: continue
|
||||
if not visited:
|
||||
stack.append((node, True))
|
||||
for s in reversed(node.src): stack.append((s, False))
|
||||
else: node.__dict__[self.nm] = self.fxn(node)
|
||||
return x.__dict__[self.nm]
|
||||
|
||||
# we import this late so we can use resolve/smax in mixins
|
||||
from tinygrad.mixin import OpMixin
|
||||
|
||||
Reference in New Issue
Block a user