multi like on full_like as well as rand_like (#13402)

* multi like on full_like as well as rand_like

* add test and fix bug

* mismatch, optim match

* one line
This commit is contained in:
George Hotz
2025-11-20 20:46:48 -08:00
committed by GitHub
parent fa3def2f12
commit e1051d00d7
3 changed files with 26 additions and 14 deletions

View File

@@ -765,6 +765,16 @@ class TestMultiTensor(unittest.TestCase):
with self.assertRaises(RuntimeError):
Tensor.rand_like(t, device=(d3, d4))
def test_full_like_on_shard(self, axis=None):
t = Tensor.empty((16, 16)).shard(devices_2, axis=axis)
t2 = Tensor.full_like(t, 1.0)
self.assertEqual(t.shape, t2.shape)
self.assertEqual(t.device, t2.device)
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.uop.axis, t2.uop.axis)
t2.realize()
def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0)
def test_dropout_on_shard(self):
with Tensor.train():
X = Tensor.ones(256).to(devices_2)