mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
this is the right way to write vmap (#13328)
This commit is contained in:
@@ -144,5 +144,22 @@ class TestOuterworld(unittest.TestCase):
|
||||
out = out.reshape(1, 3).expand(a, 3).contiguous().realize()
|
||||
self.assertListEqual([[0,4,8],[4,8,12],[8,12,16]], out.tolist())
|
||||
|
||||
class TestVmap(unittest.TestCase):
|
||||
def test_vmap_inner(self):
|
||||
x = Tensor.ones(1, 10).contiguous().requires_grad_()
|
||||
mats = Tensor.ones(3, 10, 10).contiguous().requires_grad_()
|
||||
|
||||
ref = x @ mats
|
||||
|
||||
# vmap across axis 0
|
||||
a = UOp.range(3, -1)
|
||||
out = x @ mats[a]
|
||||
out = out.reshape(1, 10).pad(((a,(3-a)-1), None))
|
||||
out = Tensor(UOp(Ops.REDUCE, dtype=out.uop.dtype, src=(out.uop, a), arg=Ops.ADD))
|
||||
out.realize()
|
||||
|
||||
# TODO: testing allclose
|
||||
assert Tensor.allclose(ref, out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user