diff --git a/test/test_outerworld.py b/test/test_outerworld.py index b3f8cf83aa..1d3f39cf53 100644 --- a/test/test_outerworld.py +++ b/test/test_outerworld.py @@ -1,6 +1,6 @@ import unittest import numpy as np -from tinygrad import Tensor, UOp +from tinygrad import Tensor, UOp, nn from tinygrad.uop.ops import AxisType, Ops class TestOuterworldReduce(unittest.TestCase): @@ -156,7 +156,7 @@ class TestVmap(unittest.TestCase): # vmap across axis 0 a = UOp.range(3, -1, axis_type) - out = x @ Tensor(mats.uop.reduce_backward(a, arg=Ops.ADD))[a] + out = x @ mats[a] out = out.reshape(1, 10).pad(((a,(3-a)-1), None)) out = Tensor(out.uop.reduce(a, arg=Ops.ADD)) if fuse: out = out * 2 @@ -175,5 +175,56 @@ class TestVmap(unittest.TestCase): def test_vmap_inner_fuse_grad(self): self.test_vmap_inner(fuse=True, grad=True) def test_vmap_outer_grad(self): self.test_vmap_inner(AxisType.OUTER, grad=True) + def test_vmap_convs(self): + layers = [ + nn.Conv2d(1, 8, 3), Tensor.relu, + nn.Conv2d(8, 8, 3), Tensor.relu] + img = Tensor.randn(4, 1, 16, 16).realize(*nn.state.get_parameters(layers)) + a = UOp.range(4, -1, AxisType.OUTER) + out = img[a:a+1].sequential(layers) + out = out.pad(((a,(4-a)-1), None, None, None)) + out = Tensor(out.uop.reduce(a, arg=Ops.ADD)) + out.realize() + np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6) + + def test_vmap_gemm(self): + layers = [ + nn.Linear(16, 16, bias=False), Tensor.relu, + nn.Linear(16, 16, bias=False), Tensor.relu] + img = Tensor.randn(4, 16).realize(*nn.state.get_parameters(layers)) + a = UOp.range(4, -1, AxisType.OUTER) + out = img[a:a+1].sequential(layers) + out = out.pad(((a,(4-a)-1), None)) + out = Tensor(out.uop.reduce(a, arg=Ops.ADD)) + out.realize() + np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6) + + @unittest.skip("this is broken, we need to lower the outer reduce in the outer graph") + def test_vmap_gemm_grad(self): + layers = [ + nn.Linear(16, 16, bias=False), Tensor.relu, + nn.Linear(16, 16, bias=False), Tensor.relu] + layer_tensors = nn.state.get_parameters(layers) + img = Tensor.randn(4, 16).realize(*layer_tensors) + for l in layer_tensors: l.requires_grad_() + a = UOp.range(4, -1, AxisType.OUTER) + out = img[a:a+1].sequential(layers) + out = out.pad(((a,(4-a)-1), None)) + out = Tensor(out.uop.reduce(a, arg=Ops.ADD)) + out.mean().backward() + grads = [l.grad for l in layer_tensors] + out.realize(*grads) + out_grads = [x.numpy() for x in grads] + + # compute reference grads + for l in layer_tensors: l.grad = None + img.sequential(layers).mean().backward() + grads = [l.grad for l in layer_tensors] + out.realize(*grads) + ref_grads = [x.numpy() for x in grads] + + # compare + for o,r in zip(out_grads, ref_grads): np.testing.assert_allclose(o, r, atol=1e-6) + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index d210f60024..0bcd63d4ee 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -31,7 +31,6 @@ pm_gradient = PatternMatcher([ (UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))), (UPat(Ops.REDUCE_AXIS, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg[0])), (UPat(Ops.REDUCE, name="ret"), lambda ctx, ret: reduce_gradient(ctx, ret, ret.arg) + (None,)*(len(ret.src)-1)), - (UPat(Ops.REDUCE_BACKWARD, name="ret"), lambda ctx, ret: (ctx.reduce(*ret.src[1:], arg=ret.arg),) + (None,)*(len(ret.src)-1)), (UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)), (UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)), (UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)), @@ -71,4 +70,8 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp # we add the backward metadata to everything new in the graph for bw_uop in v.toposort(lambda x: x not in (t0, *t0.src, grads[t0])): all_metadata[bw_uop] = all_metadata.get(bw_uop, ())+backward_metadata + # end any ranges on grads with a reduce sum + for k,v in grads.items(): + if len(v.ranges): + grads[k] = v.reduce(*v.ranges, arg=Ops.ADD) return grads diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index fdef6446dd..c989da46b5 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -65,7 +65,7 @@ mop_cleanup = PatternMatcher([ earliest_rewrites = mop_cleanup+PatternMatcher([ # just removing it works... - (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.REDUCE_BACKWARD), name="x"), lambda x: x.src[0]), + (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]), # remove CONTIGUOUS if the BUFFER is already contiguous (UPat(Ops.BUFFER).f(Ops.RESHAPE, allow_any_len=True, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)), @@ -325,6 +325,18 @@ def bufferize_to_store(ctx:itertools.count|None, x:UOp, idx:UOp, allow_locals=Tr for m in mops[::-1]: ret = ret._mop(*m) return ret + # lower outerworld reduce here + if x.src[0].op is Ops.REDUCE and len(x.src[0].src) == 2 and x.src[0].src[1].arg[-1] == AxisType.OUTER: + assert sdtype.addrspace == AddrSpace.GLOBAL + outer_range = x.src[0].src[1] + buf = UOp.new_buffer(x.arg.device, size, x.dtype) + # NOTE: this has the same number as the outer range, we need string ranges! + zero_range = outer_range.replace(src=(UOp.const(dtypes.index, size),), arg=outer_range.arg[:-1]+(AxisType.LOOP,)) + buf = buf.after(buf.index(zero_range).store(0).end(zero_range)) + bufi = buf.index(idx, dtype=sdtype) + do_store = bufi.store(bufi.load() + x.src[0].src[0], tag=x.tag).end(*rngs).end(outer_range) + return buf.after(do_store) + # NOTE: the DEFINE_LOCAL needs to be disambiguated here if sdtype.addrspace == AddrSpace.GLOBAL: buf = UOp.new_buffer(x.arg.device, size, x.dtype) @@ -500,12 +512,7 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None: raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in kernel.src)}") return kernel -def split_inner_and_outer_end(x: UOp): - outer_ranges, inner_ranges = partition(x.src[1:], lambda r: r.arg[-1] == AxisType.OUTER) - if len(outer_ranges) and len(inner_ranges): return x.src[0].end(*inner_ranges).end(*outer_ranges) - split_kernels = PatternMatcher([ - (UPat(Ops.END, name="x"), split_inner_and_outer_end), (UPat((Ops.STORE, Ops.END), name="x"), split_store), ]) @@ -537,18 +544,6 @@ replace_contiguous = PatternMatcher([ (UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None), ]) -pm_fix_vmap = PatternMatcher([ - # x>=y and x<(y+1) means x==y (this can go in symbolic) - ((UPat.var("x", dtype=dtypes.index) >= UPat.var("y")) & (UPat.var("x") < (UPat.var("y")+1)), lambda x,y: x.eq(y)), - # remove the reduce if it's compare reduce (keep the outer range) - (UPat(Ops.BUFFERIZE, name="buf", src=( - (UPat.var("r1", dtype=dtypes.index) != UPat.var("r2")).where(0, UPat.var("val")).reduce(UPat.var("r2"), arg=Ops.ADD),), allow_any_len=True), - lambda r1,r2,val,buf: buf.replace(src=(val,)+buf.src[1:]).substitute({r1:r2}) if r1 in buf.src[1:] and r2.arg[-1] == AxisType.OUTER else None), - # remove the reduce if it's compare reduce - ((UPat.var("r1", dtype=dtypes.index) != UPat.var("r2")).where(0, UPat.var("val")).reduce(UPat.var("r2"), arg=Ops.ADD), - lambda r1,r2,val: val.substitute({r2:r1}) if r2.arg[-1] != AxisType.OUTER else None), -]) - @disable_gc() @track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}", True) def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: @@ -561,9 +556,9 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: # convert movement ops to ranges tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY) - tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_fix_vmap, name="symbolic+reduce_collapse") + tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding, name="symbolic+reduce_collapse") tsink = graph_rewrite(tsink, pm_remove_bufferize, bottom_up=True, name="remove bufferize with cost function") - tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_fix_vmap, name="symbolic+reduce_collapse pt 2") + tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding, name="symbolic+reduce_collapse pt 2") tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers") # rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index 706de5e930..322cd2323f 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -91,7 +91,6 @@ class Ops(FastEnum): # reduce REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() - REDUCE_BACKWARD = auto() # errors/placeholders REWRITE_ERROR = auto(); SENTINEL = auto() diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 2aea713416..2f5877e772 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -219,7 +219,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,) # passthrough ops - case Ops.REDUCE | Ops.REDUCE_BACKWARD | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END: + case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END: return self.src[0]._shape # ops with custom handling @@ -442,7 +442,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass): assert self.dtype.scalar() is dtypes.index, "Can only call get_valid on index dtype" return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid) def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs) - def reduce_backward(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE_BACKWARD, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs) def is_contiguous(self): # TODO: this is is_realized diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 92cb9ec232..8fcd09a041 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -108,7 +108,7 @@ _tensor_spec = PatternMatcher([ (UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}), # REDUCE with an outerworld range - (UPat((Ops.REDUCE, Ops.REDUCE_BACKWARD), src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])), + (UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])), # AFTER if things were kernelized (UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),