vmap on full model (#13340)

* vmap on full model

* vmap gemm

* reduce sums on end

* outer reduce

* only if there's ranges

* put those rules in symbolic

* ranges

* do opt later

* add zero range
This commit is contained in:
George Hotz
2025-11-18 16:06:06 -08:00
committed by GitHub
parent 46cb65e692
commit 1afa3c0877
6 changed files with 74 additions and 27 deletions

View File

@@ -1,6 +1,6 @@
import unittest
import numpy as np
from tinygrad import Tensor, UOp
from tinygrad import Tensor, UOp, nn
from tinygrad.uop.ops import AxisType, Ops
class TestOuterworldReduce(unittest.TestCase):
@@ -156,7 +156,7 @@ class TestVmap(unittest.TestCase):
# vmap across axis 0
a = UOp.range(3, -1, axis_type)
out = x @ Tensor(mats.uop.reduce_backward(a, arg=Ops.ADD))[a]
out = x @ mats[a]
out = out.reshape(1, 10).pad(((a,(3-a)-1), None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
if fuse: out = out * 2
@@ -175,5 +175,56 @@ class TestVmap(unittest.TestCase):
def test_vmap_inner_fuse_grad(self): self.test_vmap_inner(fuse=True, grad=True)
def test_vmap_outer_grad(self): self.test_vmap_inner(AxisType.OUTER, grad=True)
def test_vmap_convs(self):
layers = [
nn.Conv2d(1, 8, 3), Tensor.relu,
nn.Conv2d(8, 8, 3), Tensor.relu]
img = Tensor.randn(4, 1, 16, 16).realize(*nn.state.get_parameters(layers))
a = UOp.range(4, -1, AxisType.OUTER)
out = img[a:a+1].sequential(layers)
out = out.pad(((a,(4-a)-1), None, None, None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
out.realize()
np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6)
def test_vmap_gemm(self):
layers = [
nn.Linear(16, 16, bias=False), Tensor.relu,
nn.Linear(16, 16, bias=False), Tensor.relu]
img = Tensor.randn(4, 16).realize(*nn.state.get_parameters(layers))
a = UOp.range(4, -1, AxisType.OUTER)
out = img[a:a+1].sequential(layers)
out = out.pad(((a,(4-a)-1), None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
out.realize()
np.testing.assert_allclose(out.numpy(), img.sequential(layers).numpy(), atol=1e-6)
@unittest.skip("this is broken, we need to lower the outer reduce in the outer graph")
def test_vmap_gemm_grad(self):
layers = [
nn.Linear(16, 16, bias=False), Tensor.relu,
nn.Linear(16, 16, bias=False), Tensor.relu]
layer_tensors = nn.state.get_parameters(layers)
img = Tensor.randn(4, 16).realize(*layer_tensors)
for l in layer_tensors: l.requires_grad_()
a = UOp.range(4, -1, AxisType.OUTER)
out = img[a:a+1].sequential(layers)
out = out.pad(((a,(4-a)-1), None))
out = Tensor(out.uop.reduce(a, arg=Ops.ADD))
out.mean().backward()
grads = [l.grad for l in layer_tensors]
out.realize(*grads)
out_grads = [x.numpy() for x in grads]
# compute reference grads
for l in layer_tensors: l.grad = None
img.sequential(layers).mean().backward()
grads = [l.grad for l in layer_tensors]
out.realize(*grads)
ref_grads = [x.numpy() for x in grads]
# compare
for o,r in zip(out_grads, ref_grads): np.testing.assert_allclose(o, r, atol=1e-6)
if __name__ == '__main__':
unittest.main()