mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
fix typing in compute_gradient (#13852)
This commit is contained in:
@@ -59,7 +59,7 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
|
||||
grads = {root: root_grad}
|
||||
for t0 in reversed(_deepwalk(root, targets)):
|
||||
if t0 not in grads: continue
|
||||
lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
|
||||
lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp|None, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
|
||||
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
|
||||
assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
|
||||
for k,v in zip(t0.src, lgrads):
|
||||
|
||||
Reference in New Issue
Block a user