mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
fix n^2 _apply_map_to_tensors [pr] (#13443)
* clean up slow rules * fix rule * non n^2 toposort * topovisit * state dict profile_marker
This commit is contained in:
@@ -273,7 +273,9 @@ if __name__ == "__main__":
|
||||
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
|
||||
if not args.fakeweights:
|
||||
model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt')
|
||||
load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False)
|
||||
state_dict = torch_load(model_bin)['state_dict']
|
||||
profile_marker("state dict loaded")
|
||||
load_state_dict(model, state_dict, verbose=False, strict=False, realize=False)
|
||||
|
||||
if args.fp16:
|
||||
for k,v in get_state_dict(model).items():
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
|
||||
@@ -350,7 +350,7 @@ def flatten_bufferize(x:UOp):
|
||||
rngs = x.src[1:]
|
||||
ret = ret.forced_reshape(x.shape)
|
||||
if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs):
|
||||
sym_shape = tuple([ssimplify(r.src[0]) if r.op is not Ops.CONST else 1 for r in rngs])
|
||||
sym_shape = tuple([r.src[0] if r.op is not Ops.CONST else 1 for r in rngs])
|
||||
ret = ret.shrink(tuple([(0,x) for x in sym_shape]))
|
||||
return ret.rtag(x.tag)
|
||||
pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)])
|
||||
|
||||
@@ -28,8 +28,10 @@ def canonicalize_device(device:str|None) -> str: return Device.canonicalize(devi
|
||||
all_tensors: dict[weakref.ref[Tensor], None] = {}
|
||||
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None:
|
||||
with cpu_profile(TracingKey(name), "TINY"):
|
||||
scope_tensors = [t for tref in tuple(all_tensors) if (t:=tref()) is not None and
|
||||
(t.uop in applied_map or len(applied_map.keys() & t.uop.backward_slice.keys()))]
|
||||
# get tensors in scope
|
||||
in_scope: dict[UOp, bool] = {}
|
||||
def visitor(node: UOp) -> bool: return True if node in applied_map else any(in_scope.get(s, False) for s in node.src)
|
||||
scope_tensors = [t for tref in list(all_tensors) if (t:=tref()) is not None and t.uop.topovisit(visitor, in_scope)]
|
||||
|
||||
# get all Tensors and apply the map
|
||||
sink = UOp.sink(*[t.uop for t in scope_tensors])
|
||||
|
||||
@@ -157,17 +157,29 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
def op_in_backward_slice_with_self(self, *ops:Ops): return any(x.op in ops for x in self.backward_slice_with_self)
|
||||
|
||||
def toposort(self, gate:Callable|None=None) -> dict[UOp, None]:
|
||||
ret: dict[UOp, None] = {}
|
||||
cache: dict[UOp, None] = {}
|
||||
stack: list[tuple[UOp, bool]] = [(self, False)] # each stack entry is (node, visited_flag)
|
||||
while stack:
|
||||
node, visited = stack.pop()
|
||||
if node in ret: continue
|
||||
if node in cache: continue
|
||||
if not visited:
|
||||
if gate is None or gate(node):
|
||||
stack.append((node, True)) # push node back on stack to process after its srcs
|
||||
for s in reversed(node.src): stack.append((s, False)) # push srcs on the stack
|
||||
else: ret[node] = None # second time i'm seeing this node, add it to returned toposort
|
||||
return ret
|
||||
else: cache[node] = None # second time i'm seeing this node, add it to returned toposort
|
||||
return cache
|
||||
|
||||
def topovisit(self, visitor:Callable[[UOp], T], cache:dict[UOp, T]) -> T:
|
||||
# NOTE: this shares a lot of code with toposort
|
||||
stack: list[tuple[UOp, bool]] = [(self, False)]
|
||||
while stack:
|
||||
node, visited = stack.pop()
|
||||
if node in cache: continue
|
||||
if not visited:
|
||||
stack.append((node, True))
|
||||
for s in reversed(node.src): stack.append((s, False))
|
||||
else: cache[node] = visitor(node)
|
||||
return cache[self]
|
||||
|
||||
# returns map of UOps to their consumers in the graph rooted by self
|
||||
def get_consumer_map(self) -> dict[UOp, dict[UOp, None]]: return consumer_map_from_toposort(self.toposort())
|
||||
|
||||
Reference in New Issue
Block a user