rtoposort is fast, can replace rangeify with this

This commit is contained in:
George Hotz
2025-10-08 11:06:12 +08:00
parent a6d59a0b45
commit d23324779c
3 changed files with 18 additions and 5 deletions

View File

@@ -11,7 +11,7 @@ class ChildrenContext:
# this is a generic child labeller
def extract_children(ctx:ChildrenContext, x:UOp):
if ctx.children is not None: return
ctx.children = {k:list(v.keys()) for k,v in x.get_children_map().items() if len(v) > 1}
ctx.children = {k:list(v.keys()) for k,v in x.get_consumer_map().items() if len(v) > 1}
def mark_children(ctx:ChildrenContext, x:UOp):
new_srcs = [(UOp(Ops.CHILD, s.dtype, src=(s,), arg=(ctx.children[s].index(x), len(ctx.children[s]))) if s in ctx.children else s) for s in x.src]

View File

@@ -136,7 +136,7 @@ remove_contig_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.r
class ChildrenContext: children: dict[UOp, list[UOp]]|None = None
def extract_children(ctx:ChildrenContext, x:UOp):
if ctx.children is not None: return
children_map = x.get_children_map()
children_map = x.get_consumer_map()
ctx.children = {}
for k,v in children_map.items():
# NOTE: we treat mstack children like sink here

View File

@@ -146,14 +146,27 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
else: ret[node] = None # second time i'm seeing this node, add it to returned toposort
return ret
# returns map of UOps to their children in the graph rooted by self
def get_children_map(self) -> dict[UOp, dict[UOp, None]]:
# returns map of UOps to their consumers in the graph rooted by self
def get_consumer_map(self) -> dict[UOp, dict[UOp, None]]:
ret: dict[UOp, dict[UOp, None]] = {}
for u in self.toposort():
ret[u] = {}
for s in u.src: ret[s][u] = None
return ret
def reverse_toposort(self) -> dict[UOp, None]:
consumer_map = self.get_consumer_map()
ret: dict[UOp, None] = {}
stack: list[tuple[UOp, bool]] = [(x, False) for x in consumer_map if len(x.src) == 0]
while stack:
node, visited = stack.pop()
if node in ret: continue
if not visited:
stack.append((node, True)) # push node back on stack to process after its srcs
for s in consumer_map[node]: stack.append((s, False)) # push srcs on the stack
else: ret[node] = None # second time i'm seeing this node, add it to returned toposort
return ret
@functools.cached_property
def tuplize(self:UOp) -> tuple:
return (self.op.value, self.arg, self.dtype,)+tuple([x.tuplize for x in self.src])
@@ -1189,7 +1202,7 @@ pm_pyrender = PatternMatcher([
@Context(SPEC=0)
def pyrender(ast:UOp) -> list[str]:
cmap = ast.get_children_map()
cmap = ast.get_consumer_map()
to_render = set()
for u in ast.toposort():
if u.op is Ops.STORE: to_render.add(u.src[1])