remove deepwalk args (#243)

This commit is contained in:
Göktuğ Karakaşlı
2021-01-31 19:30:17 +03:00
committed by GitHub
parent ce77dda805
commit eabe0b9017

View File

@@ -121,12 +121,14 @@ class Tensor:
# ***** toposort and backward pass *****
def deepwalk(self, visited: set, nodes: list):
visited.add(self)
if self._ctx:
[i.deepwalk(visited, nodes) for i in self._ctx.parents if i not in visited]
nodes.append(self)
return nodes
def deepwalk(self):
def _deepwalk(node, visited, nodes):
visited.add(node)
if node._ctx:
[_deepwalk(i, visited, nodes) for i in node._ctx.parents if i not in visited]
nodes.append(node)
return nodes
return _deepwalk(self, set(), [])
def backward(self):
assert self.shape == (1,)
@@ -135,7 +137,7 @@ class Tensor:
# this is "implicit gradient creation"
self.grad = Tensor(np.ones(self.shape, dtype=self.dtype), device=self.device, requires_grad=False)
for t0 in reversed(self.deepwalk(set(), [])):
for t0 in reversed(self.deepwalk()):
assert (t0.grad is not None)
with ProfileOp(t0._ctx.__class__.__name__, [t0.grad], backward=True):
grads = t0._ctx.backward(t0._ctx, t0.grad.data)