fix typing in compute_gradient (#13852)

This commit is contained in:
chenyu
2025-12-28 11:52:14 -05:00
committed by GitHub
parent cba05acadf
commit a03cd43e78

View File

@@ -59,7 +59,7 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
grads = {root: root_grad}
for t0 in reversed(_deepwalk(root, targets)):
if t0 not in grads: continue
lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp|None, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
for k,v in zip(t0.src, lgrads):