From e1f7c90459200fd93ef142f899f4d9a4167625c4 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 14 Jan 2025 20:48:23 -0800 Subject: [PATCH] gradient is a set [pr] (#8626) * gradient is a set [pr] * typing for deepwalk --- tinygrad/gradient.py | 8 ++++---- tinygrad/tensor.py | 17 +++++++++-------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 1b93e9374b..a2fa71a98d 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import cast, Iterator import math, functools from tinygrad.dtype import dtypes, sum_acc_dtype from tinygrad.ops import UOp, PatternMatcher, UPat, Ops @@ -44,10 +44,10 @@ pm_gradient = PatternMatcher([ ]) # copied from tensor.py, get relevant toposort of gradients -def _deepwalk(root:UOp, targets:list[UOp]): +def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]: @functools.lru_cache(None) def is_in_target_path(x:UOp) -> bool: return any(u in targets or is_in_target_path(u) for u in x.src) - def _walk(node:UOp, visited:set[UOp]): + def _walk(node:UOp, visited:set[UOp]) -> Iterator[UOp]: visited.add(node) if node.op is Ops.DETACH: return if is_in_target_path(node): @@ -56,7 +56,7 @@ def _deepwalk(root:UOp, targets:list[UOp]): yield node return list(_walk(root, set())) -def compute_gradient(root:UOp, root_grad:UOp, targets:list[UOp]) -> dict[UOp, UOp]: +def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]: grads = {root: root_grad} for t0 in reversed(_deepwalk(root, targets)): if t0 not in grads: continue diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7d00e441d4..e3d87f51eb 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -920,7 +920,7 @@ class Tensor(SimpleMathTrait): rets = [] for i,(uop,grad) in enumerate(zip(self.lazydata.lbs, gradient.lazydata.lbs)): target_uops = [x.lazydata.lbs[i] for x in targets] - grads = compute_gradient(uop, grad, target_uops) + grads = compute_gradient(uop, grad, set(target_uops)) ret = [] for x in target_uops: if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{uop}") @@ -931,13 +931,13 @@ class Tensor(SimpleMathTrait): return [Tensor(MultiLazyBuffer(list(u), cast(MultiLazyBuffer, t.lazydata).axis, cast(MultiLazyBuffer, t.lazydata).real), device=t.device) for t,u in zip(targets, zip(*rets))] - def _deepwalk(self): - def _walk(node, visited): + def _deepwalk(self) -> list[Tensor]: + def _walk(node:Tensor, visited:set[Tensor]): visited.add(node) # if tensor is not leaf, reset grad if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None if ctx: - for i in node._ctx.parents: + for i in cast(Function, node._ctx).parents: if i not in visited: yield from _walk(i, visited) yield node return list(_walk(self, set())) @@ -965,12 +965,13 @@ class Tensor(SimpleMathTrait): self.grad = gradient for t0 in reversed(toposorted): if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad") - token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None) - grads = t0._ctx.backward(t0.grad.lazydata) + ctx = cast(Function, t0._ctx) + token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := ctx.metadata) is not None else None) + grads = ctx.backward(t0.grad.lazydata) _METADATA.reset(token) grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None - for g in ([grads] if len(t0._ctx.parents) == 1 else grads)] - for t, g in zip(t0._ctx.parents, grads): + for g in ([grads] if len(ctx.parents) == 1 else grads)] + for t, g in zip(ctx.parents, grads): if g is not None and t.requires_grad: assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}" assert t.lazydata in toposort_uop or (isinstance(t.lazydata, MultiLazyBuffer) and any(x in toposort_uop for x in t.lazydata.lbs)), \