don't view origin buffer when sharding (#5122)

* make buffer view optional with a flag

* do not view when sharding to save memory
This commit is contained in:
David Hou
2024-06-25 20:19:09 -07:00
committed by GitHub
parent 89e106686a
commit 666a9c1448
2 changed files with 7 additions and 1 deletions

View File

@@ -542,6 +542,12 @@ class TestMultiTensor(unittest.TestCase):
assert ast.src[0].src[0].op is BufferOps.LOAD
assert ast.src[0].src[1].op is BufferOps.CONST and ast.src[0].src[1].arg.val == 3
def test_shard_memory(self):
devices = (d0, d1, d2, d3)
t = Tensor.zeros(16, 16).contiguous()
t.shard_(devices, axis=0)
assert all([lb is lb.base and lb.buffer.base.size == 4 * 16 for lb in t.lazydata.lbs])
@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI")
class TestHandleData(unittest.TestCase):
def test_copied_to_device(self):