diff --git a/test/test_multitensor.py b/test/test_multitensor.py index d2d365b37e..fda003e0ab 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -40,6 +40,7 @@ class TestMultiTensor(unittest.TestCase): assert lb.shape == (128,) (X + X).realize() + @unittest.skipIf(Device.DEFAULT == "METAL", "metal multi-device is fake") def test_sharded_memory(self): mem_base = GlobalCounters.mem_used diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 402085da78..70ea3c7f4b 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -71,7 +71,8 @@ class MetalAllocator(LRUAllocator): ret = self.device.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared) if ret is None: raise MemoryError(f"Metal OOM while allocating {size=}") return ret - def transfer(self, dest:Any, src:Any, sz:int, **kwargs): + def transfer(self, dest:Any, src:Any, sz:int, src_dev: MetalDevice, **kwargs): + src_dev.synchronize() command_buffer = self.device.mtl_queue.commandBuffer() encoder = command_buffer.blitCommandEncoder() encoder.copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_(src, 0, dest, 0, sz)