remove contiguous from MSELECT 2 (#10522)

* remove contiguous from MSELECT

* test_shrink_on_shard_axis

---------

Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
qazal
2025-05-26 19:19:01 +03:00
committed by GitHub
parent 602a145f8f
commit 6d07087fe1
4 changed files with 28 additions and 4 deletions

View File

@@ -163,6 +163,20 @@ class TestMultiTensor(unittest.TestCase):
O = X.shrink(((0, 2), None)) * W.shrink(((0, 2), None)) < 2
np.testing.assert_allclose(O.numpy(), X.numpy()[0:2]*W.numpy()[0:2] < 2)
def test_shrink_on_shard_axis(self):
X = Tensor.arange(4*4).reshape(4,4).realize()
X_np = X.numpy()
X.shard_(devices_2, 0)
# only shrink on the device that owns the shard, this is enabled by the mselect simplifier
for i in range(2):
xt = X[i*2:i*2+2].contiguous()
sched = xt.schedule()
kernels = [s for s in sched if s.ast.op is Ops.SINK]
self.assertEqual(len(kernels), 1)
self.assertEqual(kernels[0].bufs[0].device, devices_2[i])
run_schedule(sched)
np.testing.assert_equal(xt.numpy(), X_np[i*2:i*2+2])
@given(strat.sampled_from((4, 5)), strat.sampled_from((devices_2, devices_3)),
strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)),
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1)))