mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
rtoposort is fast, can replace rangeify with this
This commit is contained in:
@@ -11,7 +11,7 @@ class ChildrenContext:
|
||||
# this is a generic child labeller
|
||||
def extract_children(ctx:ChildrenContext, x:UOp):
|
||||
if ctx.children is not None: return
|
||||
ctx.children = {k:list(v.keys()) for k,v in x.get_children_map().items() if len(v) > 1}
|
||||
ctx.children = {k:list(v.keys()) for k,v in x.get_consumer_map().items() if len(v) > 1}
|
||||
|
||||
def mark_children(ctx:ChildrenContext, x:UOp):
|
||||
new_srcs = [(UOp(Ops.CHILD, s.dtype, src=(s,), arg=(ctx.children[s].index(x), len(ctx.children[s]))) if s in ctx.children else s) for s in x.src]
|
||||
|
||||
@@ -136,7 +136,7 @@ remove_contig_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.r
|
||||
class ChildrenContext: children: dict[UOp, list[UOp]]|None = None
|
||||
def extract_children(ctx:ChildrenContext, x:UOp):
|
||||
if ctx.children is not None: return
|
||||
children_map = x.get_children_map()
|
||||
children_map = x.get_consumer_map()
|
||||
ctx.children = {}
|
||||
for k,v in children_map.items():
|
||||
# NOTE: we treat mstack children like sink here
|
||||
|
||||
@@ -146,14 +146,27 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
else: ret[node] = None # second time i'm seeing this node, add it to returned toposort
|
||||
return ret
|
||||
|
||||
# returns map of UOps to their children in the graph rooted by self
|
||||
def get_children_map(self) -> dict[UOp, dict[UOp, None]]:
|
||||
# returns map of UOps to their consumers in the graph rooted by self
|
||||
def get_consumer_map(self) -> dict[UOp, dict[UOp, None]]:
|
||||
ret: dict[UOp, dict[UOp, None]] = {}
|
||||
for u in self.toposort():
|
||||
ret[u] = {}
|
||||
for s in u.src: ret[s][u] = None
|
||||
return ret
|
||||
|
||||
def reverse_toposort(self) -> dict[UOp, None]:
|
||||
consumer_map = self.get_consumer_map()
|
||||
ret: dict[UOp, None] = {}
|
||||
stack: list[tuple[UOp, bool]] = [(x, False) for x in consumer_map if len(x.src) == 0]
|
||||
while stack:
|
||||
node, visited = stack.pop()
|
||||
if node in ret: continue
|
||||
if not visited:
|
||||
stack.append((node, True)) # push node back on stack to process after its srcs
|
||||
for s in consumer_map[node]: stack.append((s, False)) # push srcs on the stack
|
||||
else: ret[node] = None # second time i'm seeing this node, add it to returned toposort
|
||||
return ret
|
||||
|
||||
@functools.cached_property
|
||||
def tuplize(self:UOp) -> tuple:
|
||||
return (self.op.value, self.arg, self.dtype,)+tuple([x.tuplize for x in self.src])
|
||||
@@ -1189,7 +1202,7 @@ pm_pyrender = PatternMatcher([
|
||||
|
||||
@Context(SPEC=0)
|
||||
def pyrender(ast:UOp) -> list[str]:
|
||||
cmap = ast.get_children_map()
|
||||
cmap = ast.get_consumer_map()
|
||||
to_render = set()
|
||||
for u in ast.toposort():
|
||||
if u.op is Ops.STORE: to_render.add(u.src[1])
|
||||
|
||||
Reference in New Issue
Block a user