multi copy_to_device return the copy on same device if possible (#5117)

previously it always returns from the first device
This commit is contained in:
chenyu
2024-06-23 20:25:56 -04:00
committed by GitHub
parent b563cd52ed
commit c0ba5e0dfb
2 changed files with 27 additions and 1 deletions

View File

@@ -537,6 +537,28 @@ class TestMultiTensor(unittest.TestCase):
assert ast.src[0].src[0].op is BufferOps.LOAD
assert ast.src[0].src[1].op is BufferOps.CONST and ast.src[0].src[1].arg.val == 3
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
class TestHandleData(unittest.TestCase):
def test_copied_to_device(self):
device = [f"{Device.DEFAULT}:{i}" for i in range(4)]
t = Tensor([1, 2, 3, 4]).shard(device).realize()
not_covered = t.to(f"{Device.DEFAULT}:5")
sched = create_schedule([not_covered.lazydata])
assert len(sched) == 1
# setup again because create_schedule has side effect
t = Tensor([1, 2, 3, 4]).shard(device).realize()
not_covered = t.to(f"{Device.DEFAULT}:5")
assert not_covered.realize().tolist() == [1, 2, 3, 4]
t = Tensor([1, 2, 3, 4]).shard(device).realize()
covered = t.to(f"{Device.DEFAULT}:2")
sched = create_schedule([covered.lazydata])
assert len(sched) == 0
# setup again because create_schedule has side effect
t = Tensor([1, 2, 3, 4]).shard(device).realize()
covered = t.to(f"{Device.DEFAULT}:2")
assert covered.realize().tolist() == [1, 2, 3, 4]
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
# shrink a multitensor on sharded axis