mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
fix typing in compute_gradient (#13852)
This commit is contained in:
@@ -59,7 +59,7 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
|
|||||||
grads = {root: root_grad}
|
grads = {root: root_grad}
|
||||||
for t0 in reversed(_deepwalk(root, targets)):
|
for t0 in reversed(_deepwalk(root, targets)):
|
||||||
if t0 not in grads: continue
|
if t0 not in grads: continue
|
||||||
lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
|
lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp|None, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
|
||||||
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
|
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
|
||||||
assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
|
assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
|
||||||
for k,v in zip(t0.src, lgrads):
|
for k,v in zip(t0.src, lgrads):
|
||||||
|
|||||||
Reference in New Issue
Block a user