_apply_map_to_tensors [pr] (#8683)

This commit is contained in:
George Hotz
2025-01-19 17:56:04 -08:00
committed by GitHub
parent 98d01a059d
commit 4198bce150

View File

@@ -18,6 +18,31 @@ from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
all_tensors: set[weakref.ref[Tensor]] = set()
def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None:
# get all children of keys in applied_map
all_uops: set[UOp] = set()
search_uops = list(applied_map)
while len(search_uops):
x = search_uops.pop(0)
if x in all_uops: continue
all_uops.add(x)
search_uops.extend([u for c in x.children if (u:=c()) is not None])
# link the found UOps back to Tensors. exit early if there's no Tensors to realize
# NOTE: this uses all_tensors, but it's fast
fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and any(x in all_uops for x in t.lazydata.lbs)]
if len(fixed_tensors):
# potentially rewrite all the discovered Tensors
sink = UOp.sink(*[UOp.sink(*t.lazydata.lbs) if isinstance(t.lazydata, MultiLazyBuffer) else t.lazydata for t in fixed_tensors])
new_sink = sink.substitute(applied_map)
# set the relevant lazydata to the realized UOps
for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
if s is ns: continue
if isinstance(t.lazydata, MultiLazyBuffer): t.lazydata.lbs = list(ns.src)
else: t.lazydata = ns
# **** start with two base classes, Tensor and Function ****
class Function:
@@ -231,31 +256,7 @@ class Tensor(SimpleMathTrait):
"""
big_sink = UOp.sink(*flatten([x.lazydata.lbs for x in (self,)+lst]))
schedule, var_vals, becomes_map = create_schedule_with_vars(big_sink)
# get all children of keys in becomes_map
all_uops: set[UOp] = set()
search_uops = list(becomes_map)
while len(search_uops):
x = search_uops.pop(0)
if x in all_uops: continue
all_uops.add(x)
search_uops.extend([u for c in x.children if (u:=c()) is not None])
# link the found UOps back to Tensors. exit early if there's no Tensors to realize
# NOTE: this uses all_tensors, but it's fast
fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and any(x in all_uops for x in t.lazydata.lbs)]
if len(fixed_tensors) == 0: return [], {}
# potentially rewrite all the discovered Tensors
sink = UOp.sink(*[UOp.sink(*t.lazydata.lbs) if isinstance(t.lazydata, MultiLazyBuffer) else t.lazydata for t in fixed_tensors])
new_sink = sink.substitute(becomes_map)
# set the relevant lazydata to the realized UOps
for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
if s is ns: continue
if isinstance(t.lazydata, MultiLazyBuffer): t.lazydata.lbs = list(ns.src)
else: t.lazydata = ns
_apply_map_to_tensors(becomes_map)
return memory_planner(schedule), var_vals
def schedule(self, *lst:Tensor) -> list[ScheduleItem]: