mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 07:18:40 -05:00
gradient is a set [pr] (#8626)
* gradient is a set [pr] * typing for deepwalk
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import cast
|
||||
from typing import cast, Iterator
|
||||
import math, functools
|
||||
from tinygrad.dtype import dtypes, sum_acc_dtype
|
||||
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops
|
||||
@@ -44,10 +44,10 @@ pm_gradient = PatternMatcher([
|
||||
])
|
||||
|
||||
# copied from tensor.py, get relevant toposort of gradients
|
||||
def _deepwalk(root:UOp, targets:list[UOp]):
|
||||
def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
|
||||
@functools.lru_cache(None)
|
||||
def is_in_target_path(x:UOp) -> bool: return any(u in targets or is_in_target_path(u) for u in x.src)
|
||||
def _walk(node:UOp, visited:set[UOp]):
|
||||
def _walk(node:UOp, visited:set[UOp]) -> Iterator[UOp]:
|
||||
visited.add(node)
|
||||
if node.op is Ops.DETACH: return
|
||||
if is_in_target_path(node):
|
||||
@@ -56,7 +56,7 @@ def _deepwalk(root:UOp, targets:list[UOp]):
|
||||
yield node
|
||||
return list(_walk(root, set()))
|
||||
|
||||
def compute_gradient(root:UOp, root_grad:UOp, targets:list[UOp]) -> dict[UOp, UOp]:
|
||||
def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
|
||||
grads = {root: root_grad}
|
||||
for t0 in reversed(_deepwalk(root, targets)):
|
||||
if t0 not in grads: continue
|
||||
|
||||
@@ -920,7 +920,7 @@ class Tensor(SimpleMathTrait):
|
||||
rets = []
|
||||
for i,(uop,grad) in enumerate(zip(self.lazydata.lbs, gradient.lazydata.lbs)):
|
||||
target_uops = [x.lazydata.lbs[i] for x in targets]
|
||||
grads = compute_gradient(uop, grad, target_uops)
|
||||
grads = compute_gradient(uop, grad, set(target_uops))
|
||||
ret = []
|
||||
for x in target_uops:
|
||||
if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{uop}")
|
||||
@@ -931,13 +931,13 @@ class Tensor(SimpleMathTrait):
|
||||
return [Tensor(MultiLazyBuffer(list(u), cast(MultiLazyBuffer, t.lazydata).axis, cast(MultiLazyBuffer, t.lazydata).real),
|
||||
device=t.device) for t,u in zip(targets, zip(*rets))]
|
||||
|
||||
def _deepwalk(self):
|
||||
def _walk(node, visited):
|
||||
def _deepwalk(self) -> list[Tensor]:
|
||||
def _walk(node:Tensor, visited:set[Tensor]):
|
||||
visited.add(node)
|
||||
# if tensor is not leaf, reset grad
|
||||
if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None
|
||||
if ctx:
|
||||
for i in node._ctx.parents:
|
||||
for i in cast(Function, node._ctx).parents:
|
||||
if i not in visited: yield from _walk(i, visited)
|
||||
yield node
|
||||
return list(_walk(self, set()))
|
||||
@@ -965,12 +965,13 @@ class Tensor(SimpleMathTrait):
|
||||
self.grad = gradient
|
||||
for t0 in reversed(toposorted):
|
||||
if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
|
||||
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None)
|
||||
grads = t0._ctx.backward(t0.grad.lazydata)
|
||||
ctx = cast(Function, t0._ctx)
|
||||
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := ctx.metadata) is not None else None)
|
||||
grads = ctx.backward(t0.grad.lazydata)
|
||||
_METADATA.reset(token)
|
||||
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
||||
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
|
||||
for t, g in zip(t0._ctx.parents, grads):
|
||||
for g in ([grads] if len(ctx.parents) == 1 else grads)]
|
||||
for t, g in zip(ctx.parents, grads):
|
||||
if g is not None and t.requires_grad:
|
||||
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
||||
assert t.lazydata in toposort_uop or (isinstance(t.lazydata, MultiLazyBuffer) and any(x in toposort_uop for x in t.lazydata.lbs)), \
|
||||
|
||||
Reference in New Issue
Block a user