mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
remove deepwalk args (#243)
This commit is contained in:
@@ -121,12 +121,14 @@ class Tensor:
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
|
||||
def deepwalk(self, visited: set, nodes: list):
|
||||
visited.add(self)
|
||||
if self._ctx:
|
||||
[i.deepwalk(visited, nodes) for i in self._ctx.parents if i not in visited]
|
||||
nodes.append(self)
|
||||
return nodes
|
||||
def deepwalk(self):
|
||||
def _deepwalk(node, visited, nodes):
|
||||
visited.add(node)
|
||||
if node._ctx:
|
||||
[_deepwalk(i, visited, nodes) for i in node._ctx.parents if i not in visited]
|
||||
nodes.append(node)
|
||||
return nodes
|
||||
return _deepwalk(self, set(), [])
|
||||
|
||||
def backward(self):
|
||||
assert self.shape == (1,)
|
||||
@@ -135,7 +137,7 @@ class Tensor:
|
||||
# this is "implicit gradient creation"
|
||||
self.grad = Tensor(np.ones(self.shape, dtype=self.dtype), device=self.device, requires_grad=False)
|
||||
|
||||
for t0 in reversed(self.deepwalk(set(), [])):
|
||||
for t0 in reversed(self.deepwalk()):
|
||||
assert (t0.grad is not None)
|
||||
with ProfileOp(t0._ctx.__class__.__name__, [t0.grad], backward=True):
|
||||
grads = t0._ctx.backward(t0._ctx, t0.grad.data)
|
||||
|
||||
Reference in New Issue
Block a user