mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
multi copy_to_device return the copy on same device if possible (#5117)
previously it always returns from the first device
This commit is contained in:
@@ -537,6 +537,28 @@ class TestMultiTensor(unittest.TestCase):
|
||||
assert ast.src[0].src[0].op is BufferOps.LOAD
|
||||
assert ast.src[0].src[1].op is BufferOps.CONST and ast.src[0].src[1].arg.val == 3
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
||||
class TestHandleData(unittest.TestCase):
|
||||
def test_copied_to_device(self):
|
||||
device = [f"{Device.DEFAULT}:{i}" for i in range(4)]
|
||||
t = Tensor([1, 2, 3, 4]).shard(device).realize()
|
||||
not_covered = t.to(f"{Device.DEFAULT}:5")
|
||||
sched = create_schedule([not_covered.lazydata])
|
||||
assert len(sched) == 1
|
||||
# setup again because create_schedule has side effect
|
||||
t = Tensor([1, 2, 3, 4]).shard(device).realize()
|
||||
not_covered = t.to(f"{Device.DEFAULT}:5")
|
||||
assert not_covered.realize().tolist() == [1, 2, 3, 4]
|
||||
|
||||
t = Tensor([1, 2, 3, 4]).shard(device).realize()
|
||||
covered = t.to(f"{Device.DEFAULT}:2")
|
||||
sched = create_schedule([covered.lazydata])
|
||||
assert len(sched) == 0
|
||||
# setup again because create_schedule has side effect
|
||||
t = Tensor([1, 2, 3, 4]).shard(device).realize()
|
||||
covered = t.to(f"{Device.DEFAULT}:2")
|
||||
assert covered.realize().tolist() == [1, 2, 3, 4]
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
||||
class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
# shrink a multitensor on sharded axis
|
||||
|
||||
Reference in New Issue
Block a user