vmap on full model (#13340)

* vmap on full model

* vmap gemm

* reduce sums on end

* outer reduce

* only if there's ranges

* put those rules in symbolic

* ranges

* do opt later

* add zero range
This commit is contained in:
George Hotz
2025-11-18 16:06:06 -08:00
committed by GitHub
parent 46cb65e692
commit 1afa3c0877
6 changed files with 74 additions and 27 deletions

View File

@@ -1,6 +1,6 @@
import unittest import unittest
import numpy as np import numpy as np
from tinygrad import Tensor, UOp from tinygrad import Tensor, UOp, nn
from tinygrad.uop.ops import AxisType, Ops from tinygrad.uop.ops import AxisType, Ops
class TestOuterworldReduce(unittest.TestCase): class TestOuterworldReduce(unittest.TestCase):
@@ -156,7 +156,7 @@ class TestVmap(unittest.TestCase):
# vmap across axis 0 # vmap across axis 0
a = UOp.range(3, -1, axis_type) a = UOp.range(3, -1, axis_type)
out = x @ Tensor(mats.uop.reduce_backward(a, arg=Ops.ADD))[a] out = x @ mats[a]
out = out.reshape(1, 10).pad(((a,(3-a)-1), None)) out = out.reshape(1, 10).pad(((a,(3-a)-1), None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD)) out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
if fuse: out = out * 2 if fuse: out = out * 2
@@ -175,5 +175,56 @@ class TestVmap(unittest.TestCase):
def test_vmap_inner_fuse_grad(self): self.test_vmap_inner(fuse=True, grad=True) def test_vmap_inner_fuse_grad(self): self.test_vmap_inner(fuse=True, grad=True)
def test_vmap_outer_grad(self): self.test_vmap_inner(AxisType.OUTER, grad=True) def test_vmap_outer_grad(self): self.test_vmap_inner(AxisType.OUTER, grad=True)
def test_vmap_convs(self):
layers = [
nn.Conv2d(1, 8, 3), Tensor.relu,
nn.Conv2d(8, 8, 3), Tensor.relu]
img = Tensor.randn(4, 1, 16, 16).realize(*nn.state.get_parameters(layers))
a = UOp.range(4, -1, AxisType.OUTER)
out = img[a:a+1].sequential(layers)
out = out.pad(((a,(4-a)-1), None, None, None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
out.realize()
np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6)
def test_vmap_gemm(self):
layers = [
nn.Linear(16, 16, bias=False), Tensor.relu,
nn.Linear(16, 16, bias=False), Tensor.relu]
img = Tensor.randn(4, 16).realize(*nn.state.get_parameters(layers))
a = UOp.range(4, -1, AxisType.OUTER)
out = img[a:a+1].sequential(layers)
out = out.pad(((a,(4-a)-1), None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
out.realize()
np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6)
@unittest.skip("this is broken, we need to lower the outer reduce in the outer graph")
def test_vmap_gemm_grad(self):
layers = [
nn.Linear(16, 16, bias=False), Tensor.relu,
nn.Linear(16, 16, bias=False), Tensor.relu]
layer_tensors = nn.state.get_parameters(layers)
img = Tensor.randn(4, 16).realize(*layer_tensors)
for l in layer_tensors: l.requires_grad_()
a = UOp.range(4, -1, AxisType.OUTER)
out = img[a:a+1].sequential(layers)
out = out.pad(((a,(4-a)-1), None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
out.mean().backward()
grads = [l.grad for l in layer_tensors]
out.realize(*grads)
out_grads = [x.numpy() for x in grads]
# compute reference grads
for l in layer_tensors: l.grad = None
img.sequential(layers).mean().backward()
grads = [l.grad for l in layer_tensors]
out.realize(*grads)
ref_grads = [x.numpy() for x in grads]
# compare
for o,r in zip(out_grads, ref_grads): np.testing.assert_allclose(o, r, atol=1e-6)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -31,7 +31,6 @@ pm_gradient = PatternMatcher([
(UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))), (UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))),
(UPat(Ops.REDUCE_AXIS, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg[0])), (UPat(Ops.REDUCE_AXIS, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg[0])),
(UPat(Ops.REDUCE, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg) + (None,)*(len(ret.src)-1)), (UPat(Ops.REDUCE, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg) + (None,)*(len(ret.src)-1)),
(UPat(Ops.REDUCE_BACKWARD, name="ret"), lambda ctx, ret: (ctx.reduce(*ret.src[1:], arg=ret.arg),) + (None,)*(len(ret.src)-1)),
(UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)), (UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)), (UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)), (UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)),
@@ -71,4 +70,8 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
# we add the backward metadata to everything new in the graph # we add the backward metadata to everything new in the graph
for bw_uop in v.toposort(lambda x: x not in (t0, *t0.src, grads[t0])): for bw_uop in v.toposort(lambda x: x not in (t0, *t0.src, grads[t0])):
all_metadata[bw_uop] = all_metadata.get(bw_uop, ())+backward_metadata all_metadata[bw_uop] = all_metadata.get(bw_uop, ())+backward_metadata
# end any ranges on grads with a reduce sum
for k,v in grads.items():
if len(v.ranges):
grads[k] = v.reduce(*v.ranges, arg=Ops.ADD)
return grads return grads

View File

@@ -65,7 +65,7 @@ mop_cleanup = PatternMatcher([
earliest_rewrites = mop_cleanup+PatternMatcher([ earliest_rewrites = mop_cleanup+PatternMatcher([
# just removing it works... # just removing it works...
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.REDUCE_BACKWARD), name="x"), lambda x: x.src[0]), (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# remove CONTIGUOUS if the BUFFER is already contiguous # remove CONTIGUOUS if the BUFFER is already contiguous
(UPat(Ops.BUFFER).f(Ops.RESHAPE, allow_any_len=True, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)), (UPat(Ops.BUFFER).f(Ops.RESHAPE, allow_any_len=True, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
@@ -325,6 +325,18 @@ def bufferize_to_store(ctx:itertools.count|None, x:UOp, idx:UOp, allow_locals=Tr
for m in mops[::-1]: ret = ret._mop(*m) for m in mops[::-1]: ret = ret._mop(*m)
return ret return ret
# lower outerworld reduce here
if x.src[0].op is Ops.REDUCE and len(x.src[0].src) == 2 and x.src[0].src[1].arg[-1] == AxisType.OUTER:
assert sdtype.addrspace == AddrSpace.GLOBAL
outer_range = x.src[0].src[1]
buf = UOp.new_buffer(x.arg.device, size, x.dtype)
# NOTE: this has the same number as the outer range, we need string ranges!
zero_range = outer_range.replace(src=(UOp.const(dtypes.index, size),), arg=outer_range.arg[:-1]+(AxisType.LOOP,))
buf = buf.after(buf.index(zero_range).store(0).end(zero_range))
bufi = buf.index(idx, dtype=sdtype)
do_store = bufi.store(bufi.load() + x.src[0].src[0], tag=x.tag).end(*rngs).end(outer_range)
return buf.after(do_store)
# NOTE: the DEFINE_LOCAL needs to be disambiguated here # NOTE: the DEFINE_LOCAL needs to be disambiguated here
if sdtype.addrspace == AddrSpace.GLOBAL: if sdtype.addrspace == AddrSpace.GLOBAL:
buf = UOp.new_buffer(x.arg.device, size, x.dtype) buf = UOp.new_buffer(x.arg.device, size, x.dtype)
@@ -500,12 +512,7 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in kernel.src)}") raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in kernel.src)}")
return kernel return kernel
def split_inner_and_outer_end(x: UOp):
outer_ranges, inner_ranges = partition(x.src[1:], lambda r: r.arg[-1] == AxisType.OUTER)
if len(outer_ranges) and len(inner_ranges): return x.src[0].end(*inner_ranges).end(*outer_ranges)
split_kernels = PatternMatcher([ split_kernels = PatternMatcher([
(UPat(Ops.END, name="x"), split_inner_and_outer_end),
(UPat((Ops.STORE, Ops.END), name="x"), split_store), (UPat((Ops.STORE, Ops.END), name="x"), split_store),
]) ])
@@ -537,18 +544,6 @@ replace_contiguous = PatternMatcher([
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None), (UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
]) ])
pm_fix_vmap = PatternMatcher([
# x>=y and x<(y+1) means x==y (this can go in symbolic)
((UPat.var("x", dtype=dtypes.index) >= UPat.var("y")) & (UPat.var("x") < (UPat.var("y")+1)), lambda x,y: x.eq(y)),
# remove the reduce if it's compare reduce (keep the outer range)
(UPat(Ops.BUFFERIZE, name="buf", src=(
(UPat.var("r1", dtype=dtypes.index) != UPat.var("r2")).where(0, UPat.var("val")).reduce(UPat.var("r2"), arg=Ops.ADD),), allow_any_len=True),
lambda r1,r2,val,buf: buf.replace(src=(val,)+buf.src[1:]).substitute({r1:r2}) if r1 in buf.src[1:] and r2.arg[-1] == AxisType.OUTER else None),
# remove the reduce if it's compare reduce
((UPat.var("r1", dtype=dtypes.index) != UPat.var("r2")).where(0, UPat.var("val")).reduce(UPat.var("r2"), arg=Ops.ADD),
lambda r1,r2,val: val.substitute({r2:r1}) if r2.arg[-1] != AxisType.OUTER else None),
])
@disable_gc() @disable_gc()
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}", True) @track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}", True)
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
@@ -561,9 +556,9 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
# convert movement ops to ranges # convert movement ops to ranges
tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY) tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY)
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_fix_vmap, name="symbolic+reduce_collapse") tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding, name="symbolic+reduce_collapse")
tsink = graph_rewrite(tsink, pm_remove_bufferize, bottom_up=True, name="remove bufferize with cost function") tsink = graph_rewrite(tsink, pm_remove_bufferize, bottom_up=True, name="remove bufferize with cost function")
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_fix_vmap, name="symbolic+reduce_collapse pt 2") tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding, name="symbolic+reduce_collapse pt 2")
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers") tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers")
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph # rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph

