mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
outer vmap works (#13334)
* outer vmap works * fuse works * vmap outer works * outer ranges work * grad work * should be good to merge
This commit is contained in:
@@ -77,7 +77,8 @@ class TestOuterScan(unittest.TestCase):
|
||||
# 3 matmuls with SCAN
|
||||
i = UOp.range(3, -100, AxisType.OUTER)
|
||||
out = Tensor.empty(3, 1, 10)
|
||||
comp = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop)) @ mats[i]
|
||||
phi = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop))
|
||||
comp = phi @ mats[i]
|
||||
store = out[i].uop.store(comp.uop).end(i)
|
||||
out = Tensor(out.uop.after(store))
|
||||
out.realize()
|
||||
@@ -145,21 +146,26 @@ class TestOuterworld(unittest.TestCase):
|
||||
self.assertListEqual([[0,4,8],[4,8,12],[8,12,16]], out.tolist())
|
||||
|
||||
class TestVmap(unittest.TestCase):
|
||||
def test_vmap_inner(self):
|
||||
def test_vmap_inner(self, axis_type=AxisType.LOOP, fuse=False, grad=False):
|
||||
x = Tensor.ones(1, 10).contiguous().requires_grad_()
|
||||
mats = Tensor.ones(3, 10, 10).contiguous().requires_grad_()
|
||||
|
||||
ref = x @ mats
|
||||
if fuse: ref = ref * 2
|
||||
|
||||
# vmap across axis 0
|
||||
a = UOp.range(3, -1)
|
||||
a = UOp.range(3, -1, axis_type)
|
||||
out = x @ mats[a]
|
||||
out = out.reshape(1, 10).pad(((a,(3-a)-1), None))
|
||||
out = Tensor(UOp(Ops.REDUCE, dtype=out.uop.dtype, src=(out.uop, a), arg=Ops.ADD))
|
||||
if fuse: out = out * 2
|
||||
out.realize()
|
||||
|
||||
# TODO: testing allclose
|
||||
assert Tensor.allclose(ref, out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
|
||||
def test_vmap_inner_fuse(self): self.test_vmap_inner(fuse=True)
|
||||
def test_vmap_outer(self): self.test_vmap_inner(AxisType.OUTER)
|
||||
def test_vmap_outer_fuse(self): self.test_vmap_inner(AxisType.OUTER, fuse=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user