fix n^2 _apply_map_to_tensors [pr] (#13443)

* clean up slow rules

* fix rule

* non n^2 toposort

* topovisit

* state dict profile_marker
This commit is contained in:
George Hotz
2025-11-24 18:59:16 -08:00
committed by GitHub
parent 249553a119
commit 8e8fec408e
4 changed files with 25 additions and 9 deletions

View File

@@ -273,7 +273,9 @@ if __name__ == "__main__":
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS): with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
if not args.fakeweights: if not args.fakeweights:
model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt') model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt')
load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False) state_dict = torch_load(model_bin)['state_dict']
profile_marker("state dict loaded")
load_state_dict(model, state_dict, verbose=False, strict=False, realize=False)
if args.fp16: if args.fp16:
for k,v in get_state_dict(model).items(): for k,v in get_state_dict(model).items():

View File

@@ -1,7 +1,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
import itertools import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
from tinygrad.uop.symbolic import symbolic from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
@@ -350,7 +350,7 @@ def flatten_bufferize(x:UOp):
rngs = x.src[1:] rngs = x.src[1:]
ret = ret.forced_reshape(x.shape) ret = ret.forced_reshape(x.shape)
if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs): if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs):
sym_shape = tuple([ssimplify(r.src[0]) if r.op is not Ops.CONST else 1 for r in rngs]) sym_shape = tuple([r.src[0] if r.op is not Ops.CONST else 1 for r in rngs])
ret = ret.shrink(tuple([(0,x) for x in sym_shape])) ret = ret.shrink(tuple([(0,x) for x in sym_shape]))
return ret.rtag(x.tag) return ret.rtag(x.tag)
pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)]) pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)])

View File

@@ -28,8 +28,10 @@ def canonicalize_device(device:str|None) -> str: return Device.canonicalize(devi
all_tensors: dict[weakref.ref[Tensor], None] = {} all_tensors: dict[weakref.ref[Tensor], None] = {}
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None: def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None:
with cpu_profile(TracingKey(name), "TINY"): with cpu_profile(TracingKey(name), "TINY"):
scope_tensors = [t for tref in tuple(all_tensors) if (t:=tref()) is not None and # get tensors in scope
(t.uop in applied_map or len(applied_map.keys() & t.uop.backward_slice.keys()))] in_scope: dict[UOp, bool] = {}
def visitor(node: UOp) -> bool: return True if node in applied_map else any(in_scope.get(s, False) for s in node.src)
scope_tensors = [t for tref in list(all_tensors) if (t:=tref()) is not None and t.uop.topovisit(visitor, in_scope)]
# get all Tensors and apply the map # get all Tensors and apply the map
sink = UOp.sink(*[t.uop for t in scope_tensors]) sink = UOp.sink(*[t.uop for t in scope_tensors])

View File

@@ -157,17 +157,29 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def op_in_backward_slice_with_self(self, *ops:Ops): return any(x.op in ops for x in self.backward_slice_with_self) def op_in_backward_slice_with_self(self, *ops:Ops): return any(x.op in ops for x in self.backward_slice_with_self)
def toposort(self, gate:Callable|None=None) -> dict[UOp, None]: def toposort(self, gate:Callable|None=None) -> dict[UOp, None]:
ret: dict[UOp, None] = {} cache: dict[UOp, None] = {}
stack: list[tuple[UOp, bool]] = [(self, False)] # each stack entry is (node, visited_flag) stack: list[tuple[UOp, bool]] = [(self, False)] # each stack entry is (node, visited_flag)
while stack: while stack:
node, visited = stack.pop() node, visited = stack.pop()
if node in ret: continue if node in cache: continue
if not visited: if not visited:
if gate is None or gate(node): if gate is None or gate(node):
stack.append((node, True)) # push node back on stack to process after its srcs stack.append((node, True)) # push node back on stack to process after its srcs
for s in reversed(node.src): stack.append((s, False)) # push srcs on the stack for s in reversed(node.src): stack.append((s, False)) # push srcs on the stack
else: ret[node] = None # second time i'm seeing this node, add it to returned toposort else: cache[node] = None # second time i'm seeing this node, add it to returned toposort
return ret return cache
def topovisit(self, visitor:Callable[[UOp], T], cache:dict[UOp, T]) -> T:
# NOTE: this shares a lot of code with toposort
stack: list[tuple[UOp, bool]] = [(self, False)]
while stack:
node, visited = stack.pop()
if node in cache: continue
if not visited:
stack.append((node, True))
for s in reversed(node.src): stack.append((s, False))
else: cache[node] = visitor(node)
return cache[self]
# returns map of UOps to their consumers in the graph rooted by self # returns map of UOps to their consumers in the graph rooted by self
def get_consumer_map(self) -> dict[UOp, dict[UOp, None]]: return consumer_map_from_toposort(self.toposort()) def get_consumer_map(self) -> dict[UOp, dict[UOp, None]]: return consumer_map_from_toposort(self.toposort())