View File

@@ -91,7 +91,6 @@ class Ops(FastEnum):
# reduce # reduce
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto()
REDUCE_BACKWARD = auto()
# errors/placeholders # errors/placeholders
REWRITE_ERROR = auto(); SENTINEL = auto() REWRITE_ERROR = auto(); SENTINEL = auto()

View File

@@ -219,7 +219,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,) case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
# passthrough ops # passthrough ops
case Ops.REDUCE | Ops.REDUCE_BACKWARD | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END: case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
return self.src[0]._shape return self.src[0]._shape
# ops with custom handling # ops with custom handling
@@ -442,7 +442,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
assert self.dtype.scalar() is dtypes.index, "Can only call get_valid on index dtype" assert self.dtype.scalar() is dtypes.index, "Can only call get_valid on index dtype"
return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid) return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid)
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs) def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
def reduce_backward(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE_BACKWARD, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
def is_contiguous(self): def is_contiguous(self):
# TODO: this is is_realized # TODO: this is is_realized

View File

@@ -108,7 +108,7 @@ _tensor_spec = PatternMatcher([
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}), (UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
# REDUCE with an outerworld range # REDUCE with an outerworld range
(UPat((Ops.REDUCE, Ops.REDUCE_BACKWARD), src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])), (UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
# AFTER if things were kernelized # AFTER if things were kernelized
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True), (UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),