mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
@@ -6,7 +6,7 @@ from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, Suppor
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
|
||||
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, DEBUG, is_numpy_ndarray, SPEC
|
||||
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, DEBUG, is_numpy_ndarray, SPEC, TracingKey, cpu_profile
|
||||
from tinygrad.helpers import suppress_finalizing
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.mixin import OpMixin
|
||||
@@ -26,18 +26,19 @@ def canonicalize_device(device:str|None) -> str: return Device.canonicalize(devi
|
||||
# *** all in scope Tensors are here. this gets relevant UOps ***
|
||||
|
||||
all_tensors: dict[weakref.ref[Tensor], None] = {}
|
||||
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str|None=None) -> None:
|
||||
scope_tensors = [t for tref in tuple(all_tensors) if (t:=tref()) is not None and
|
||||
(t.uop in applied_map or len(applied_map.keys() & t.uop.backward_slice.keys()))]
|
||||
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None:
|
||||
with cpu_profile(TracingKey(name), "TINY"):
|
||||
scope_tensors = [t for tref in tuple(all_tensors) if (t:=tref()) is not None and
|
||||
(t.uop in applied_map or len(applied_map.keys() & t.uop.backward_slice.keys()))]
|
||||
|
||||
# get all Tensors and apply the map
|
||||
sink = UOp.sink(*[t.uop for t in scope_tensors])
|
||||
new_sink = sink.substitute(applied_map, name=name)
|
||||
# get all Tensors and apply the map
|
||||
sink = UOp.sink(*[t.uop for t in scope_tensors])
|
||||
new_sink = sink.substitute(applied_map, name=f"substitute {name}")
|
||||
|
||||
# set the relevant uop to the realized UOps
|
||||
for t,s,ns in zip(scope_tensors, sink.src, new_sink.src):
|
||||
if s is ns: continue
|
||||
t.uop = ns
|
||||
# set the relevant uop to the realized UOps
|
||||
for t,s,ns in zip(scope_tensors, sink.src, new_sink.src):
|
||||
if s is ns: continue
|
||||
t.uop = ns
|
||||
|
||||
# **** Tensor helper functions ****
|
||||
|
||||
@@ -229,7 +230,7 @@ class Tensor(OpMixin):
|
||||
if SPEC: type_verify(big_sink, tensor_spec)
|
||||
|
||||
if any(isinstance(x._device, tuple) for x in big_sink.toposort()):
|
||||
_apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map")
|
||||
_apply_map_to_tensors(get_multi_map(big_sink), name="Apply Multi Map")
|
||||
big_sink = UOp.sink(*flatten([x.uop.src if x.uop.op is Ops.MULTI else [x.uop] for x in (self,)+lst]))
|
||||
|
||||
becomes_map = get_rangeify_map(big_sink)
|
||||
|
||||
Reference in New Issue
Block a user