Don't save parents unless needed (#1142)

* don't save parents unless requires grad

* keep del ctx since idk
This commit is contained in:
cheeetoo
2023-07-05 20:11:57 -05:00
committed by GitHub
parent 801564f31b
commit f109af3cbb

View File

@@ -14,9 +14,10 @@ from tinygrad.ops import LoadOps
# An instantiation of the Function is the Context
class Function:
def __init__(self, device:str, *tensors:Tensor):
self.device, self.parents = device, tensors
self.needs_input_grad = [t.requires_grad for t in self.parents]
self.device = device
self.needs_input_grad = [t.requires_grad for t in tensors]
self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
if self.requires_grad: self.parents = tensors
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")
@@ -213,8 +214,8 @@ class Tensor:
self.grad = Tensor(1, device=self.device, requires_grad=False)
for t0 in reversed(self.deepwalk()):
if not any(x.requires_grad for x in t0._ctx.parents):
del t0._ctx # TODO: does it help to delete this here ever?
if not t0.requires_grad:
del t0._ctx # TODO: does it help to delete this here ever?
continue
assert (t0.grad is not None)
grads = t0._ctx.backward(t0.grad.lazydata)