mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
mypy for mockgpu/cuda & dsp/run (#8575)
This commit is contained in:
@@ -16,7 +16,7 @@ except Exception: pass
|
||||
# Global state
|
||||
class CUDAState:
|
||||
def __init__(self):
|
||||
self.memory: dict[int, bytearray] = {}
|
||||
self.memory: dict[int, memoryview] = {}
|
||||
self.events: dict[int, float] = {} # Event ID -> timestamp
|
||||
self.modules: dict[int, memoryview] = {} # Module ID -> code
|
||||
self.current_context: int|None = None
|
||||
@@ -53,7 +53,7 @@ def cuCtxCreate_v2(pctx, flags: int, dev: int) -> int:
|
||||
pctx._obj.value = ctx_id
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuCtxSetCurrent(context: orig_cuda.CUcontext) -> int:
|
||||
def cuCtxSetCurrent(context) -> int:
|
||||
if context.value not in cuda_state.contexts:
|
||||
return orig_cuda.CUDA_ERROR_INVALID_VALUE
|
||||
cuda_state.current_context = context.value
|
||||
@@ -65,17 +65,17 @@ def cuMemAlloc_v2(dptr, bytesize: int) -> int:
|
||||
cuda_state.memory[dptr._obj.value] = x
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuMemFree_v2(dptr: orig_cuda.CUdeviceptr_v2) -> int:
|
||||
def cuMemFree_v2(dptr) -> int:
|
||||
if dptr.value in cuda_state.memory:
|
||||
del cuda_state.memory[dptr.value]
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
return orig_cuda.CUDA_ERROR_INVALID_VALUE
|
||||
|
||||
def cuMemcpyHtoDAsync_v2(dst: orig_cuda.CUdeviceptr_v2, src: ctypes.c_void_p, bytesize: int, stream: Any) -> int:
|
||||
def cuMemcpyHtoDAsync_v2(dst, src: ctypes.c_void_p, bytesize: int, stream: Any) -> int:
|
||||
ctypes.memmove(dst.value, src, bytesize)
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuMemcpyDtoH_v2(dst: ctypes.c_void_p, src: orig_cuda.CUdeviceptr_v2, bytesize: int) -> int:
|
||||
def cuMemcpyDtoH_v2(dst: ctypes.c_void_p, src, bytesize: int) -> int:
|
||||
ctypes.memmove(dst, src.value, bytesize)
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
@@ -86,25 +86,25 @@ def cuEventCreate(phEvent, flags: int) -> int:
|
||||
phEvent._obj.value = event_id
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuEventRecord(hEvent: orig_cuda.CUevent, hStream: Any) -> int:
|
||||
def cuEventRecord(hEvent, hStream: Any) -> int:
|
||||
if hEvent.value not in cuda_state.events:
|
||||
return orig_cuda.CUDA_ERROR_INVALID_VALUE
|
||||
cuda_state.events[hEvent.value] = time.perf_counter_ns()
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuEventSynchronize(hEvent: orig_cuda.CUevent) -> int:
|
||||
def cuEventSynchronize(hEvent) -> int:
|
||||
if hEvent.value not in cuda_state.events:
|
||||
return orig_cuda.CUDA_ERROR_INVALID_VALUE
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuEventElapsedTime(pMilliseconds, hStart: orig_cuda.CUevent, hEnd: orig_cuda.CUevent) -> int:
|
||||
def cuEventElapsedTime(pMilliseconds, hStart, hEnd) -> int:
|
||||
if hStart.value not in cuda_state.events or hEnd.value not in cuda_state.events:
|
||||
return orig_cuda.CUDA_ERROR_INVALID_VALUE
|
||||
elapsed = (cuda_state.events[hEnd.value] - cuda_state.events[hStart.value]) * 1000
|
||||
pMilliseconds._obj.value = elapsed
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuEventDestroy_v2(hEvent: orig_cuda.CUevent) -> int:
|
||||
def cuEventDestroy_v2(hEvent) -> int:
|
||||
if hEvent.value in cuda_state.events:
|
||||
del cuda_state.events[hEvent.value]
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
@@ -116,18 +116,18 @@ def cuModuleLoadData(module, image: bytes) -> int:
|
||||
module._obj.value = module_id
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuModuleGetFunction(hfunc, hmod: orig_cuda.CUmodule, name: bytes) -> int:
|
||||
def cuModuleGetFunction(hfunc, hmod, name: bytes) -> int:
|
||||
if hmod.value not in cuda_state.modules:
|
||||
return orig_cuda.CUDA_ERROR_INVALID_VALUE
|
||||
hfunc._obj.value = mv_address(cuda_state.modules[hmod.value])
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuModuleUnload(hmod: orig_cuda.CUmodule) -> int:
|
||||
def cuModuleUnload(hmod) -> int:
|
||||
if hmod.value in cuda_state.modules:
|
||||
del cuda_state.modules[hmod.value]
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuLaunchKernel(f: orig_cuda.CUfunction, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, sharedMemBytes: int,
|
||||
def cuLaunchKernel(f, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, sharedMemBytes: int,
|
||||
hStream: Any, kernelParams: Any, extra: Any) -> int:
|
||||
cargs = [ctypes.cast(getattr(extra, field[0]), ctypes.c_void_p) for field in extra._fields_]
|
||||
gpuocelot_lib.ptx_run(ctypes.cast(f.value, ctypes.c_char_p), len(cargs), (ctypes.c_void_p*len(cargs))(*cargs), lx, ly, lz, gx, gy, gz, 0)
|
||||
@@ -144,7 +144,7 @@ def cuDeviceCanAccessPeer(canAccessPeer, dev: int, peerDev: int) -> int:
|
||||
canAccessPeer._obj.value = 1 # Always allow peer access in simulation
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuCtxEnablePeerAccess(peerContext: orig_cuda.CUcontext, flags: int) -> int:
|
||||
def cuCtxEnablePeerAccess(peerContext, flags: int) -> int:
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuMemHostAlloc(pp, bytesize: int, flags: int) -> int:
|
||||
@@ -152,19 +152,19 @@ def cuMemHostAlloc(pp, bytesize: int, flags: int) -> int:
|
||||
|
||||
def cuMemFreeHost(p: ctypes.c_void_p) -> int: return cuMemFree_v2(p)
|
||||
|
||||
def cuMemcpyDtoDAsync_v2(dst: orig_cuda.CUdeviceptr_v2, src: orig_cuda.CUdeviceptr_v2, bytesize: int, stream: Any) -> int:
|
||||
def cuMemcpyDtoDAsync_v2(dst, src, bytesize: int, stream: Any) -> int:
|
||||
ctypes.memmove(dst.value, src.value, bytesize)
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuFuncSetAttribute(hfunc: orig_cuda.CUfunction, attrib: int, value: int) -> int:
|
||||
def cuFuncSetAttribute(hfunc, attrib: int, value: int) -> int:
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuStreamWaitEvent(stream: Any, event: orig_cuda.CUevent, flags: int) -> int: return orig_cuda.CUDA_SUCCESS
|
||||
def cuStreamWaitEvent(stream: Any, event, flags: int) -> int: return orig_cuda.CUDA_SUCCESS
|
||||
def cuCtxSynchronize() -> int: return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
def cuGetErrorString(error: int, pStr) -> int:
|
||||
error_str = orig_cuda.cudaError_enum__enumvalues.get(error, b"Unknown CUDA error")
|
||||
error_str = orig_cuda.cudaError_enum__enumvalues.get(error, "Unknown CUDA error").encode()
|
||||
buf = ctypes.create_string_buffer(error_str)
|
||||
# Set the pointer to point to our error string buffer
|
||||
pStr._obj.value = ctypes.cast(buf, ctypes.POINTER(ctypes.c_char)).value
|
||||
pStr._obj.value = ctypes.cast(buf, ctypes.POINTER(ctypes.c_char))
|
||||
return orig_cuda.CUDA_SUCCESS
|
||||
|
||||
Reference in New Issue
Block a user