tensor universe

This commit is contained in:
George Hotz
2024-12-29 19:23:43 -05:00
parent 24a906aa50
commit c21301852a

View File

@@ -1,6 +1,6 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib, weakref
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib, weakref, contextlib
from contextlib import ContextDecorator
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
@@ -14,11 +14,6 @@ from tinygrad.engine.realize import run_schedule
from tinygrad.engine.memory import memory_planner
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
# *** all in scope Tensors are here. this is the only way to get children ***
# TODO: different "universes" for disconnected Tensors
all_tensors: weakref.WeakSet[Tensor] = weakref.WeakSet()
# **** start with two base classes, Tensor and Function ****
class Function:
@@ -38,6 +33,16 @@ class Function:
ret = Tensor.__new__(Tensor)
ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
# merge the Tensor universe of all in x
unis:list[Tensor] = sorted(dedup(x), key=lambda x: -len(x.universe))
# choose the biggest universe to merge into
merged_universe = unis[0].universe
merged_universe[ret.ref] = None
for t in unis[1:]:
merged_universe.update(t.universe)
for s in t.universe:
if (tt:=s()) is not None: tt.universe = merged_universe
ret.universe = merged_universe
return ret
import tinygrad.function as F
@@ -126,11 +131,10 @@ class Tensor(SimpleMathTrait):
training: ClassVar[bool] = False
no_grad: ClassVar[bool] = False
def __new__(cls, *args, **kwargs):
instance = super().__new__(cls)
all_tensors.add(instance)
return instance
@functools.cached_property
def ref(self) -> weakref.ref[Tensor]: return weakref.ref(self)
def __del__(self):
with contextlib.suppress(AttributeError): del self.universe[self.ref]
def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, MultiLazyBuffer, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
if dtype is not None: dtype = to_dtype(dtype)
@@ -181,6 +185,9 @@ class Tensor(SimpleMathTrait):
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
self.lazydata = data
# all Tensors in the same universe as this one. if this is a realized Tensor it doesn't have to be in own universe
self.universe = {self.ref:None}
def requires_grad_(self, requires_grad=True) -> Tensor:
self.requires_grad = requires_grad
return self
@@ -230,7 +237,7 @@ class Tensor(SimpleMathTrait):
# TODO: becomes_map should be returned from create_schedule_with_vars
# NOTE: this is potentially a lot of Tensors. see above about the universes
fixed_tensors: list[Tensor] = list(all_tensors)
fixed_tensors: list[Tensor] = dedup(flatten([[x for xref in t.universe if (x:=xref()) is not None] for t in (self,)+lst]))
sink = UOp.sink(*[UOp.sink(*t.lazydata.lbs) if isinstance(t.lazydata, MultiLazyBuffer) else t.lazydata for t in fixed_tensors])
new_sink = sink.substitute(becomes_map)
becomes_map.clear()
@@ -238,7 +245,7 @@ class Tensor(SimpleMathTrait):
if s is ns: continue
if isinstance(t.lazydata, MultiLazyBuffer): t.lazydata.lbs = list(ns.src)
else: t.lazydata = ns
# TODO: we can update the universe here to reflect the realization
return memory_planner(schedule), var_vals
def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
@@ -260,6 +267,7 @@ class Tensor(SimpleMathTrait):
assert getattr(self, '_ctx', None) is None
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
self.lazydata = x.lazydata
self.universe.update(x.universe)
return self
def assign(self, x) -> Tensor:
@@ -279,6 +287,7 @@ class Tensor(SimpleMathTrait):
assert not x.requires_grad # self requires_grad is okay?
if not self.lazydata.is_realized: return self.replace(x)
self.lazydata = self.lazydata.assign(x.lazydata)
self.universe.update(x.universe)
return self
def detach(self) -> Tensor: