mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
124 lines
8.1 KiB
Python
124 lines
8.1 KiB
Python
from typing import cast
|
|
import math, dataclasses, itertools
|
|
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata, graph_rewrite
|
|
from tinygrad.helpers import argsort
|
|
from tinygrad.dtype import sum_acc_dtype
|
|
|
|
def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
|
|
def broadcast_to_input(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape)
|
|
if op == Ops.ADD: return (broadcast_to_input(ctx),)
|
|
if op == Ops.MAX:
|
|
assert ret.op is Ops.REDUCE_AXIS, "only works on REDUCE_AXIS"
|
|
mask = ret.src[0].eq(broadcast_to_input(ret)).cast(ctx.dtype)
|
|
count = mask._rop(Ops.ADD, ret.arg[1])
|
|
return ((mask/broadcast_to_input(count)) * broadcast_to_input(ctx),)
|
|
if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],)
|
|
|
|
def _compact_params(body:UOp, all_args:tuple[UOp, ...]) -> tuple[UOp, tuple[UOp, ...]]:
|
|
"""Remove unused PARAMs from body and return compacted (body, args)."""
|
|
used = sorted({p.arg: p for p in body.toposort() if p.op is Ops.PARAM}.items())
|
|
return body.substitute({p: p.replace(arg=j) for j,(_, p) in enumerate(used)}, walk=True), tuple(all_args[i] for i,_ in used)
|
|
|
|
def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
|
|
fxn, args = k.src[0], k.src[1:]
|
|
if k.arg.grad_fxn is not None:
|
|
if ctx.op is Ops.TUPLE:
|
|
real = [g for g in ctx.src if g.op is not Ops.NOOP]
|
|
return (None,) + (k.arg.grad_fxn(*real, call=k) if len(real) > 1 else k.arg.grad_fxn(real[0], k))
|
|
return (None,) + k.arg.grad_fxn(ctx, k)
|
|
assert fxn.op is Ops.TUPLE, f"expected TUPLE body for gradient, got {fxn.op}"
|
|
params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
|
|
grad_args = ctx.src
|
|
root_grad = UOp(Ops.TUPLE, src=tuple(UOp(Ops.NOOP) if g.op is Ops.NOOP else g.param_like(len(args)+i) for i,g in enumerate(grad_args)))
|
|
grads = compute_gradient(fxn, root_grad, set(params.values()))
|
|
# for precompiled calls, substitute forward outputs with params so intermediates aren't recomputed
|
|
fwd_subs = {src: src.param_like(len(args)+len(grad_args)+i) for i, src in enumerate(fxn.src)} if k.arg.precompile else {}
|
|
fwd_outs = tuple(k.gettuple(i) for i in range(len(fxn.src))) if k.arg.precompile else ()
|
|
# collect needed gradient bodies, compact unused params, create a single backward CALL
|
|
grad_bodies = [(i, grads[p]) for i in needed if (p:=params.get(i)) is not None and p in grads]
|
|
bwd_body = UOp.maketuple(*(gb for _, gb in grad_bodies)).substitute(fwd_subs, walk=True)
|
|
bwd_body, compact_args = _compact_params(bwd_body, (*args, *grad_args, *fwd_outs))
|
|
# TODO: is this okay here?
|
|
from tinygrad.function import pm_transform_unique_const
|
|
bwd_body = graph_rewrite(bwd_body, pm_transform_unique_const, ctx=(None, itertools.count(0)))
|
|
bwd_call = bwd_body.call(*compact_args, name=(k.arg.name or "")+"_backward", precompile=k.arg.precompile_backward)
|
|
gb_map = {i: idx for idx, (i, _) in enumerate(grad_bodies)}
|
|
return (None,) + tuple(bwd_call.gettuple(gb_map[i]) if i in gb_map else None for i in range(len(args)))
|
|
|
|
# ctx is grad_output
|
|
pm_gradient = PatternMatcher([
|
|
(UPat(Ops.CAST, name="ret"), lambda ctx, ret: (ctx.cast(ret.src[0].dtype),)),
|
|
(UPat(Ops.RECIPROCAL, name="ret"), lambda ctx, ret: (-ctx * ret * ret,)),
|
|
(UPat(Ops.SIN, name="ret"), lambda ctx, ret: ((math.pi/2 - ret.src[0]).sin() * ctx,)),
|
|
(UPat(Ops.LOG2, name="ret"), lambda ctx, ret: (ctx / (ret.src[0] * math.log(2)),)),
|
|
(UPat(Ops.EXP2, name="ret"), lambda ctx, ret: (ret * ctx * math.log(2),)),
|
|
(UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)),
|
|
(UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)),
|
|
(UPat(Ops.ADD), lambda ctx: (ctx, ctx)),
|
|
(UPat(Ops.POW, name="ret", src=(UPat.var("b"), UPat.var("e"))), lambda ctx, ret, b, e:
|
|
(ctx * (b.eq(0)&e.eq(0)).where(e, e*b.pow(e-1)), ctx * b.eq(0).where((e<0).where(ret.const_like(-math.inf), 0), ret*b.log2()*math.log(2.0)))),
|
|
(UPat(Ops.MAX, src=(UPat.var("x"), UPat.var("y"))), lambda ctx, x, y:
|
|
((x>y).where(ctx, (x.eq(y)).where(ctx * 0.5, 0)), (x<y).where(ctx, (x.eq(y)).where(ctx * 0.5, 0)))),
|
|
(UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)),
|
|
(UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))),
|
|
(UPat(Ops.REDUCE_AXIS, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg[0])),
|
|
(UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
|
|
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
|
|
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)),
|
|
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
|
|
(ctx.cast(sum_acc_dtype(ctx.dtype))._rop(Ops.ADD, tuple(i for i,(s,n) in enumerate(zip(ret.src[0].shape, ret.shape)) if s!=n))
|
|
.cast(ctx.dtype), None)),
|
|
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
|
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
|
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
|
|
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip([i for i,x in enumerate(ret.marg) if x]),)),
|
|
(UPat(Ops.COPY, name="ret"), lambda ctx, ret: (ctx.copy_to_device(ret.src[0].device), None)),
|
|
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
|
(UPat(Ops.TUPLE), lambda ctx: ctx.src),
|
|
# NOTE: this is only correct when the KERNEL has a single output
|
|
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
|
|
# there's no gradient for bitcast
|
|
(UPat(Ops.BITCAST), lambda: (None,)),
|
|
])
|
|
|
|
def _deepwalk(root:UOp, targets:set[UOp]) -> tuple[list[UOp], dict[UOp, bool]]:
|
|
# compute the target path (top down)
|
|
in_target_path: dict[UOp, bool] = {}
|
|
root.topovisit(lambda u: any(in_target_path[x] or x in targets for x in u.src), in_target_path)
|
|
# don't flow through DETACH or anything not in target path
|
|
return [node for node in in_target_path if node.op is not Ops.DETACH and in_target_path[node]], in_target_path
|
|
|
|
def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
|
|
walk, in_target_path = _deepwalk(root, targets)
|
|
grads: dict[UOp, UOp] = {root: root_grad}
|
|
for t0 in reversed(walk):
|
|
if t0 not in grads or grads[t0].op is Ops.NOOP: continue
|
|
# GETTUPLE: accumulate gradient into a TUPLE UOp on the FUNCTION, process when we hit the FUNCTION
|
|
if t0.op is Ops.GETTUPLE:
|
|
k = t0.src[0] # the FUNCTION
|
|
assert k.op is Ops.FUNCTION and k.src[0].op is Ops.TUPLE
|
|
n_outputs = len(k.src[0].src)
|
|
prev = grads[k].src if k in grads else tuple(UOp(Ops.NOOP) for _ in range(n_outputs))
|
|
grads[k] = UOp.maketuple(*(prev[i] + grads[t0] if i == t0.arg and prev[i].op is not Ops.NOOP else
|
|
grads[t0] if i == t0.arg else prev[i] for i in range(n_outputs)))
|
|
continue
|
|
# FUNCTION/CALL: pass needed param set so backward only computes required gradients
|
|
# (FUNCTION uses implicit TUPLE gradient or grad_fxn; CALL requires an explicit grad_fxn)
|
|
if t0.op in {Ops.FUNCTION, Ops.CALL}:
|
|
needed = {i for i, arg in enumerate(t0.src[1:]) if arg in targets or in_target_path.get(arg, False)}
|
|
lgrads:tuple[UOp|None, ...]|None = call_gradient(grads[t0], t0, needed)
|
|
else:
|
|
lgrads = cast(tuple[UOp|None, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
|
|
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
|
|
assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
|
|
for k,v in zip(t0.src, lgrads):
|
|
if v is None: continue
|
|
if k in grads and grads[k].op is not Ops.NOOP: grads[k] = grads[k] + v
|
|
else: grads[k] = v
|
|
if len(forward_metadata:=all_metadata.get(t0, ())):
|
|
backward_metadata = tuple(dataclasses.replace(x, backward=True) for x in forward_metadata)
|
|
# we add the backward metadata to everything new in the graph
|
|
for bw_uop in v.toposort(lambda x: x not in (t0, *t0.src, grads[t0])):
|
|
all_metadata[bw_uop] = all_metadata.get(bw_uop, ())+backward_metadata
|
|
return grads
|