From 8e8fec408ec458ff9127108cbdb9b3f603c56aed Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 24 Nov 2025 18:59:16 -0800 Subject: [PATCH] fix n^2 _apply_map_to_tensors [pr] (#13443) * clean up slow rules * fix rule * non n^2 toposort * topovisit * state dict profile_marker --- examples/stable_diffusion.py | 4 +++- tinygrad/schedule/rangeify.py | 4 ++-- tinygrad/tensor.py | 6 ++++-- tinygrad/uop/ops.py | 20 ++++++++++++++++---- 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 73a40c5604..61d3ec2a45 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -273,7 +273,9 @@ if __name__ == "__main__": with WallTimeEvent(BenchEvent.LOAD_WEIGHTS): if not args.fakeweights: model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt') - load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False) + state_dict = torch_load(model_bin)['state_dict'] + profile_marker("state dict loaded") + load_state_dict(model, state_dict, verbose=False, strict=False, realize=False) if args.fp16: for k,v in get_state_dict(model).items(): diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 044561522d..2caae91c21 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field import itertools from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace -from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo +from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str from tinygrad.uop.symbolic import symbolic from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY @@ -350,7 +350,7 @@ def flatten_bufferize(x:UOp): rngs = x.src[1:] ret = ret.forced_reshape(x.shape) if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs): - sym_shape = tuple([ssimplify(r.src[0]) if r.op is not Ops.CONST else 1 for r in rngs]) + sym_shape = tuple([r.src[0] if r.op is not Ops.CONST else 1 for r in rngs]) ret = ret.shrink(tuple([(0,x) for x in sym_shape])) return ret.rtag(x.tag) pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)]) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 341483498a..d0ffe374af 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -28,8 +28,10 @@ def canonicalize_device(device:str|None) -> str: return Device.canonicalize(devi all_tensors: dict[weakref.ref[Tensor], None] = {} def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None: with cpu_profile(TracingKey(name), "TINY"): - scope_tensors = [t for tref in tuple(all_tensors) if (t:=tref()) is not None and - (t.uop in applied_map or len(applied_map.keys() & t.uop.backward_slice.keys()))] + # get tensors in scope + in_scope: dict[UOp, bool] = {} + def visitor(node: UOp) -> bool: return True if node in applied_map else any(in_scope.get(s, False) for s in node.src) + scope_tensors = [t for tref in list(all_tensors) if (t:=tref()) is not None and t.uop.topovisit(visitor, in_scope)] # get all Tensors and apply the map sink = UOp.sink(*[t.uop for t in scope_tensors]) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 2c684b9b57..d7016daf3f 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -157,17 +157,29 @@ class UOp(OpMixin, metaclass=UOpMetaClass): def op_in_backward_slice_with_self(self, *ops:Ops): return any(x.op in ops for x in self.backward_slice_with_self) def toposort(self, gate:Callable|None=None) -> dict[UOp, None]: - ret: dict[UOp, None] = {} + cache: dict[UOp, None] = {} stack: list[tuple[UOp, bool]] = [(self, False)] # each stack entry is (node, visited_flag) while stack: node, visited = stack.pop() - if node in ret: continue + if node in cache: continue if not visited: if gate is None or gate(node): stack.append((node, True)) # push node back on stack to process after its srcs for s in reversed(node.src): stack.append((s, False)) # push srcs on the stack - else: ret[node] = None # second time i'm seeing this node, add it to returned toposort - return ret + else: cache[node] = None # second time i'm seeing this node, add it to returned toposort + return cache + + def topovisit(self, visitor:Callable[[UOp], T], cache:dict[UOp, T]) -> T: + # NOTE: this shares a lot of code with toposort + stack: list[tuple[UOp, bool]] = [(self, False)] + while stack: + node, visited = stack.pop() + if node in cache: continue + if not visited: + stack.append((node, True)) + for s in reversed(node.src): stack.append((s, False)) + else: cache[node] = visitor(node) + return cache[self] # returns map of UOps to their consumers in the graph rooted by self def get_consumer_map(self) -> dict[UOp, dict[UOp, None]]: return consumer_map_from_toposort(self.toposort())