remove unused ctx [pr] (#8751)

* remove unused ctx [pr]

* fix test
This commit is contained in:
George Hotz
2025-01-26 17:59:15 +09:00
committed by GitHub
parent 06b58aa7ec
commit b53fe7c2fc
2 changed files with 2 additions and 12 deletions

View File

@@ -33,7 +33,7 @@ class TestGC(unittest.TestCase):
base = tensors_allocated()
a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
b = Tensor.rand(4, 4, requires_grad=True)
assert (tensors_allocated()-base == 5)
assert (tensors_allocated()-base == 4)
(a*b).mean().backward()
assert (tensors_allocated()-base == 6)
del b

View File

@@ -53,14 +53,12 @@ class Function:
self.metadata = metadata
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")
@classmethod
def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
ctx = fxn(x[0].device, *x, metadata=_METADATA.get())
ret = Tensor.__new__(Tensor)
ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
return ret
import tinygrad.function as F
@@ -147,8 +145,7 @@ class Tensor(SimpleMathTrait):
np.set_printoptions(precision=4)
```
"""
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
__deletable__ = ('_ctx',)
__slots__ = "lazydata", "requires_grad", "grad"
training: ClassVar[bool] = False
no_grad: ClassVar[bool] = False
@@ -171,9 +168,6 @@ class Tensor(SimpleMathTrait):
# None (the default) will be updated to True if it's put in an optimizer
self.requires_grad: Optional[bool] = requires_grad
# internal variable used for autograd graph construction
self._ctx: Optional[Function] = None
# create a LazyBuffer from the different types of inputs
if isinstance(data, UOp):
assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
@@ -281,7 +275,6 @@ class Tensor(SimpleMathTrait):
Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
"""
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
assert getattr(self, '_ctx', None) is None
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
self.lazydata = x.lazydata
return self
@@ -378,7 +371,6 @@ class Tensor(SimpleMathTrait):
"""
ret = Tensor(self.lazydata.clone(), self.device, requires_grad=self.requires_grad)
if self.grad is not None: ret.grad = self.grad.clone()
if hasattr(self, '_ctx'): ret._ctx = self._ctx
return ret
def to(self, device:Optional[Union[str, tuple[str, ...]]]) -> Tensor:
@@ -390,7 +382,6 @@ class Tensor(SimpleMathTrait):
if not isinstance(device, str): return self.shard(device)
ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
if self.grad is not None: ret.grad = self.grad.to(device)
if hasattr(self, '_ctx'): ret._ctx = self._ctx
return ret
def to_(self, device:Optional[Union[str, tuple[str, ...]]]):
@@ -944,7 +935,6 @@ class Tensor(SimpleMathTrait):
tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \
t.lazydata in all_uops and t.requires_grad and not Tensor.no_grad]
# clear contexts
for t in tensors_need_grad: t._ctx = None
for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)):
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
t.grad = g if t.grad is None else (t.grad + g)