diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 4430c048b7..f878010e8a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -14,9 +14,10 @@ from tinygrad.ops import LoadOps # An instantiation of the Function is the Context class Function: def __init__(self, device:str, *tensors:Tensor): - self.device, self.parents = device, tensors - self.needs_input_grad = [t.requires_grad for t in self.parents] + self.device = device + self.needs_input_grad = [t.requires_grad for t in tensors] self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False + if self.requires_grad: self.parents = tensors def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}") def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}") @@ -213,8 +214,8 @@ class Tensor: self.grad = Tensor(1, device=self.device, requires_grad=False) for t0 in reversed(self.deepwalk()): - if not any(x.requires_grad for x in t0._ctx.parents): - del t0._ctx # TODO: does it help to delete this here ever? + if not t0.requires_grad: + del t0._ctx # TODO: does it help to delete this here ever? continue assert (t0.grad is not None) grads = t0._ctx.backward(t0.grad.lazydata)