mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
do not create structs every call in CUDAProgram (#3855)
* do not create structs in cuda * fix graph * linter * do not exec twice * fix graph
This commit is contained in:
@@ -213,20 +213,3 @@ def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
|
||||
return CStruct
|
||||
def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
|
||||
def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,))
|
||||
|
||||
# *** Helpers for CUDA-like APIs.
|
||||
|
||||
def encode_args_cuda_style(bufs, vals, device_ptr_t, marks) -> Tuple[ctypes.Array, ctypes.Structure]:
|
||||
c_args = init_c_struct_t(tuple([(f'f{i}', device_ptr_t) for i in range(len(bufs))] + [(f'f{i}', ctypes.c_int) for i in range(len(bufs), len(bufs)+len(vals))]))(*bufs, *vals) # noqa: E501
|
||||
return (ctypes.c_void_p * 5)(ctypes.c_void_p(marks[0]), ctypes.cast(ctypes.pointer(c_args), ctypes.c_void_p), ctypes.c_void_p(marks[1]), ctypes.cast(ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(c_args))), ctypes.c_void_p), ctypes.c_void_p(marks[2])), c_args # noqa: E501
|
||||
|
||||
def time_execution_cuda_style(cb, ev_t, evcreate, evrecord, evsync, evdestroy, evtime, enable=False) -> Optional[float]:
|
||||
if not enable: return cb()
|
||||
evs = [init_c_var(ev_t(), lambda x: evcreate(ctypes.byref(x), 0)) for _ in range(2)]
|
||||
evrecord(evs[0], None)
|
||||
cb()
|
||||
evrecord(evs[1], None)
|
||||
evsync(evs[1])
|
||||
evtime(ctypes.byref(ret := ctypes.c_float()), evs[0], evs[1])
|
||||
for ev in evs: evdestroy(ev)
|
||||
return ret.value * 1e-3
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import ctypes
|
||||
from typing import Any, Optional, Tuple, Dict, List, cast
|
||||
import tinygrad.runtime.autogen.cuda as cuda
|
||||
from tinygrad.helpers import init_c_var, encode_args_cuda_style, all_same, GraphException
|
||||
from tinygrad.helpers import init_c_var, all_same, GraphException
|
||||
from tinygrad.device import CompiledASTRunner, update_stats, Buffer
|
||||
from tinygrad.runtime.ops_cuda import check, cu_time_execution
|
||||
from tinygrad.runtime.ops_cuda import check, cu_time_execution, encode_args
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.features.jit import JitItem, get_input_replace, get_jit_stats, \
|
||||
get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
|
||||
@@ -31,7 +31,7 @@ class CUDAGraph:
|
||||
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
|
||||
|
||||
c_deps = (type(graph_node)*1)(*(graph_node,)) if graph_node is not None else None
|
||||
c_kernel_input_config, c_input_params = encode_args_cuda_style([cast(Buffer, x)._buf for x in ji.rawbufs], [var_vals[x] for x in prg.vars], *self.encode_args_info()) # noqa: E501
|
||||
c_input_params, c_kernel_input_config = encode_args([cast(Buffer, x)._buf for x in ji.rawbufs], [var_vals[x] for x in prg.vars])
|
||||
c_node_params = self.build_kernel_node_params(prg, *cast(Tuple[List[int], List[int]], prg.launch_dims(var_vals)), c_kernel_input_config)
|
||||
graph_node = self.graph_add_kernel_node(self.graph, c_deps, c_node_params)
|
||||
|
||||
@@ -49,7 +49,7 @@ class CUDAGraph:
|
||||
# Update var_vals in the c_input_params struct.
|
||||
for j in self.jc_idxs_with_updatable_var_vals:
|
||||
for i,v in enumerate(cast(CompiledASTRunner, self.jit_cache[j].prg).vars):
|
||||
setattr(self.updatable_nodes[j][2], f'f{len(self.jit_cache[j].rawbufs) + i}', var_vals[v])
|
||||
setattr(self.updatable_nodes[j][2], f'v{i}', var_vals[v])
|
||||
|
||||
# Update launch dims in the c_node_params struct.
|
||||
for j in self.jc_idxs_with_updatable_launch_dims:
|
||||
@@ -69,7 +69,6 @@ class CUDAGraph:
|
||||
if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance))
|
||||
|
||||
def set_device(self): check(cuda.cuCtxSetCurrent(self.device.context))
|
||||
def encode_args_info(self): return (cuda.CUdeviceptr_v2, (1,2,0))
|
||||
def graph_create(self): return init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
|
||||
def graph_instantiate(self, graph):
|
||||
return init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), graph, None, None, 0)))
|
||||
|
||||
@@ -3,7 +3,7 @@ import subprocess, hashlib, tempfile, ctypes, ctypes.util, functools, re
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
import tinygrad.runtime.autogen.cuda as cuda
|
||||
from tinygrad.helpers import DEBUG, getenv, from_mv, to_char_p_p, init_c_var, colored, cpu_time_execution, encode_args_cuda_style, time_execution_cuda_style # noqa: E501
|
||||
from tinygrad.helpers import DEBUG, getenv, from_mv, to_char_p_p, init_c_var, init_c_struct_t, colored, cpu_time_execution
|
||||
from tinygrad.device import Compiled, LRUAllocator, MallocAllocator, Compiler
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
@@ -28,7 +28,24 @@ if CUDACPU:
|
||||
def check(status):
|
||||
if status != 0: raise RuntimeError(f"CUDA Error {status}, {ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(), lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()}") # noqa: E501
|
||||
|
||||
def cu_time_execution(cb, enable=False) -> Optional[float]: return time_execution_cuda_style(cb, cuda.CUevent, cuda.cuEventCreate, cuda.cuEventRecord, cuda.cuEventSynchronize, cuda.cuEventDestroy_v2, cuda.cuEventElapsedTime, enable=enable) if not CUDACPU else cpu_time_execution(cb, enable=enable) # noqa: E501
|
||||
def encode_args(args, vals) -> Tuple[ctypes.Structure, ctypes.Array]:
|
||||
c_args = init_c_struct_t(tuple([(f'f{i}', cuda.CUdeviceptr_v2) for i in range(len(args))] +
|
||||
[(f'v{i}', ctypes.c_int) for i in range(len(vals))]))(*args, *vals)
|
||||
vargs = (ctypes.c_void_p * 5)(ctypes.c_void_p(1), ctypes.cast(ctypes.byref(c_args), ctypes.c_void_p), ctypes.c_void_p(2),
|
||||
ctypes.cast(ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(c_args))), ctypes.c_void_p), ctypes.c_void_p(0))
|
||||
return c_args, vargs
|
||||
|
||||
def cu_time_execution(cb, enable=False) -> Optional[float]:
|
||||
if CUDACPU: return cpu_time_execution(cb, enable=enable)
|
||||
if not enable: return cb()
|
||||
evs = [init_c_var(cuda.CUevent(), lambda x: cuda.cuEventCreate(ctypes.byref(x), 0)) for _ in range(2)]
|
||||
cuda.cuEventRecord(evs[0], None)
|
||||
cb()
|
||||
cuda.cuEventRecord(evs[1], None)
|
||||
cuda.cuEventSynchronize(evs[1])
|
||||
cuda.cuEventElapsedTime(ctypes.byref(ret := ctypes.c_float()), evs[0], evs[1])
|
||||
for ev in evs: cuda.cuEventDestroy_v2(ev)
|
||||
return ret.value * 1e-3
|
||||
|
||||
def _get_bytes(arg, get_str, get_sz, check) -> bytes:
|
||||
sz = init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x))))
|
||||
@@ -73,7 +90,8 @@ class CUDAProgram:
|
||||
if DEBUG >= 5: print("\n".join([f"{i+1:>3} {line}" for i, line in enumerate(pretty_ptx(lib.decode('utf-8')).split("\n"))]))
|
||||
if DEBUG >= 6: cuda_disassemble(lib, device.arch)
|
||||
|
||||
if not CUDACPU:
|
||||
if CUDACPU: self.prg = lib
|
||||
else:
|
||||
check(cuda.cuCtxSetCurrent(self.device.context))
|
||||
self.module = cuda.CUmodule()
|
||||
status = cuda.cuModuleLoadData(ctypes.byref(self.module), lib)
|
||||
@@ -82,15 +100,21 @@ class CUDAProgram:
|
||||
cuda_disassemble(lib, device.arch)
|
||||
raise RuntimeError("module load failed")
|
||||
check(cuda.cuModuleGetFunction(ctypes.byref(prg := cuda.CUfunction()), self.module, name.encode("utf-8")))
|
||||
self.prg = prg if not CUDACPU else lib
|
||||
self.prg = prg #type: ignore
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'module'): check(cuda.cuModuleUnload(self.module))
|
||||
|
||||
def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
||||
if not CUDACPU: check(cuda.cuCtxSetCurrent(self.device.context))
|
||||
c_kernel_input_config = encode_args_cuda_style(bufs, vals, cuda.CUdeviceptr_v2, (1,2,0))[0] if not CUDACPU else (bufs+tuple(vals))
|
||||
return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, c_kernel_input_config)), enable=wait) # noqa: E501
|
||||
def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
||||
if CUDACPU: self.vargs = args+tuple(vals)
|
||||
else:
|
||||
check(cuda.cuCtxSetCurrent(self.device.context))
|
||||
if not hasattr(self, "vargs"):
|
||||
self.c_args, self.vargs = encode_args(args, vals) #type: ignore
|
||||
else:
|
||||
for i in range(len(args)): self.c_args.__setattr__(f'f{i}', args[i])
|
||||
for i in range(len(vals)): self.c_args.__setattr__(f'v{i}', vals[i])
|
||||
return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, self.vargs)), enable=wait)
|
||||
|
||||
class CUDAAllocator(LRUAllocator):
|
||||
def __init__(self, device:CUDADevice):
|
||||
|
||||
Reference in New Issue
Block a user