mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
_apply_map_to_tensors [pr] (#8683)
This commit is contained in:
@@ -18,6 +18,31 @@ from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||
|
||||
all_tensors: set[weakref.ref[Tensor]] = set()
|
||||
|
||||
def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None:
|
||||
# get all children of keys in applied_map
|
||||
all_uops: set[UOp] = set()
|
||||
search_uops = list(applied_map)
|
||||
while len(search_uops):
|
||||
x = search_uops.pop(0)
|
||||
if x in all_uops: continue
|
||||
all_uops.add(x)
|
||||
search_uops.extend([u for c in x.children if (u:=c()) is not None])
|
||||
|
||||
# link the found UOps back to Tensors. exit early if there's no Tensors to realize
|
||||
# NOTE: this uses all_tensors, but it's fast
|
||||
fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and any(x in all_uops for x in t.lazydata.lbs)]
|
||||
|
||||
if len(fixed_tensors):
|
||||
# potentially rewrite all the discovered Tensors
|
||||
sink = UOp.sink(*[UOp.sink(*t.lazydata.lbs) if isinstance(t.lazydata, MultiLazyBuffer) else t.lazydata for t in fixed_tensors])
|
||||
new_sink = sink.substitute(applied_map)
|
||||
|
||||
# set the relevant lazydata to the realized UOps
|
||||
for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
|
||||
if s is ns: continue
|
||||
if isinstance(t.lazydata, MultiLazyBuffer): t.lazydata.lbs = list(ns.src)
|
||||
else: t.lazydata = ns
|
||||
|
||||
# **** start with two base classes, Tensor and Function ****
|
||||
|
||||
class Function:
|
||||
@@ -231,31 +256,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
big_sink = UOp.sink(*flatten([x.lazydata.lbs for x in (self,)+lst]))
|
||||
schedule, var_vals, becomes_map = create_schedule_with_vars(big_sink)
|
||||
|
||||
# get all children of keys in becomes_map
|
||||
all_uops: set[UOp] = set()
|
||||
search_uops = list(becomes_map)
|
||||
while len(search_uops):
|
||||
x = search_uops.pop(0)
|
||||
if x in all_uops: continue
|
||||
all_uops.add(x)
|
||||
search_uops.extend([u for c in x.children if (u:=c()) is not None])
|
||||
|
||||
# link the found UOps back to Tensors. exit early if there's no Tensors to realize
|
||||
# NOTE: this uses all_tensors, but it's fast
|
||||
fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and any(x in all_uops for x in t.lazydata.lbs)]
|
||||
if len(fixed_tensors) == 0: return [], {}
|
||||
|
||||
# potentially rewrite all the discovered Tensors
|
||||
sink = UOp.sink(*[UOp.sink(*t.lazydata.lbs) if isinstance(t.lazydata, MultiLazyBuffer) else t.lazydata for t in fixed_tensors])
|
||||
new_sink = sink.substitute(becomes_map)
|
||||
|
||||
# set the relevant lazydata to the realized UOps
|
||||
for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
|
||||
if s is ns: continue
|
||||
if isinstance(t.lazydata, MultiLazyBuffer): t.lazydata.lbs = list(ns.src)
|
||||
else: t.lazydata = ns
|
||||
|
||||
_apply_map_to_tensors(becomes_map)
|
||||
return memory_planner(schedule), var_vals
|
||||
|
||||
def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
|
||||
|
||||
Reference in New Issue
Block a user