this is the right way to write vmap (#13328)

This commit is contained in:
George Hotz
2025-11-17 20:20:52 -08:00
committed by GitHub
parent 8e8e53c886
commit 583560ab72
2 changed files with 17 additions and 0 deletions

View File

@@ -144,5 +144,22 @@ class TestOuterworld(unittest.TestCase):
out = out.reshape(1, 3).expand(a, 3).contiguous().realize()
self.assertListEqual([[0,4,8],[4,8,12],[8,12,16]], out.tolist())
class TestVmap(unittest.TestCase):
def test_vmap_inner(self):
x = Tensor.ones(1, 10).contiguous().requires_grad_()
mats = Tensor.ones(3, 10, 10).contiguous().requires_grad_()
ref = x @ mats
# vmap across axis 0
a = UOp.range(3, -1)
out = x @ mats[a]
out = out.reshape(1, 10).pad(((a,(3-a)-1), None))
out = Tensor(UOp(Ops.REDUCE, dtype=out.uop.dtype, src=(out.uop, a), arg=Ops.ADD))
out.realize()
# TODO: testing allclose
assert Tensor.allclose(ref, out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
if __name__ == '__main__':
unittest.main()

View File