mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
tensor universe
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib, weakref
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib, weakref, contextlib
|
||||
from contextlib import ContextDecorator
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
@@ -14,11 +14,6 @@ from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||
|
||||
# *** all in scope Tensors are here. this is the only way to get children ***
|
||||
# TODO: different "universes" for disconnected Tensors
|
||||
|
||||
all_tensors: weakref.WeakSet[Tensor] = weakref.WeakSet()
|
||||
|
||||
# **** start with two base classes, Tensor and Function ****
|
||||
|
||||
class Function:
|
||||
@@ -38,6 +33,16 @@ class Function:
|
||||
ret = Tensor.__new__(Tensor)
|
||||
ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
|
||||
ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
|
||||
# merge the Tensor universe of all in x
|
||||
unis:list[Tensor] = sorted(dedup(x), key=lambda x: -len(x.universe))
|
||||
# choose the biggest universe to merge into
|
||||
merged_universe = unis[0].universe
|
||||
merged_universe[ret.ref] = None
|
||||
for t in unis[1:]:
|
||||
merged_universe.update(t.universe)
|
||||
for s in t.universe:
|
||||
if (tt:=s()) is not None: tt.universe = merged_universe
|
||||
ret.universe = merged_universe
|
||||
return ret
|
||||
|
||||
import tinygrad.function as F
|
||||
@@ -126,11 +131,10 @@ class Tensor(SimpleMathTrait):
|
||||
training: ClassVar[bool] = False
|
||||
no_grad: ClassVar[bool] = False
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
instance = super().__new__(cls)
|
||||
all_tensors.add(instance)
|
||||
return instance
|
||||
|
||||
@functools.cached_property
|
||||
def ref(self) -> weakref.ref[Tensor]: return weakref.ref(self)
|
||||
def __del__(self):
|
||||
with contextlib.suppress(AttributeError): del self.universe[self.ref]
|
||||
def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, MultiLazyBuffer, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
|
||||
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
|
||||
if dtype is not None: dtype = to_dtype(dtype)
|
||||
@@ -181,6 +185,9 @@ class Tensor(SimpleMathTrait):
|
||||
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
|
||||
self.lazydata = data
|
||||
|
||||
# all Tensors in the same universe as this one. if this is a realized Tensor it doesn't have to be in own universe
|
||||
self.universe = {self.ref:None}
|
||||
|
||||
def requires_grad_(self, requires_grad=True) -> Tensor:
|
||||
self.requires_grad = requires_grad
|
||||
return self
|
||||
@@ -230,7 +237,7 @@ class Tensor(SimpleMathTrait):
|
||||
# TODO: becomes_map should be returned from create_schedule_with_vars
|
||||
|
||||
# NOTE: this is potentially a lot of Tensors. see above about the universes
|
||||
fixed_tensors: list[Tensor] = list(all_tensors)
|
||||
fixed_tensors: list[Tensor] = dedup(flatten([[x for xref in t.universe if (x:=xref()) is not None] for t in (self,)+lst]))
|
||||
sink = UOp.sink(*[UOp.sink(*t.lazydata.lbs) if isinstance(t.lazydata, MultiLazyBuffer) else t.lazydata for t in fixed_tensors])
|
||||
new_sink = sink.substitute(becomes_map)
|
||||
becomes_map.clear()
|
||||
@@ -238,7 +245,7 @@ class Tensor(SimpleMathTrait):
|
||||
if s is ns: continue
|
||||
if isinstance(t.lazydata, MultiLazyBuffer): t.lazydata.lbs = list(ns.src)
|
||||
else: t.lazydata = ns
|
||||
|
||||
# TODO: we can update the universe here to reflect the realization
|
||||
return memory_planner(schedule), var_vals
|
||||
|
||||
def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
|
||||
@@ -260,6 +267,7 @@ class Tensor(SimpleMathTrait):
|
||||
assert getattr(self, '_ctx', None) is None
|
||||
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
|
||||
self.lazydata = x.lazydata
|
||||
self.universe.update(x.universe)
|
||||
return self
|
||||
|
||||
def assign(self, x) -> Tensor:
|
||||
@@ -279,6 +287,7 @@ class Tensor(SimpleMathTrait):
|
||||
assert not x.requires_grad # self requires_grad is okay?
|
||||
if not self.lazydata.is_realized: return self.replace(x)
|
||||
self.lazydata = self.lazydata.assign(x.lazydata)
|
||||
self.universe.update(x.universe)
|
||||
return self
|
||||
|
||||
def detach(self) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user