vmap on full model (#13340)

* vmap on full model

* vmap gemm

* reduce sums on end

* outer reduce

* only if there's ranges

* put those rules in symbolic

* ranges

* do opt later

* add zero range
This commit is contained in:
George Hotz
2025-11-18 16:06:06 -08:00
committed by GitHub
parent 46cb65e692
commit 1afa3c0877
6 changed files with 74 additions and 27 deletions

View File

@@ -1,6 +1,6 @@
import unittest
import numpy as np
from tinygrad import Tensor, UOp
from tinygrad import Tensor, UOp, nn
from tinygrad.uop.ops import AxisType, Ops
class TestOuterworldReduce(unittest.TestCase):
@@ -156,7 +156,7 @@ class TestVmap(unittest.TestCase):
# vmap across axis 0
a = UOp.range(3, -1, axis_type)
out = x @ Tensor(mats.uop.reduce_backward(a, arg=Ops.ADD))[a]
out = x @ mats[a]
out = out.reshape(1, 10).pad(((a,(3-a)-1), None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
if fuse: out = out * 2
@@ -175,5 +175,56 @@ class TestVmap(unittest.TestCase):
def test_vmap_inner_fuse_grad(self): self.test_vmap_inner(fuse=True, grad=True)
def test_vmap_outer_grad(self): self.test_vmap_inner(AxisType.OUTER, grad=True)
def test_vmap_convs(self):
layers = [
nn.Conv2d(1, 8, 3), Tensor.relu,
nn.Conv2d(8, 8, 3), Tensor.relu]
img = Tensor.randn(4, 1, 16, 16).realize(*nn.state.get_parameters(layers))
a = UOp.range(4, -1, AxisType.OUTER)
out = img[a:a+1].sequential(layers)
out = out.pad(((a,(4-a)-1), None, None, None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
out.realize()
np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6)
def test_vmap_gemm(self):
layers = [
nn.Linear(16, 16, bias=False), Tensor.relu,
nn.Linear(16, 16, bias=False), Tensor.relu]
img = Tensor.randn(4, 16).realize(*nn.state.get_parameters(layers))
a = UOp.range(4, -1, AxisType.OUTER)
out = img[a:a+1].sequential(layers)
out = out.pad(((a,(4-a)-1), None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
out.realize()
np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6)
@unittest.skip("this is broken, we need to lower the outer reduce in the outer graph")
def test_vmap_gemm_grad(self):
layers = [
nn.Linear(16, 16, bias=False), Tensor.relu,
nn.Linear(16, 16, bias=False), Tensor.relu]
layer_tensors = nn.state.get_parameters(layers)
img = Tensor.randn(4, 16).realize(*layer_tensors)
for l in layer_tensors: l.requires_grad_()
a = UOp.range(4, -1, AxisType.OUTER)
out = img[a:a+1].sequential(layers)
out = out.pad(((a,(4-a)-1), None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
out.mean().backward()
grads = [l.grad for l in layer_tensors]
out.realize(*grads)
out_grads = [x.numpy() for x in grads]
# compute reference grads
for l in layer_tensors: l.grad = None
img.sequential(layers).mean().backward()
grads = [l.grad for l in layer_tensors]
out.realize(*grads)
ref_grads = [x.numpy() for x in grads]
# compare
for o,r in zip(out_grads, ref_grads): np.testing.assert_allclose(o, r, atol=1e-6)
if __name__ == '__main__':
unittest.main()

View File

@@ -31,7 +31,6 @@ pm_gradient = PatternMatcher([
(UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))),
(UPat(Ops.REDUCE_AXIS, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg[0])),
(UPat(Ops.REDUCE, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg) + (None,)*(len(ret.src)-1)),
(UPat(Ops.REDUCE_BACKWARD, name="ret"), lambda ctx, ret: (ctx.reduce(*ret.src[1:], arg=ret.arg),) + (None,)*(len(ret.src)-1)),
(UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)),
@@ -71,4 +70,8 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
# we add the backward metadata to everything new in the graph
for bw_uop in v.toposort(lambda x: x not in (t0, *t0.src, grads[t0])):
all_metadata[bw_uop] = all_metadata.get(bw_uop, ())+backward_metadata
# end any ranges on grads with a reduce sum
for k,v in grads.items():
if len(v.ranges):
grads[k] = v.reduce(*v.ranges, arg=Ops.ADD)
return grads

View File

@@ -65,7 +65,7 @@ mop_cleanup = PatternMatcher([
earliest_rewrites = mop_cleanup+PatternMatcher([
# just removing it works...
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.REDUCE_BACKWARD), name="x"), lambda x: x.src[0]),
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# remove CONTIGUOUS if the BUFFER is already contiguous
(UPat(Ops.BUFFER).f(Ops.RESHAPE, allow_any_len=True, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
@@ -325,6 +325,18 @@ def bufferize_to_store(ctx:itertools.count|None, x:UOp, idx:UOp, allow_locals=Tr
for m in mops[::-1]: ret = ret._mop(*m)
return ret
# lower outerworld reduce here
if x.src[0].op is Ops.REDUCE and len(x.src[0].src) == 2 and x.src[0].src[1].arg[-1] == AxisType.OUTER:
assert sdtype.addrspace == AddrSpace.GLOBAL
outer_range = x.src[0].src[1]
buf = UOp.new_buffer(x.arg.device, size, x.dtype)
# NOTE: this has the same number as the outer range, we need string ranges!
zero_range = outer_range.replace(src=(UOp.const(dtypes.index, size),), arg=outer_range.arg[:-1]+(AxisType.LOOP,))
buf = buf.after(buf.index(zero_range).store(0).end(zero_range))
bufi = buf.index(idx, dtype=sdtype)
do_store = bufi.store(bufi.load() + x.src[0].src[0], tag=x.tag).end(*rngs).end(outer_range)
return buf.after(do_store)
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
if sdtype.addrspace == AddrSpace.GLOBAL:
buf = UOp.new_buffer(x.arg.device, size, x.dtype)
@@ -500,12 +512,7 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in kernel.src)}")
return kernel
def split_inner_and_outer_end(x: UOp):
outer_ranges, inner_ranges = partition(x.src[1:], lambda r: r.arg[-1] == AxisType.OUTER)
if len(outer_ranges) and len(inner_ranges): return x.src[0].end(*inner_ranges).end(*outer_ranges)
split_kernels = PatternMatcher([
(UPat(Ops.END, name="x"), split_inner_and_outer_end),
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
])
@@ -537,18 +544,6 @@ replace_contiguous = PatternMatcher([
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
])
pm_fix_vmap = PatternMatcher([
# x>=y and x<(y+1) means x==y (this can go in symbolic)
((UPat.var("x", dtype=dtypes.index) >= UPat.var("y")) & (UPat.var("x") < (UPat.var("y")+1)), lambda x,y: x.eq(y)),
# remove the reduce if it's compare reduce (keep the outer range)
(UPat(Ops.BUFFERIZE, name="buf", src=(
(UPat.var("r1", dtype=dtypes.index) != UPat.var("r2")).where(0, UPat.var("val")).reduce(UPat.var("r2"), arg=Ops.ADD),), allow_any_len=True),
lambda r1,r2,val,buf: buf.replace(src=(val,)+buf.src[1:]).substitute({r1:r2}) if r1 in buf.src[1:] and r2.arg[-1] == AxisType.OUTER else None),
# remove the reduce if it's compare reduce
((UPat.var("r1", dtype=dtypes.index) != UPat.var("r2")).where(0, UPat.var("val")).reduce(UPat.var("r2"), arg=Ops.ADD),
lambda r1,r2,val: val.substitute({r2:r1}) if r2.arg[-1] != AxisType.OUTER else None),
])
@disable_gc()
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}", True)
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
@@ -561,9 +556,9 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
# convert movement ops to ranges
tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY)
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_fix_vmap, name="symbolic+reduce_collapse")
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding, name="symbolic+reduce_collapse")
tsink = graph_rewrite(tsink, pm_remove_bufferize, bottom_up=True, name="remove bufferize with cost function")
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_fix_vmap, name="symbolic+reduce_collapse pt 2")
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding, name="symbolic+reduce_collapse pt 2")
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers")
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph

View File

@@ -91,7 +91,6 @@ class Ops(FastEnum):
# reduce
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto()
REDUCE_BACKWARD = auto()
# errors/placeholders
REWRITE_ERROR = auto(); SENTINEL = auto()

View File

@@ -219,7 +219,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
# passthrough ops
case Ops.REDUCE | Ops.REDUCE_BACKWARD | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
return self.src[0]._shape
# ops with custom handling
@@ -442,7 +442,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
assert self.dtype.scalar() is dtypes.index, "Can only call get_valid on index dtype"
return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid)
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
def reduce_backward(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE_BACKWARD, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
def is_contiguous(self):
# TODO: this is is_realized

View File

@@ -108,7 +108,7 @@ _tensor_spec = PatternMatcher([
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
# REDUCE with an outerworld range
(UPat((Ops.REDUCE, Ops.REDUCE_BACKWARD), src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
# AFTER if things were kernelized
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),