mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
vmap on full model (#13340)
* vmap on full model * vmap gemm * reduce sums on end * outer reduce * only if there's ranges * put those rules in symbolic * ranges * do opt later * add zero range
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, UOp
|
||||
from tinygrad import Tensor, UOp, nn
|
||||
from tinygrad.uop.ops import AxisType, Ops
|
||||
|
||||
class TestOuterworldReduce(unittest.TestCase):
|
||||
@@ -156,7 +156,7 @@ class TestVmap(unittest.TestCase):
|
||||
|
||||
# vmap across axis 0
|
||||
a = UOp.range(3, -1, axis_type)
|
||||
out = x @ Tensor(mats.uop.reduce_backward(a, arg=Ops.ADD))[a]
|
||||
out = x @ mats[a]
|
||||
out = out.reshape(1, 10).pad(((a,(3-a)-1), None))
|
||||
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||
if fuse: out = out * 2
|
||||
@@ -175,5 +175,56 @@ class TestVmap(unittest.TestCase):
|
||||
def test_vmap_inner_fuse_grad(self): self.test_vmap_inner(fuse=True, grad=True)
|
||||
def test_vmap_outer_grad(self): self.test_vmap_inner(AxisType.OUTER, grad=True)
|
||||
|
||||
def test_vmap_convs(self):
|
||||
layers = [
|
||||
nn.Conv2d(1, 8, 3), Tensor.relu,
|
||||
nn.Conv2d(8, 8, 3), Tensor.relu]
|
||||
img = Tensor.randn(4, 1, 16, 16).realize(*nn.state.get_parameters(layers))
|
||||
a = UOp.range(4, -1, AxisType.OUTER)
|
||||
out = img[a:a+1].sequential(layers)
|
||||
out = out.pad(((a,(4-a)-1), None, None, None))
|
||||
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||
out.realize()
|
||||
np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6)
|
||||
|
||||
def test_vmap_gemm(self):
|
||||
layers = [
|
||||
nn.Linear(16, 16, bias=False), Tensor.relu,
|
||||
nn.Linear(16, 16, bias=False), Tensor.relu]
|
||||
img = Tensor.randn(4, 16).realize(*nn.state.get_parameters(layers))
|
||||
a = UOp.range(4, -1, AxisType.OUTER)
|
||||
out = img[a:a+1].sequential(layers)
|
||||
out = out.pad(((a,(4-a)-1), None))
|
||||
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||
out.realize()
|
||||
np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6)
|
||||
|
||||
@unittest.skip("this is broken, we need to lower the outer reduce in the outer graph")
|
||||
def test_vmap_gemm_grad(self):
|
||||
layers = [
|
||||
nn.Linear(16, 16, bias=False), Tensor.relu,
|
||||
nn.Linear(16, 16, bias=False), Tensor.relu]
|
||||
layer_tensors = nn.state.get_parameters(layers)
|
||||
img = Tensor.randn(4, 16).realize(*layer_tensors)
|
||||
for l in layer_tensors: l.requires_grad_()
|
||||
a = UOp.range(4, -1, AxisType.OUTER)
|
||||
out = img[a:a+1].sequential(layers)
|
||||
out = out.pad(((a,(4-a)-1), None))
|
||||
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||
out.mean().backward()
|
||||
grads = [l.grad for l in layer_tensors]
|
||||
out.realize(*grads)
|
||||
out_grads = [x.numpy() for x in grads]
|
||||
|
||||
# compute reference grads
|
||||
for l in layer_tensors: l.grad = None
|
||||
img.sequential(layers).mean().backward()
|
||||
grads = [l.grad for l in layer_tensors]
|
||||
out.realize(*grads)
|
||||
ref_grads = [x.numpy() for x in grads]
|
||||
|
||||
# compare
|
||||
for o,r in zip(out_grads, ref_grads): np.testing.assert_allclose(o, r, atol=1e-6)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -31,7 +31,6 @@ pm_gradient = PatternMatcher([
|
||||
(UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))),
|
||||
(UPat(Ops.REDUCE_AXIS, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg[0])),
|
||||
(UPat(Ops.REDUCE, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg) + (None,)*(len(ret.src)-1)),
|
||||
(UPat(Ops.REDUCE_BACKWARD, name="ret"), lambda ctx, ret: (ctx.reduce(*ret.src[1:], arg=ret.arg),) + (None,)*(len(ret.src)-1)),
|
||||
(UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
|
||||
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
|
||||
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)),
|
||||
@@ -71,4 +70,8 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
|
||||
# we add the backward metadata to everything new in the graph
|
||||
for bw_uop in v.toposort(lambda x: x not in (t0, *t0.src, grads[t0])):
|
||||
all_metadata[bw_uop] = all_metadata.get(bw_uop, ())+backward_metadata
|
||||
# end any ranges on grads with a reduce sum
|
||||
for k,v in grads.items():
|
||||
if len(v.ranges):
|
||||
grads[k] = v.reduce(*v.ranges, arg=Ops.ADD)
|
||||
return grads
|
||||
|
||||
@@ -65,7 +65,7 @@ mop_cleanup = PatternMatcher([
|
||||
|
||||
earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
# just removing it works...
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.REDUCE_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||
|
||||
# remove CONTIGUOUS if the BUFFER is already contiguous
|
||||
(UPat(Ops.BUFFER).f(Ops.RESHAPE, allow_any_len=True, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
|
||||
@@ -325,6 +325,18 @@ def bufferize_to_store(ctx:itertools.count|None, x:UOp, idx:UOp, allow_locals=Tr
|
||||
for m in mops[::-1]: ret = ret._mop(*m)
|
||||
return ret
|
||||
|
||||
# lower outerworld reduce here
|
||||
if x.src[0].op is Ops.REDUCE and len(x.src[0].src) == 2 and x.src[0].src[1].arg[-1] == AxisType.OUTER:
|
||||
assert sdtype.addrspace == AddrSpace.GLOBAL
|
||||
outer_range = x.src[0].src[1]
|
||||
buf = UOp.new_buffer(x.arg.device, size, x.dtype)
|
||||
# NOTE: this has the same number as the outer range, we need string ranges!
|
||||
zero_range = outer_range.replace(src=(UOp.const(dtypes.index, size),), arg=outer_range.arg[:-1]+(AxisType.LOOP,))
|
||||
buf = buf.after(buf.index(zero_range).store(0).end(zero_range))
|
||||
bufi = buf.index(idx, dtype=sdtype)
|
||||
do_store = bufi.store(bufi.load() + x.src[0].src[0], tag=x.tag).end(*rngs).end(outer_range)
|
||||
return buf.after(do_store)
|
||||
|
||||
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
|
||||
if sdtype.addrspace == AddrSpace.GLOBAL:
|
||||
buf = UOp.new_buffer(x.arg.device, size, x.dtype)
|
||||
@@ -500,12 +512,7 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in kernel.src)}")
|
||||
return kernel
|
||||
|
||||
def split_inner_and_outer_end(x: UOp):
|
||||
outer_ranges, inner_ranges = partition(x.src[1:], lambda r: r.arg[-1] == AxisType.OUTER)
|
||||
if len(outer_ranges) and len(inner_ranges): return x.src[0].end(*inner_ranges).end(*outer_ranges)
|
||||
|
||||
split_kernels = PatternMatcher([
|
||||
(UPat(Ops.END, name="x"), split_inner_and_outer_end),
|
||||
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
|
||||
])
|
||||
|
||||
@@ -537,18 +544,6 @@ replace_contiguous = PatternMatcher([
|
||||
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
||||
])
|
||||
|
||||
pm_fix_vmap = PatternMatcher([
|
||||
# x>=y and x<(y+1) means x==y (this can go in symbolic)
|
||||
((UPat.var("x", dtype=dtypes.index) >= UPat.var("y")) & (UPat.var("x") < (UPat.var("y")+1)), lambda x,y: x.eq(y)),
|
||||
# remove the reduce if it's compare reduce (keep the outer range)
|
||||
(UPat(Ops.BUFFERIZE, name="buf", src=(
|
||||
(UPat.var("r1", dtype=dtypes.index) != UPat.var("r2")).where(0, UPat.var("val")).reduce(UPat.var("r2"), arg=Ops.ADD),), allow_any_len=True),
|
||||
lambda r1,r2,val,buf: buf.replace(src=(val,)+buf.src[1:]).substitute({r1:r2}) if r1 in buf.src[1:] and r2.arg[-1] == AxisType.OUTER else None),
|
||||
# remove the reduce if it's compare reduce
|
||||
((UPat.var("r1", dtype=dtypes.index) != UPat.var("r2")).where(0, UPat.var("val")).reduce(UPat.var("r2"), arg=Ops.ADD),
|
||||
lambda r1,r2,val: val.substitute({r2:r1}) if r2.arg[-1] != AxisType.OUTER else None),
|
||||
])
|
||||
|
||||
@disable_gc()
|
||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}", True)
|
||||
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
@@ -561,9 +556,9 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
# convert movement ops to ranges
|
||||
tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY)
|
||||
|
||||
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_fix_vmap, name="symbolic+reduce_collapse")
|
||||
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding, name="symbolic+reduce_collapse")
|
||||
tsink = graph_rewrite(tsink, pm_remove_bufferize, bottom_up=True, name="remove bufferize with cost function")
|
||||
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_fix_vmap, name="symbolic+reduce_collapse pt 2")
|
||||
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding, name="symbolic+reduce_collapse pt 2")
|
||||
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers")
|
||||
|
||||
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph
|
||||
|
||||
@@ -91,7 +91,6 @@ class Ops(FastEnum):
|
||||
|
||||
# reduce
|
||||
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto()
|
||||
REDUCE_BACKWARD = auto()
|
||||
|
||||
# errors/placeholders
|
||||
REWRITE_ERROR = auto(); SENTINEL = auto()
|
||||
|
||||
@@ -219,7 +219,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
|
||||
|
||||
# passthrough ops
|
||||
case Ops.REDUCE | Ops.REDUCE_BACKWARD | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
|
||||
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
|
||||
return self.src[0]._shape
|
||||
|
||||
# ops with custom handling
|
||||
@@ -442,7 +442,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
assert self.dtype.scalar() is dtypes.index, "Can only call get_valid on index dtype"
|
||||
return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid)
|
||||
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
||||
def reduce_backward(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE_BACKWARD, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
||||
|
||||
def is_contiguous(self):
|
||||
# TODO: this is is_realized
|
||||
|
||||
@@ -108,7 +108,7 @@ _tensor_spec = PatternMatcher([
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
|
||||
|
||||
# REDUCE with an outerworld range
|
||||
(UPat((Ops.REDUCE, Ops.REDUCE_BACKWARD), src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
|
||||
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
|
||||
|
||||
# AFTER if things were kernelized
|
||||
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
|
||||
|
||||
Reference in New Issue
Block a user