mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-15 17:15:48 -05:00
cuda tranfer + async copyin (#3873)
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
import subprocess, hashlib, tempfile, ctypes, ctypes.util, functools, re
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
from typing import Tuple, Optional, List
|
||||
import tinygrad.runtime.autogen.cuda as cuda
|
||||
from tinygrad.helpers import DEBUG, getenv, from_mv, to_char_p_p, init_c_var, init_c_struct_t, colored, cpu_time_execution
|
||||
from tinygrad.device import Compiled, LRUAllocator, MallocAllocator, Compiler
|
||||
from tinygrad.device import Compiled, LRUAllocator, MallocAllocator, Compiler, BufferOptions
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
from tinygrad.renderer.assembly import PTXRenderer
|
||||
@@ -122,15 +122,32 @@ class CUDAAllocator(LRUAllocator):
|
||||
def _alloc(self, size):
|
||||
check(cuda.cuCtxSetCurrent(self.device.context))
|
||||
return init_c_var(cuda.CUdeviceptr(), lambda x: check(cuda.cuMemAlloc_v2(ctypes.byref(x), size)))
|
||||
def _alloc_with_options(self, size:int, options:BufferOptions):
|
||||
if options.host:
|
||||
return init_c_var(ctypes.c_void_p(), lambda x: check(cuda.cuMemHostAlloc(ctypes.byref(x), size, 0)))
|
||||
else:
|
||||
raise Exception("no options")
|
||||
def _free(self, opaque): check(cuda.cuMemFree_v2(opaque))
|
||||
def copyin(self, dest, src:memoryview):
|
||||
check(cuda.cuCtxSetCurrent(self.device.context))
|
||||
check(cuda.cuMemcpyHtoD_v2(dest, from_mv(src), len(src), None))
|
||||
host_mem = self._alloc_with_options(len(src), BufferOptions(host=True))
|
||||
self.device.pending_copyin.append(host_mem.value)
|
||||
ctypes.memmove(host_mem, from_mv(src), len(src))
|
||||
check(cuda.cuMemcpyHtoDAsync_v2(dest, host_mem, len(src), None))
|
||||
def copyout(self, dest:memoryview, src):
|
||||
CUDADevice.synchronize_system()
|
||||
check(cuda.cuCtxSetCurrent(self.device.context))
|
||||
check(cuda.cuMemcpyDtoH_v2(from_mv(dest), src, len(dest)))
|
||||
def transfer(self, dest, src, sz:int, src_dev, dest_dev):
|
||||
check(cuda.cuCtxSetCurrent(src_dev.context))
|
||||
check(cuda.cuEventCreate(ctypes.byref(sync_event := cuda.CUevent()), 0))
|
||||
check(cuda.cuMemcpyDtoDAsync_v2(dest, src, sz, None))
|
||||
check(cuda.cuEventRecord(sync_event, None))
|
||||
check(cuda.cuCtxSetCurrent(dest_dev.context))
|
||||
check(cuda.cuStreamWaitEvent(None, sync_event, 0)) # sync the default stream on the dest dev
|
||||
|
||||
class CUDADevice(Compiled):
|
||||
devices: List[CUDADevice] = []
|
||||
|
||||
def __init__(self, device:str):
|
||||
device_id = int(device.split(":")[1]) if ":" in device else 0
|
||||
if not CUDACPU:
|
||||
@@ -138,13 +155,23 @@ class CUDADevice(Compiled):
|
||||
check(cuda.cuDeviceGet(ctypes.byref(cu_device := cuda.CUdevice()), device_id))
|
||||
self.context = init_c_var(cuda.CUcontext(), lambda x: check(cuda.cuCtxCreate_v2(ctypes.byref(x), 0, cu_device)))
|
||||
check(cuda.cuDeviceComputeCapability(ctypes.byref(major := ctypes.c_int()), ctypes.byref(minor := ctypes.c_int()), device_id))
|
||||
|
||||
self.arch = f"sm_{major.value}{minor.value}" if not CUDACPU else "sm_35"
|
||||
self.pending_copyin: List[int] = []
|
||||
CUDADevice.devices.append(self)
|
||||
|
||||
from tinygrad.runtime.graph.cuda import CUDAGraph
|
||||
super().__init__(device, CUDAAllocator(self) if not CUDACPU else MallocAllocator,
|
||||
PTXCompiler(self.arch) if getenv("PTX") else CUDACompiler(self.arch),
|
||||
functools.partial(CUDAProgram, self), graph=CUDAGraph if not CUDACPU else None)
|
||||
|
||||
def synchronize(self):
|
||||
if not CUDACPU:
|
||||
check(cuda.cuCtxSetCurrent(self.context))
|
||||
check(cuda.cuCtxSynchronize())
|
||||
if CUDACPU: return
|
||||
check(cuda.cuCtxSetCurrent(self.context))
|
||||
check(cuda.cuCtxSynchronize())
|
||||
for opaque in self.pending_copyin: check(cuda.cuMemFreeHost(opaque))
|
||||
self.pending_copyin.clear()
|
||||
|
||||
@staticmethod
|
||||
def synchronize_system():
|
||||
for d in CUDADevice.devices: d.synchronize()
|
||||
|
||||
Reference in New Issue
Block a user