mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
remove contiguous from MSELECT 2 (#10522)
* remove contiguous from MSELECT * test_shrink_on_shard_axis --------- Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
@@ -163,6 +163,20 @@ class TestMultiTensor(unittest.TestCase):
|
||||
O = X.shrink(((0, 2), None)) * W.shrink(((0, 2), None)) < 2
|
||||
np.testing.assert_allclose(O.numpy(), X.numpy()[0:2]*W.numpy()[0:2] < 2)
|
||||
|
||||
def test_shrink_on_shard_axis(self):
|
||||
X = Tensor.arange(4*4).reshape(4,4).realize()
|
||||
X_np = X.numpy()
|
||||
X.shard_(devices_2, 0)
|
||||
# only shrink on the device that owns the shard, this is enabled by the mselect simplifier
|
||||
for i in range(2):
|
||||
xt = X[i*2:i*2+2].contiguous()
|
||||
sched = xt.schedule()
|
||||
kernels = [s for s in sched if s.ast.op is Ops.SINK]
|
||||
self.assertEqual(len(kernels), 1)
|
||||
self.assertEqual(kernels[0].bufs[0].device, devices_2[i])
|
||||
run_schedule(sched)
|
||||
np.testing.assert_equal(xt.numpy(), X_np[i*2:i*2+2])
|
||||
|
||||
@given(strat.sampled_from((4, 5)), strat.sampled_from((devices_2, devices_3)),
|
||||
strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)),
|
||||
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1)))
|
||||
|
||||
Reference in New Issue
Block a user