diff --git a/tinygrad/device.py b/tinygrad/device.py index 0c6b5e6b1f..8711b0bc5a 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -223,6 +223,7 @@ class Allocator(Generic[DeviceType]): def __init__(self, dev:DeviceType): self.dev: DeviceType = dev self.default_buffer_spec: BufferSpec = BufferSpec() + self.supports_copy_from_disk: bool = True # overridden in LRUAllocator def alloc(self, size:int, options:BufferSpec|None=None): assert size > 0, f"alloc size must be positive, getting {size}" diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 7af3b0b550..a2f82d5282 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -102,7 +102,7 @@ class BufferCopy(Runner): super().__init__(colored(name, "yellow"), dest_device, Estimates(lds=total_sz, mem=total_sz)) def copy(self, dest, src): disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.dev, 'io_uring') and \ - getattr(src.allocator.dev, 'fd', None) is not None + getattr(src.allocator.dev, 'fd', None) is not None and dest.allocator.supports_copy_from_disk if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_disk') and disk_supports_fast_copyout and src.nbytes >= 4096: dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes) elif src.device.startswith("DISK") and hasattr(dest.allocator, '_as_buffer'): diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 5da9cde227..5a644a5377 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -469,6 +469,7 @@ class AMDAllocator(HCQAllocator['AMDDevice']): def __init__(self, dev:AMDDevice): super().__init__(dev, copy_bufs=getattr(dev.iface, 'copy_bufs', None), max_copyout_size=0x1000 if dev.is_usb() else None) if hasattr(dev.iface, "as_dmaref"): self._as_dmaref = dev.iface.as_dmaref + self.supports_copy_from_disk = not dev.is_usb() def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer: return self.dev.iface.alloc(size, host=options.host, uncached=options.uncached, cpu_access=options.cpu_access)