outer vmap works (#13334)

* outer vmap works

* fuse works

* vmap outer works

* outer ranges work

* grad work

* should be good to merge
This commit is contained in:
George Hotz
2025-11-18 09:27:48 -08:00
committed by GitHub
parent 805de27e07
commit 06e39a88a9
5 changed files with 39 additions and 11 deletions

View File

@@ -77,7 +77,8 @@ class TestOuterScan(unittest.TestCase):
# 3 matmuls with SCAN
i = UOp.range(3, -100, AxisType.OUTER)
out = Tensor.empty(3, 1, 10)
comp = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop)) @ mats[i]
phi = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop))
comp = phi @ mats[i]
store = out[i].uop.store(comp.uop).end(i)
out = Tensor(out.uop.after(store))
out.realize()
@@ -145,21 +146,26 @@ class TestOuterworld(unittest.TestCase):
self.assertListEqual([[0,4,8],[4,8,12],[8,12,16]], out.tolist())
class TestVmap(unittest.TestCase):
def test_vmap_inner(self):
def test_vmap_inner(self, axis_type=AxisType.LOOP, fuse=False, grad=False):
x = Tensor.ones(1, 10).contiguous().requires_grad_()
mats = Tensor.ones(3, 10, 10).contiguous().requires_grad_()
ref = x @ mats
if fuse: ref = ref * 2
# vmap across axis 0
a = UOp.range(3, -1)
a = UOp.range(3, -1, axis_type)
out = x @ mats[a]
out = out.reshape(1, 10).pad(((a,(3-a)-1), None))
out = Tensor(UOp(Ops.REDUCE, dtype=out.uop.dtype, src=(out.uop, a), arg=Ops.ADD))
if fuse: out = out * 2
out.realize()
# TODO: testing allclose
assert Tensor.allclose(ref, out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
def test_vmap_inner_fuse(self): self.test_vmap_inner(fuse=True)
def test_vmap_outer(self): self.test_vmap_inner(AxisType.OUTER)
def test_vmap_outer_fuse(self): self.test_vmap_inner(AxisType.OUTER, fuse=True)
if __name__ == '__main__':
unittest.main()