track apply map to tensors (#13399)

* track apply map to tensors

* sub
This commit is contained in:
George Hotz
2025-11-20 14:24:55 -08:00
committed by GitHub
parent 647fde64e6
commit a74f6020d5

View File

@@ -6,7 +6,7 @@ from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, Suppor
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, DEBUG, is_numpy_ndarray, SPEC
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, DEBUG, is_numpy_ndarray, SPEC, TracingKey, cpu_profile
from tinygrad.helpers import suppress_finalizing
from tinygrad.gradient import compute_gradient
from tinygrad.mixin import OpMixin
@@ -26,18 +26,19 @@ def canonicalize_device(device:str|None) -> str: return Device.canonicalize(devi
# *** all in scope Tensors are here. this gets relevant UOps ***
all_tensors: dict[weakref.ref[Tensor], None] = {}
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str|None=None) -> None:
scope_tensors = [t for tref in tuple(all_tensors) if (t:=tref()) is not None and
(t.uop in applied_map or len(applied_map.keys() & t.uop.backward_slice.keys()))]
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None:
with cpu_profile(TracingKey(name), "TINY"):
scope_tensors = [t for tref in tuple(all_tensors) if (t:=tref()) is not None and
(t.uop in applied_map or len(applied_map.keys() & t.uop.backward_slice.keys()))]
# get all Tensors and apply the map
sink = UOp.sink(*[t.uop for t in scope_tensors])
new_sink = sink.substitute(applied_map, name=name)
# get all Tensors and apply the map
sink = UOp.sink(*[t.uop for t in scope_tensors])
new_sink = sink.substitute(applied_map, name=f"substitute {name}")
# set the relevant uop to the realized UOps
for t,s,ns in zip(scope_tensors, sink.src, new_sink.src):
if s is ns: continue
t.uop = ns
# set the relevant uop to the realized UOps
for t,s,ns in zip(scope_tensors, sink.src, new_sink.src):
if s is ns: continue
t.uop = ns
# **** Tensor helper functions ****
@@ -229,7 +230,7 @@ class Tensor(OpMixin):
if SPEC: type_verify(big_sink, tensor_spec)
if any(isinstance(x._device, tuple) for x in big_sink.toposort()):
_apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map")
_apply_map_to_tensors(get_multi_map(big_sink), name="Apply Multi Map")
big_sink = UOp.sink(*flatten([x.uop.src if x.uop.op is Ops.MULTI else [x.uop] for x in (self,)+lst]))
becomes_map = get_rangeify_map(big_sink)