diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index c4e995ae06..c56b635b98 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -3,6 +3,7 @@ import functools, operator, itertools from dataclasses import dataclass, field from tinygrad.dtype import dtypes, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType, profile_matches +from tinygrad.uop.ops import consumer_map_from_toposort from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored @@ -163,13 +164,13 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: # get ops to realize graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, name="get realize") - # get the traversal order - with cpu_profile("reverse toposort", "TINY"): - tsink_reverse_toposort = tsink.reverse_toposort(consumer_map:=tsink.get_consumer_map()) + # get the consumer map + with cpu_profile("consumer map in rangeify", "TINY"): + consumer_map = consumer_map_from_toposort(tsink_toposort:=tsink.toposort()) # explicit rangeify ending_ranges: dict[UOp, list[UOp]] = {} - for x in tsink_reverse_toposort: + for x in reversed(tsink_toposort): if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue # no ranges on kernels, they are internal diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index e23b8e5943..529cfa0295 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -190,18 +190,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass): # returns map of UOps to their consumers in the graph rooted by self def get_consumer_map(self) -> dict[UOp, dict[UOp, None]]: return consumer_map_from_toposort(self.toposort()) - def reverse_toposort(self, consumer_map) -> dict[UOp, None]: - ret: dict[UOp, None] = {} - stack: list[tuple[UOp, bool]] = [(x, False) for x in consumer_map if len(x.src) == 0] - while stack: - node, visited = stack.pop() - if node in ret: continue - if not visited: - stack.append((node, True)) # push node back on stack to process after its srcs - for s in consumer_map[node]: stack.append((s, False)) # push srcs on the stack - else: ret[node] = None # second time i'm seeing this node, add it to returned toposort - return ret - @functools.cached_property def tuplize(self:UOp) -> tuple: return (self.op.value, self.arg, self.dtype,)+tuple([x.tuplize for x in self.src])