mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
fix n^2 _apply_map_to_tensors [pr] (#13443)
* clean up slow rules * fix rule * non n^2 toposort * topovisit * state dict profile_marker
This commit is contained in:
@@ -273,7 +273,9 @@ if __name__ == "__main__":
|
|||||||
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
|
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
|
||||||
if not args.fakeweights:
|
if not args.fakeweights:
|
||||||
model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt')
|
model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt')
|
||||||
load_state_dict(model, torch_load(model_bin)['state_dict'], verbose=False, strict=False, realize=False)
|
state_dict = torch_load(model_bin)['state_dict']
|
||||||
|
profile_marker("state dict loaded")
|
||||||
|
load_state_dict(model, state_dict, verbose=False, strict=False, realize=False)
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
for k,v in get_state_dict(model).items():
|
for k,v in get_state_dict(model).items():
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
import itertools
|
import itertools
|
||||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo
|
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
|
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
|
||||||
from tinygrad.uop.symbolic import symbolic
|
from tinygrad.uop.symbolic import symbolic
|
||||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
|
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
|
||||||
@@ -350,7 +350,7 @@ def flatten_bufferize(x:UOp):
|
|||||||
rngs = x.src[1:]
|
rngs = x.src[1:]
|
||||||
ret = ret.forced_reshape(x.shape)
|
ret = ret.forced_reshape(x.shape)
|
||||||
if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs):
|
if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs):
|
||||||
sym_shape = tuple([ssimplify(r.src[0]) if r.op is not Ops.CONST else 1 for r in rngs])
|
sym_shape = tuple([r.src[0] if r.op is not Ops.CONST else 1 for r in rngs])
|
||||||
ret = ret.shrink(tuple([(0,x) for x in sym_shape]))
|
ret = ret.shrink(tuple([(0,x) for x in sym_shape]))
|
||||||
return ret.rtag(x.tag)
|
return ret.rtag(x.tag)
|
||||||
pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)])
|
pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)])
|
||||||
|
|||||||
@@ -28,8 +28,10 @@ def canonicalize_device(device:str|None) -> str: return Device.canonicalize(devi
|
|||||||
all_tensors: dict[weakref.ref[Tensor], None] = {}
|
all_tensors: dict[weakref.ref[Tensor], None] = {}
|
||||||
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None:
|
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None:
|
||||||
with cpu_profile(TracingKey(name), "TINY"):
|
with cpu_profile(TracingKey(name), "TINY"):
|
||||||
scope_tensors = [t for tref in tuple(all_tensors) if (t:=tref()) is not None and
|
# get tensors in scope
|
||||||
(t.uop in applied_map or len(applied_map.keys() & t.uop.backward_slice.keys()))]
|
in_scope: dict[UOp, bool] = {}
|
||||||
|
def visitor(node: UOp) -> bool: return True if node in applied_map else any(in_scope.get(s, False) for s in node.src)
|
||||||
|
scope_tensors = [t for tref in list(all_tensors) if (t:=tref()) is not None and t.uop.topovisit(visitor, in_scope)]
|
||||||
|
|
||||||
# get all Tensors and apply the map
|
# get all Tensors and apply the map
|
||||||
sink = UOp.sink(*[t.uop for t in scope_tensors])
|
sink = UOp.sink(*[t.uop for t in scope_tensors])
|
||||||
|
|||||||
@@ -157,17 +157,29 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||||||
def op_in_backward_slice_with_self(self, *ops:Ops): return any(x.op in ops for x in self.backward_slice_with_self)
|
def op_in_backward_slice_with_self(self, *ops:Ops): return any(x.op in ops for x in self.backward_slice_with_self)
|
||||||
|
|
||||||
def toposort(self, gate:Callable|None=None) -> dict[UOp, None]:
|
def toposort(self, gate:Callable|None=None) -> dict[UOp, None]:
|
||||||
ret: dict[UOp, None] = {}
|
cache: dict[UOp, None] = {}
|
||||||
stack: list[tuple[UOp, bool]] = [(self, False)] # each stack entry is (node, visited_flag)
|
stack: list[tuple[UOp, bool]] = [(self, False)] # each stack entry is (node, visited_flag)
|
||||||
while stack:
|
while stack:
|
||||||
node, visited = stack.pop()
|
node, visited = stack.pop()
|
||||||
if node in ret: continue
|
if node in cache: continue
|
||||||
if not visited:
|
if not visited:
|
||||||
if gate is None or gate(node):
|
if gate is None or gate(node):
|
||||||
stack.append((node, True)) # push node back on stack to process after its srcs
|
stack.append((node, True)) # push node back on stack to process after its srcs
|
||||||
for s in reversed(node.src): stack.append((s, False)) # push srcs on the stack
|
for s in reversed(node.src): stack.append((s, False)) # push srcs on the stack
|
||||||
else: ret[node] = None # second time i'm seeing this node, add it to returned toposort
|
else: cache[node] = None # second time i'm seeing this node, add it to returned toposort
|
||||||
return ret
|
return cache
|
||||||
|
|
||||||
|
def topovisit(self, visitor:Callable[[UOp], T], cache:dict[UOp, T]) -> T:
|
||||||
|
# NOTE: this shares a lot of code with toposort
|
||||||
|
stack: list[tuple[UOp, bool]] = [(self, False)]
|
||||||
|
while stack:
|
||||||
|
node, visited = stack.pop()
|
||||||
|
if node in cache: continue
|
||||||
|
if not visited:
|
||||||
|
stack.append((node, True))
|
||||||
|
for s in reversed(node.src): stack.append((s, False))
|
||||||
|
else: cache[node] = visitor(node)
|
||||||
|
return cache[self]
|
||||||
|
|
||||||
# returns map of UOps to their consumers in the graph rooted by self
|
# returns map of UOps to their consumers in the graph rooted by self
|
||||||
def get_consumer_map(self) -> dict[UOp, dict[UOp, None]]: return consumer_map_from_toposort(self.toposort())
|
def get_consumer_map(self) -> dict[UOp, dict[UOp, None]]: return consumer_map_from_toposort(self.toposort())
|
||||||
|
|||||||
Reference in New Issue
Block a user