diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 7096284ed1..e6d08ba727 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -127,14 +127,13 @@ class CUDAAllocator(LRUAllocator): check(cuda.cuCtxSetCurrent(self.device.context)) return init_c_var(cuda.CUdeviceptr(), lambda x: check(cuda.cuMemAlloc_v2(ctypes.byref(x), size))) def _alloc_with_options(self, size:int, options:BufferOptions): - if options.host: - return init_c_var(ctypes.c_void_p(), lambda x: check(cuda.cuMemHostAlloc(ctypes.byref(x), size, 0))) - else: - raise ValueError("no options") + if options.host: return init_c_var(ctypes.c_void_p(), lambda x: check(cuda.cuMemHostAlloc(ctypes.byref(x), size, 0))) + else: raise ValueError("no options") def _free(self, opaque): check(cuda.cuMemFree_v2(opaque)) def copyin(self, dest, src:memoryview): - host_mem = self._alloc_with_options(len(src), BufferOptions(host=True)) - self.device.pending_copyin.append(host_mem.value) + check(cuda.cuCtxSetCurrent(self.device.context)) + host_mem = self.alloc(len(src), BufferOptions(host=True)) + self.device.pending_copyin.append((host_mem, len(src), BufferOptions(host=True))) ctypes.memmove(host_mem, from_mv(src), len(src)) check(cuda.cuMemcpyHtoDAsync_v2(dest, host_mem, len(src), None)) def copyout(self, dest:memoryview, src): @@ -161,7 +160,7 @@ class CUDADevice(Compiled): check(cuda.cuDeviceComputeCapability(ctypes.byref(major := ctypes.c_int()), ctypes.byref(minor := ctypes.c_int()), device_id)) self.arch = f"sm_{major.value}{minor.value}" if not CUDACPU else "sm_35" - self.pending_copyin: List[int] = [] + self.pending_copyin: List[Tuple[int, int, Optional[BufferOptions]]] = [] CUDADevice.devices.append(self) from tinygrad.runtime.graph.cuda import CUDAGraph @@ -173,7 +172,7 @@ class CUDADevice(Compiled): if CUDACPU: return check(cuda.cuCtxSetCurrent(self.context)) check(cuda.cuCtxSynchronize()) - for opaque in self.pending_copyin: check(cuda.cuMemFreeHost(opaque)) + for opaque,sz,options in self.pending_copyin: self.allocator.free(opaque, sz, options) self.pending_copyin.clear() @staticmethod