mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
vmap on full model (#13340)
* vmap on full model * vmap gemm * reduce sums on end * outer reduce * only if there's ranges * put those rules in symbolic * ranges * do opt later * add zero range
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, UOp
|
||||
from tinygrad import Tensor, UOp, nn
|
||||
from tinygrad.uop.ops import AxisType, Ops
|
||||
|
||||
class TestOuterworldReduce(unittest.TestCase):
|
||||
@@ -156,7 +156,7 @@ class TestVmap(unittest.TestCase):
|
||||
|
||||
# vmap across axis 0
|
||||
a = UOp.range(3, -1, axis_type)
|
||||
out = x @ Tensor(mats.uop.reduce_backward(a, arg=Ops.ADD))[a]
|
||||
out = x @ mats[a]
|
||||
out = out.reshape(1, 10).pad(((a,(3-a)-1), None))
|
||||
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||
if fuse: out = out * 2
|
||||
@@ -175,5 +175,56 @@ class TestVmap(unittest.TestCase):
|
||||
def test_vmap_inner_fuse_grad(self): self.test_vmap_inner(fuse=True, grad=True)
|
||||
def test_vmap_outer_grad(self): self.test_vmap_inner(AxisType.OUTER, grad=True)
|
||||
|
||||
def test_vmap_convs(self):
|
||||
layers = [
|
||||
nn.Conv2d(1, 8, 3), Tensor.relu,
|
||||
nn.Conv2d(8, 8, 3), Tensor.relu]
|
||||
img = Tensor.randn(4, 1, 16, 16).realize(*nn.state.get_parameters(layers))
|
||||
a = UOp.range(4, -1, AxisType.OUTER)
|
||||
out = img[a:a+1].sequential(layers)
|
||||
out = out.pad(((a,(4-a)-1), None, None, None))
|
||||
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||
out.realize()
|
||||
np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6)
|
||||
|
||||
def test_vmap_gemm(self):
|
||||
layers = [
|
||||
nn.Linear(16, 16, bias=False), Tensor.relu,
|
||||
nn.Linear(16, 16, bias=False), Tensor.relu]
|
||||
img = Tensor.randn(4, 16).realize(*nn.state.get_parameters(layers))
|
||||
a = UOp.range(4, -1, AxisType.OUTER)
|
||||
out = img[a:a+1].sequential(layers)
|
||||
out = out.pad(((a,(4-a)-1), None))
|
||||
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||
out.realize()
|
||||
np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6)
|
||||
|
||||
@unittest.skip("this is broken, we need to lower the outer reduce in the outer graph")
|
||||
def test_vmap_gemm_grad(self):
|
||||
layers = [
|
||||
nn.Linear(16, 16, bias=False), Tensor.relu,
|
||||
nn.Linear(16, 16, bias=False), Tensor.relu]
|
||||
layer_tensors = nn.state.get_parameters(layers)
|
||||
img = Tensor.randn(4, 16).realize(*layer_tensors)
|
||||
for l in layer_tensors: l.requires_grad_()
|
||||
a = UOp.range(4, -1, AxisType.OUTER)
|
||||
out = img[a:a+1].sequential(layers)
|
||||
out = out.pad(((a,(4-a)-1), None))
|
||||
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
|
||||
out.mean().backward()
|
||||
grads = [l.grad for l in layer_tensors]
|
||||
out.realize(*grads)
|
||||
out_grads = [x.numpy() for x in grads]
|
||||
|
||||
# compute reference grads
|
||||
for l in layer_tensors: l.grad = None
|
||||
img.sequential(layers).mean().backward()
|
||||
grads = [l.grad for l in layer_tensors]
|
||||
out.realize(*grads)
|
||||
ref_grads = [x.numpy() for x in grads]
|
||||
|
||||
# compare
|
||||
for o,r in zip(out_grads, ref_grads): np.testing.assert_allclose(o, r, atol=1e-6)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user