gradient is a set [pr] (#8626)

* gradient is a set [pr]

* typing for deepwalk
This commit is contained in:
George Hotz
2025-01-14 20:48:23 -08:00
committed by GitHub
parent 7fb1c7af61
commit e1f7c90459
2 changed files with 13 additions and 12 deletions

View File

@@ -1,4 +1,4 @@
from typing import cast
from typing import cast, Iterator
import math, functools
from tinygrad.dtype import dtypes, sum_acc_dtype
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops
@@ -44,10 +44,10 @@ pm_gradient = PatternMatcher([
])
# copied from tensor.py, get relevant toposort of gradients
def _deepwalk(root:UOp, targets:list[UOp]):
def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
@functools.lru_cache(None)
def is_in_target_path(x:UOp) -> bool: return any(u in targets or is_in_target_path(u) for u in x.src)
def _walk(node:UOp, visited:set[UOp]):
def _walk(node:UOp, visited:set[UOp]) -> Iterator[UOp]:
visited.add(node)
if node.op is Ops.DETACH: return
if is_in_target_path(node):
@@ -56,7 +56,7 @@ def _deepwalk(root:UOp, targets:list[UOp]):
yield node
return list(_walk(root, set()))
def compute_gradient(root:UOp, root_grad:UOp, targets:list[UOp]) -> dict[UOp, UOp]:
def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
grads = {root: root_grad}
for t0 in reversed(_deepwalk(root, targets)):
if t0 not in grads: continue

View File

@@ -920,7 +920,7 @@ class Tensor(SimpleMathTrait):
rets = []
for i,(uop,grad) in enumerate(zip(self.lazydata.lbs, gradient.lazydata.lbs)):
target_uops = [x.lazydata.lbs[i] for x in targets]
grads = compute_gradient(uop, grad, target_uops)
grads = compute_gradient(uop, grad, set(target_uops))
ret = []
for x in target_uops:
if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{uop}")
@@ -931,13 +931,13 @@ class Tensor(SimpleMathTrait):
return [Tensor(MultiLazyBuffer(list(u), cast(MultiLazyBuffer, t.lazydata).axis, cast(MultiLazyBuffer, t.lazydata).real),
device=t.device) for t,u in zip(targets, zip(*rets))]
def _deepwalk(self):
def _walk(node, visited):
def _deepwalk(self) -> list[Tensor]:
def _walk(node:Tensor, visited:set[Tensor]):
visited.add(node)
# if tensor is not leaf, reset grad
if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None
if ctx:
for i in node._ctx.parents:
for i in cast(Function, node._ctx).parents:
if i not in visited: yield from _walk(i, visited)
yield node
return list(_walk(self, set()))
@@ -965,12 +965,13 @@ class Tensor(SimpleMathTrait):
self.grad = gradient
for t0 in reversed(toposorted):
if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None)
grads = t0._ctx.backward(t0.grad.lazydata)
ctx = cast(Function, t0._ctx)
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := ctx.metadata) is not None else None)
grads = ctx.backward(t0.grad.lazydata)
_METADATA.reset(token)
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
for t, g in zip(t0._ctx.parents, grads):
for g in ([grads] if len(ctx.parents) == 1 else grads)]
for t, g in zip(ctx.parents, grads):
if g is not None and t.requires_grad:
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
assert t.lazydata in toposort_uop or (isinstance(t.lazydata, MultiLazyBuffer) and any(x in toposort_uop for x in t.lazydata.lbs)), \