diff --git a/test/test_gc.py b/test/test_gc.py index 0929a81394..cf90dc6201 100644 --- a/test/test_gc.py +++ b/test/test_gc.py @@ -33,7 +33,7 @@ class TestGC(unittest.TestCase): base = tensors_allocated() a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True) b = Tensor.rand(4, 4, requires_grad=True) - assert (tensors_allocated()-base == 5) + assert (tensors_allocated()-base == 4) (a*b).mean().backward() assert (tensors_allocated()-base == 6) del b diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 826ade9242..1bcd22af53 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -53,14 +53,12 @@ class Function: self.metadata = metadata def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}") - def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}") @classmethod def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor: ctx = fxn(x[0].device, *x, metadata=_METADATA.get()) ret = Tensor.__new__(Tensor) ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None - ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine return ret import tinygrad.function as F @@ -147,8 +145,7 @@ class Tensor(SimpleMathTrait): np.set_printoptions(precision=4) ``` """ - __slots__ = "lazydata", "requires_grad", "grad", "_ctx" - __deletable__ = ('_ctx',) + __slots__ = "lazydata", "requires_grad", "grad" training: ClassVar[bool] = False no_grad: ClassVar[bool] = False @@ -171,9 +168,6 @@ class Tensor(SimpleMathTrait): # None (the default) will be updated to True if it's put in an optimizer self.requires_grad: Optional[bool] = requires_grad - # internal variable used for autograd graph construction - self._ctx: Optional[Function] = None - # create a LazyBuffer from the different types of inputs if isinstance(data, UOp): assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported" @@ -281,7 +275,6 @@ class Tensor(SimpleMathTrait): Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match. """ # used for replacing a Tensor with a new version of it (potentially with a different device and dtype) - assert getattr(self, '_ctx', None) is None assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}" self.lazydata = x.lazydata return self @@ -378,7 +371,6 @@ class Tensor(SimpleMathTrait): """ ret = Tensor(self.lazydata.clone(), self.device, requires_grad=self.requires_grad) if self.grad is not None: ret.grad = self.grad.clone() - if hasattr(self, '_ctx'): ret._ctx = self._ctx return ret def to(self, device:Optional[Union[str, tuple[str, ...]]]) -> Tensor: @@ -390,7 +382,6 @@ class Tensor(SimpleMathTrait): if not isinstance(device, str): return self.shard(device) ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad) if self.grad is not None: ret.grad = self.grad.to(device) - if hasattr(self, '_ctx'): ret._ctx = self._ctx return ret def to_(self, device:Optional[Union[str, tuple[str, ...]]]): @@ -944,7 +935,6 @@ class Tensor(SimpleMathTrait): tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \ t.lazydata in all_uops and t.requires_grad and not Tensor.no_grad] # clear contexts - for t in tensors_need_grad: t._ctx = None for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)): assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}" t.grad = g if t.grad is None else (t.grad + g)