mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
vmap on full model (#13340)
* vmap on full model * vmap gemm * reduce sums on end * outer reduce * only if there's ranges * put those rules in symbolic * ranges * do opt later * add zero range
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tinygrad import Tensor, UOp
|
from tinygrad import Tensor, UOp, nn
|
||||||
from tinygrad.uop.ops import AxisType, Ops
|
from tinygrad.uop.ops import AxisType, Ops
|
||||||
|
|
||||||
class TestOuterworldReduce(unittest.TestCase):
|
class TestOuterworldReduce(unittest.TestCase):
|
||||||
@@ -156,7 +156,7 @@ class TestVmap(unittest.TestCase):
|
|||||||
|
|
||||||
# vmap across axis 0
|
# vmap across axis 0
|
||||||
a = UOp.range(3, -1, axis_type)
|
a = UOp.range(3, -1, axis_type)
|
||||||
out = x @ Tensor(mats.uop.reduce_backward(a, arg=Ops.ADD))[a]
|
out = x @ mats[a]
|
||||||
out = out.reshape(1, 10).pad(((a,(3-a)-1), None))
|
out = out.reshape(1, 10).pad(((a,(3-a)-1), None))
|
||||||
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||||
if fuse: out = out * 2
|
if fuse: out = out * 2
|
||||||
@@ -175,5 +175,56 @@ class TestVmap(unittest.TestCase):
|
|||||||
def test_vmap_inner_fuse_grad(self): self.test_vmap_inner(fuse=True, grad=True)
|
def test_vmap_inner_fuse_grad(self): self.test_vmap_inner(fuse=True, grad=True)
|
||||||
def test_vmap_outer_grad(self): self.test_vmap_inner(AxisType.OUTER, grad=True)
|
def test_vmap_outer_grad(self): self.test_vmap_inner(AxisType.OUTER, grad=True)
|
||||||
|
|
||||||
|
def test_vmap_convs(self):
|
||||||
|
layers = [
|
||||||
|
nn.Conv2d(1, 8, 3), Tensor.relu,
|
||||||
|
nn.Conv2d(8, 8, 3), Tensor.relu]
|
||||||
|
img = Tensor.randn(4, 1, 16, 16).realize(*nn.state.get_parameters(layers))
|
||||||
|
a = UOp.range(4, -1, AxisType.OUTER)
|
||||||
|
out = img[a:a+1].sequential(layers)
|
||||||
|
out = out.pad(((a,(4-a)-1), None, None, None))
|
||||||
|
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||||
|
out.realize()
|
||||||
|
np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6)
|
||||||
|
|
||||||
|
def test_vmap_gemm(self):
|
||||||
|
layers = [
|
||||||
|
nn.Linear(16, 16, bias=False), Tensor.relu,
|
||||||
|
nn.Linear(16, 16, bias=False), Tensor.relu]
|
||||||
|
img = Tensor.randn(4, 16).realize(*nn.state.get_parameters(layers))
|
||||||
|
a = UOp.range(4, -1, AxisType.OUTER)
|
||||||
|
out = img[a:a+1].sequential(layers)
|
||||||
|
out = out.pad(((a,(4-a)-1), None))
|
||||||
|
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||||
|
out.realize()
|
||||||
|
np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6)
|
||||||
|
|
||||||
|
@unittest.skip("this is broken, we need to lower the outer reduce in the outer graph")
|
||||||
|
def test_vmap_gemm_grad(self):
|
||||||
|
layers = [
|
||||||
|
nn.Linear(16, 16, bias=False), Tensor.relu,
|
||||||
|
nn.Linear(16, 16, bias=False), Tensor.relu]
|
||||||
|
layer_tensors = nn.state.get_parameters(layers)
|
||||||
|
img = Tensor.randn(4, 16).realize(*layer_tensors)
|
||||||
|
for l in layer_tensors: l.requires_grad_()
|
||||||
|
a = UOp.range(4, -1, AxisType.OUTER)
|
||||||
|
out = img[a:a+1].sequential(layers)
|
||||||
|
out = out.pad(((a,(4-a)-1), None))
|
||||||
|
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||||
|
out.mean().backward()
|
||||||
|
grads = [l.grad for l in layer_tensors]
|
||||||
|
out.realize(*grads)
|
||||||
|
out_grads = [x.numpy() for x in grads]
|
||||||
|
|
||||||
|
# compute reference grads
|
||||||
|
for l in layer_tensors: l.grad = None
|
||||||
|
img.sequential(layers).mean().backward()
|
||||||
|
grads = [l.grad for l in layer_tensors]
|
||||||
|
out.realize(*grads)
|
||||||
|
ref_grads = [x.numpy() for x in grads]
|
||||||
|
|
||||||
|
# compare
|
||||||
|
for o,r in zip(out_grads, ref_grads): np.testing.assert_allclose(o, r, atol=1e-6)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
@@ -31,7 +31,6 @@ pm_gradient = PatternMatcher([
|
|||||||
(UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))),
|
(UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))),
|
||||||
(UPat(Ops.REDUCE_AXIS, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg[0])),
|
(UPat(Ops.REDUCE_AXIS, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg[0])),
|
||||||
(UPat(Ops.REDUCE, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg) + (None,)*(len(ret.src)-1)),
|
(UPat(Ops.REDUCE, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg) + (None,)*(len(ret.src)-1)),
|
||||||
(UPat(Ops.REDUCE_BACKWARD, name="ret"), lambda ctx, ret: (ctx.reduce(*ret.src[1:], arg=ret.arg),) + (None,)*(len(ret.src)-1)),
|
|
||||||
(UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
|
(UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
|
||||||
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
|
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
|
||||||
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)),
|
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)),
|
||||||
@@ -71,4 +70,8 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
|
|||||||
# we add the backward metadata to everything new in the graph
|
# we add the backward metadata to everything new in the graph
|
||||||
for bw_uop in v.toposort(lambda x: x not in (t0, *t0.src, grads[t0])):
|
for bw_uop in v.toposort(lambda x: x not in (t0, *t0.src, grads[t0])):
|
||||||
all_metadata[bw_uop] = all_metadata.get(bw_uop, ())+backward_metadata
|
all_metadata[bw_uop] = all_metadata.get(bw_uop, ())+backward_metadata
|
||||||
|
# end any ranges on grads with a reduce sum
|
||||||
|
for k,v in grads.items():
|
||||||
|
if len(v.ranges):
|
||||||
|
grads[k] = v.reduce(*v.ranges, arg=Ops.ADD)
|
||||||
return grads
|
return grads
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ mop_cleanup = PatternMatcher([
|
|||||||
|
|
||||||
earliest_rewrites = mop_cleanup+PatternMatcher([
|
earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||||
# just removing it works...
|
# just removing it works...
|
||||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.REDUCE_BACKWARD), name="x"), lambda x: x.src[0]),
|
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||||
|
|
||||||
# remove CONTIGUOUS if the BUFFER is already contiguous
|
# remove CONTIGUOUS if the BUFFER is already contiguous
|
||||||
(UPat(Ops.BUFFER).f(Ops.RESHAPE, allow_any_len=True, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
|
(UPat(Ops.BUFFER).f(Ops.RESHAPE, allow_any_len=True, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
|
||||||
@@ -325,6 +325,18 @@ def bufferize_to_store(ctx:itertools.count|None, x:UOp, idx:UOp, allow_locals=Tr
|
|||||||
for m in mops[::-1]: ret = ret._mop(*m)
|
for m in mops[::-1]: ret = ret._mop(*m)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
# lower outerworld reduce here
|
||||||
|
if x.src[0].op is Ops.REDUCE and len(x.src[0].src) == 2 and x.src[0].src[1].arg[-1] == AxisType.OUTER:
|
||||||
|
assert sdtype.addrspace == AddrSpace.GLOBAL
|
||||||
|
outer_range = x.src[0].src[1]
|
||||||
|
buf = UOp.new_buffer(x.arg.device, size, x.dtype)
|
||||||
|
# NOTE: this has the same number as the outer range, we need string ranges!
|
||||||
|
zero_range = outer_range.replace(src=(UOp.const(dtypes.index, size),), arg=outer_range.arg[:-1]+(AxisType.LOOP,))
|
||||||
|
buf = buf.after(buf.index(zero_range).store(0).end(zero_range))
|
||||||
|
bufi = buf.index(idx, dtype=sdtype)
|
||||||
|
do_store = bufi.store(bufi.load() + x.src[0].src[0], tag=x.tag).end(*rngs).end(outer_range)
|
||||||
|
return buf.after(do_store)
|
||||||
|
|
||||||
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
|
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
|
||||||
if sdtype.addrspace == AddrSpace.GLOBAL:
|
if sdtype.addrspace == AddrSpace.GLOBAL:
|
||||||
buf = UOp.new_buffer(x.arg.device, size, x.dtype)
|
buf = UOp.new_buffer(x.arg.device, size, x.dtype)
|
||||||
@@ -500,12 +512,7 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
|||||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in kernel.src)}")
|
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in kernel.src)}")
|
||||||
return kernel
|
return kernel
|
||||||
|
|
||||||
def split_inner_and_outer_end(x: UOp):
|
|
||||||
outer_ranges, inner_ranges = partition(x.src[1:], lambda r: r.arg[-1] == AxisType.OUTER)
|
|
||||||
if len(outer_ranges) and len(inner_ranges): return x.src[0].end(*inner_ranges).end(*outer_ranges)
|
|
||||||
|
|
||||||
split_kernels = PatternMatcher([
|
split_kernels = PatternMatcher([
|
||||||
(UPat(Ops.END, name="x"), split_inner_and_outer_end),
|
|
||||||
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
|
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
|
||||||
])
|
])
|
||||||
|
|
||||||
@@ -537,18 +544,6 @@ replace_contiguous = PatternMatcher([
|
|||||||
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
||||||
])
|
])
|
||||||
|
|
||||||
pm_fix_vmap = PatternMatcher([
|
|
||||||
# x>=y and x<(y+1) means x==y (this can go in symbolic)
|
|
||||||
((UPat.var("x", dtype=dtypes.index) >= UPat.var("y")) & (UPat.var("x") < (UPat.var("y")+1)), lambda x,y: x.eq(y)),
|
|
||||||
# remove the reduce if it's compare reduce (keep the outer range)
|
|
||||||
(UPat(Ops.BUFFERIZE, name="buf", src=(
|
|
||||||
(UPat.var("r1", dtype=dtypes.index) != UPat.var("r2")).where(0, UPat.var("val")).reduce(UPat.var("r2"), arg=Ops.ADD),), allow_any_len=True),
|
|
||||||
lambda r1,r2,val,buf: buf.replace(src=(val,)+buf.src[1:]).substitute({r1:r2}) if r1 in buf.src[1:] and r2.arg[-1] == AxisType.OUTER else None),
|
|
||||||
# remove the reduce if it's compare reduce
|
|
||||||
((UPat.var("r1", dtype=dtypes.index) != UPat.var("r2")).where(0, UPat.var("val")).reduce(UPat.var("r2"), arg=Ops.ADD),
|
|
||||||
lambda r1,r2,val: val.substitute({r2:r1}) if r2.arg[-1] != AxisType.OUTER else None),
|
|
||||||
])
|
|
||||||
|
|
||||||
@disable_gc()
|
@disable_gc()
|
||||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}", True)
|
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}", True)
|
||||||
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||||
@@ -561,9 +556,9 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
|||||||
# convert movement ops to ranges
|
# convert movement ops to ranges
|
||||||
tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY)
|
tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY)
|
||||||
|
|
||||||
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_fix_vmap, name="symbolic+reduce_collapse")
|
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding, name="symbolic+reduce_collapse")
|
||||||
tsink = graph_rewrite(tsink, pm_remove_bufferize, bottom_up=True, name="remove bufferize with cost function")
|
tsink = graph_rewrite(tsink, pm_remove_bufferize, bottom_up=True, name="remove bufferize with cost function")
|
||||||
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_fix_vmap, name="symbolic+reduce_collapse pt 2")
|
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding, name="symbolic+reduce_collapse pt 2")
|
||||||
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers")
|
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers")
|
||||||
|
|
||||||
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph
|
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph
|
||||||
|
|||||||
@@ -91,7 +91,6 @@ class Ops(FastEnum):
|
|||||||
|
|
||||||
# reduce
|
# reduce
|
||||||
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto()
|
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto()
|
||||||
REDUCE_BACKWARD = auto()
|
|
||||||
|
|
||||||
# errors/placeholders
|
# errors/placeholders
|
||||||
REWRITE_ERROR = auto(); SENTINEL = auto()
|
REWRITE_ERROR = auto(); SENTINEL = auto()
|
||||||
|
|||||||
@@ -219,7 +219,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||||||
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
|
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
|
||||||
|
|
||||||
# passthrough ops
|
# passthrough ops
|
||||||
case Ops.REDUCE | Ops.REDUCE_BACKWARD | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
|
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
|
||||||
return self.src[0]._shape
|
return self.src[0]._shape
|
||||||
|
|
||||||
# ops with custom handling
|
# ops with custom handling
|
||||||
@@ -442,7 +442,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||||||
assert self.dtype.scalar() is dtypes.index, "Can only call get_valid on index dtype"
|
assert self.dtype.scalar() is dtypes.index, "Can only call get_valid on index dtype"
|
||||||
return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid)
|
return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid)
|
||||||
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
||||||
def reduce_backward(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE_BACKWARD, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
|
||||||
|
|
||||||
def is_contiguous(self):
|
def is_contiguous(self):
|
||||||
# TODO: this is is_realized
|
# TODO: this is is_realized
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ _tensor_spec = PatternMatcher([
|
|||||||
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
|
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
|
||||||
|
|
||||||
# REDUCE with an outerworld range
|
# REDUCE with an outerworld range
|
||||||
(UPat((Ops.REDUCE, Ops.REDUCE_BACKWARD), src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
|
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
|
||||||
|
|
||||||
# AFTER if things were kernelized
|
# AFTER if things were kernelized
|
||||||
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
|
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
|
||||||
|
|||||||
Reference in New Issue
Block a user