fix n^2 _apply_map_to_tensors [pr] (#13443)

* clean up slow rules

* fix rule

* non n^2 toposort

* topovisit

* state dict profile_marker
This commit is contained in:
George Hotz
2025-11-24 18:59:16 -08:00
committed by GitHub
parent 249553a119
commit 8e8fec408e
4 changed files with 25 additions and 9 deletions

View File

@@ -273,7 +273,9 @@ if __name__ == "__main__":
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
if not args.fakeweights:
model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt')
load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False)
state_dict = torch_load(model_bin)['state_dict']
profile_marker("state dict loaded")
load_state_dict(model, state_dict, verbose=False, strict=False, realize=False)
if args.fp16:
for k,v in get_state_dict(model).items():

View File

@@ -1,7 +1,7 @@
from dataclasses import dataclass, field
import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
@@ -350,7 +350,7 @@ def flatten_bufferize(x:UOp):
rngs = x.src[1:]
ret = ret.forced_reshape(x.shape)
if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs):
sym_shape = tuple([ssimplify(r.src[0]) if r.op is not Ops.CONST else 1 for r in rngs])
sym_shape = tuple([r.src[0] if r.op is not Ops.CONST else 1 for r in rngs])
ret = ret.shrink(tuple([(0,x) for x in sym_shape]))
return ret.rtag(x.tag)
pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)])

View File

@@ -28,8 +28,10 @@ def canonicalize_device(device:str|None) -> str: return Device.canonicalize(devi
all_tensors: dict[weakref.ref[Tensor], None] = {}
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None:
with cpu_profile(TracingKey(name), "TINY"):
scope_tensors = [t for tref in tuple(all_tensors) if (t:=tref()) is not None and
(t.uop in applied_map or len(applied_map.keys() & t.uop.backward_slice.keys()))]
# get tensors in scope
in_scope: dict[UOp, bool] = {}
def visitor(node: UOp) -> bool: return True if node in applied_map else any(in_scope.get(s, False) for s in node.src)
scope_tensors = [t for tref in list(all_tensors) if (t:=tref()) is not None and t.uop.topovisit(visitor, in_scope)]
# get all Tensors and apply the map
sink = UOp.sink(*[t.uop for t in scope_tensors])

View File

@@ -157,17 +157,29 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def op_in_backward_slice_with_self(self, *ops:Ops): return any(x.op in ops for x in self.backward_slice_with_self)
def toposort(self, gate:Callable|None=None) -> dict[UOp, None]:
ret: dict[UOp, None] = {}
cache: dict[UOp, None] = {}
stack: list[tuple[UOp, bool]] = [(self, False)] # each stack entry is (node, visited_flag)
while stack:
node, visited = stack.pop()
if node in ret: continue
if node in cache: continue
if not visited:
if gate is None or gate(node):
stack.append((node, True)) # push node back on stack to process after its srcs
for s in reversed(node.src): stack.append((s, False)) # push srcs on the stack
else: ret[node] = None # second time i'm seeing this node, add it to returned toposort
return ret
else: cache[node] = None # second time i'm seeing this node, add it to returned toposort
return cache
def topovisit(self, visitor:Callable[[UOp], T], cache:dict[UOp, T]) -> T:
# NOTE: this shares a lot of code with toposort
stack: list[tuple[UOp, bool]] = [(self, False)]
while stack:
node, visited = stack.pop()
if node in cache: continue
if not visited:
stack.append((node, True))
for s in reversed(node.src): stack.append((s, False))
else: cache[node] = visitor(node)
return cache[self]
# returns map of UOps to their consumers in the graph rooted by self
def get_consumer_map(self) -> dict[UOp, dict[UOp, None]]: return consumer_map_from_toposort(self.toposort())