mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 18:35:12 -05:00
Don't save parents unless needed (#1142)
* don't save parents unless requires grad * keep del ctx since idk
This commit is contained in:
@@ -14,9 +14,10 @@ from tinygrad.ops import LoadOps
|
||||
# An instantiation of the Function is the Context
|
||||
class Function:
|
||||
def __init__(self, device:str, *tensors:Tensor):
|
||||
self.device, self.parents = device, tensors
|
||||
self.needs_input_grad = [t.requires_grad for t in self.parents]
|
||||
self.device = device
|
||||
self.needs_input_grad = [t.requires_grad for t in tensors]
|
||||
self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
|
||||
if self.requires_grad: self.parents = tensors
|
||||
|
||||
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
|
||||
def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")
|
||||
@@ -213,8 +214,8 @@ class Tensor:
|
||||
self.grad = Tensor(1, device=self.device, requires_grad=False)
|
||||
|
||||
for t0 in reversed(self.deepwalk()):
|
||||
if not any(x.requires_grad for x in t0._ctx.parents):
|
||||
del t0._ctx # TODO: does it help to delete this here ever?
|
||||
if not t0.requires_grad:
|
||||
del t0._ctx # TODO: does it help to delete this here ever?
|
||||
continue
|
||||
assert (t0.grad is not None)
|
||||
grads = t0._ctx.backward(t0.grad.lazydata)
|
||||
|
||||
Reference in New Issue
Block a user