From a03cd43e788288e97ab96cfb210c9a8e2f7b16c4 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 28 Dec 2025 11:52:14 -0500 Subject: [PATCH] fix typing in compute_gradient (#13852) --- tinygrad/gradient.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 1447c69113..a1d44c6ce2 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -59,7 +59,7 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp grads = {root: root_grad} for t0 in reversed(_deepwalk(root, targets)): if t0 not in grads: continue - lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0])) + lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp|None, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0])) if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...") assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}" for k,v in zip(t0.src, lgrads):