From b78352b4236f91edfb4593c57695593ffe873b3a Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Thu, 21 Mar 2024 17:51:40 +0300 Subject: [PATCH] do not create structs every call in CUDAProgram (#3855) * do not create structs in cuda * fix graph * linter * do not exec twice * fix graph --- tinygrad/helpers.py | 17 --------------- tinygrad/runtime/graph/cuda.py | 9 ++++---- tinygrad/runtime/ops_cuda.py | 40 +++++++++++++++++++++++++++------- 3 files changed, 36 insertions(+), 30 deletions(-) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index b92fb19e56..8629de7995 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -213,20 +213,3 @@ def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]): return CStruct def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1] def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,)) - -# *** Helpers for CUDA-like APIs. - -def encode_args_cuda_style(bufs, vals, device_ptr_t, marks) -> Tuple[ctypes.Array, ctypes.Structure]: - c_args = init_c_struct_t(tuple([(f'f{i}', device_ptr_t) for i in range(len(bufs))] + [(f'f{i}', ctypes.c_int) for i in range(len(bufs), len(bufs)+len(vals))]))(*bufs, *vals) # noqa: E501 - return (ctypes.c_void_p * 5)(ctypes.c_void_p(marks[0]), ctypes.cast(ctypes.pointer(c_args), ctypes.c_void_p), ctypes.c_void_p(marks[1]), ctypes.cast(ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(c_args))), ctypes.c_void_p), ctypes.c_void_p(marks[2])), c_args # noqa: E501 - -def time_execution_cuda_style(cb, ev_t, evcreate, evrecord, evsync, evdestroy, evtime, enable=False) -> Optional[float]: - if not enable: return cb() - evs = [init_c_var(ev_t(), lambda x: evcreate(ctypes.byref(x), 0)) for _ in range(2)] - evrecord(evs[0], None) - cb() - evrecord(evs[1], None) - evsync(evs[1]) - evtime(ctypes.byref(ret := ctypes.c_float()), evs[0], evs[1]) - for ev in evs: evdestroy(ev) - return ret.value * 1e-3 diff --git a/tinygrad/runtime/graph/cuda.py b/tinygrad/runtime/graph/cuda.py index f8eda4f34b..9814be0c8e 100644 --- a/tinygrad/runtime/graph/cuda.py +++ b/tinygrad/runtime/graph/cuda.py @@ -1,9 +1,9 @@ import ctypes from typing import Any, Optional, Tuple, Dict, List, cast import tinygrad.runtime.autogen.cuda as cuda -from tinygrad.helpers import init_c_var, encode_args_cuda_style, all_same, GraphException +from tinygrad.helpers import init_c_var, all_same, GraphException from tinygrad.device import CompiledASTRunner, update_stats, Buffer -from tinygrad.runtime.ops_cuda import check, cu_time_execution +from tinygrad.runtime.ops_cuda import check, cu_time_execution, encode_args from tinygrad.shape.symbolic import Variable from tinygrad.features.jit import JitItem, get_input_replace, get_jit_stats, \ get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals @@ -31,7 +31,7 @@ class CUDAGraph: prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg) c_deps = (type(graph_node)*1)(*(graph_node,)) if graph_node is not None else None - c_kernel_input_config, c_input_params = encode_args_cuda_style([cast(Buffer, x)._buf for x in ji.rawbufs], [var_vals[x] for x in prg.vars], *self.encode_args_info()) # noqa: E501 + c_input_params, c_kernel_input_config = encode_args([cast(Buffer, x)._buf for x in ji.rawbufs], [var_vals[x] for x in prg.vars]) c_node_params = self.build_kernel_node_params(prg, *cast(Tuple[List[int], List[int]], prg.launch_dims(var_vals)), c_kernel_input_config) graph_node = self.graph_add_kernel_node(self.graph, c_deps, c_node_params) @@ -49,7 +49,7 @@ class CUDAGraph: # Update var_vals in the c_input_params struct. for j in self.jc_idxs_with_updatable_var_vals: for i,v in enumerate(cast(CompiledASTRunner, self.jit_cache[j].prg).vars): - setattr(self.updatable_nodes[j][2], f'f{len(self.jit_cache[j].rawbufs) + i}', var_vals[v]) + setattr(self.updatable_nodes[j][2], f'v{i}', var_vals[v]) # Update launch dims in the c_node_params struct. for j in self.jc_idxs_with_updatable_launch_dims: @@ -69,7 +69,6 @@ class CUDAGraph: if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance)) def set_device(self): check(cuda.cuCtxSetCurrent(self.device.context)) - def encode_args_info(self): return (cuda.CUdeviceptr_v2, (1,2,0)) def graph_create(self): return init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0))) def graph_instantiate(self, graph): return init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), graph, None, None, 0))) diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 62a991c15e..2405145f84 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -3,7 +3,7 @@ import subprocess, hashlib, tempfile, ctypes, ctypes.util, functools, re from pathlib import Path from typing import Tuple, Optional import tinygrad.runtime.autogen.cuda as cuda -from tinygrad.helpers import DEBUG, getenv, from_mv, to_char_p_p, init_c_var, colored, cpu_time_execution, encode_args_cuda_style, time_execution_cuda_style # noqa: E501 +from tinygrad.helpers import DEBUG, getenv, from_mv, to_char_p_p, init_c_var, init_c_struct_t, colored, cpu_time_execution from tinygrad.device import Compiled, LRUAllocator, MallocAllocator, Compiler from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cstyle import CUDARenderer @@ -28,7 +28,24 @@ if CUDACPU: def check(status): if status != 0: raise RuntimeError(f"CUDA Error {status}, {ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(), lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()}") # noqa: E501 -def cu_time_execution(cb, enable=False) -> Optional[float]: return time_execution_cuda_style(cb, cuda.CUevent, cuda.cuEventCreate, cuda.cuEventRecord, cuda.cuEventSynchronize, cuda.cuEventDestroy_v2, cuda.cuEventElapsedTime, enable=enable) if not CUDACPU else cpu_time_execution(cb, enable=enable) # noqa: E501 +def encode_args(args, vals) -> Tuple[ctypes.Structure, ctypes.Array]: + c_args = init_c_struct_t(tuple([(f'f{i}', cuda.CUdeviceptr_v2) for i in range(len(args))] + + [(f'v{i}', ctypes.c_int) for i in range(len(vals))]))(*args, *vals) + vargs = (ctypes.c_void_p * 5)(ctypes.c_void_p(1), ctypes.cast(ctypes.byref(c_args), ctypes.c_void_p), ctypes.c_void_p(2), + ctypes.cast(ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(c_args))), ctypes.c_void_p), ctypes.c_void_p(0)) + return c_args, vargs + +def cu_time_execution(cb, enable=False) -> Optional[float]: + if CUDACPU: return cpu_time_execution(cb, enable=enable) + if not enable: return cb() + evs = [init_c_var(cuda.CUevent(), lambda x: cuda.cuEventCreate(ctypes.byref(x), 0)) for _ in range(2)] + cuda.cuEventRecord(evs[0], None) + cb() + cuda.cuEventRecord(evs[1], None) + cuda.cuEventSynchronize(evs[1]) + cuda.cuEventElapsedTime(ctypes.byref(ret := ctypes.c_float()), evs[0], evs[1]) + for ev in evs: cuda.cuEventDestroy_v2(ev) + return ret.value * 1e-3 def _get_bytes(arg, get_str, get_sz, check) -> bytes: sz = init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x)))) @@ -73,7 +90,8 @@ class CUDAProgram: if DEBUG >= 5: print("\n".join([f"{i+1:>3} {line}" for i, line in enumerate(pretty_ptx(lib.decode('utf-8')).split("\n"))])) if DEBUG >= 6: cuda_disassemble(lib, device.arch) - if not CUDACPU: + if CUDACPU: self.prg = lib + else: check(cuda.cuCtxSetCurrent(self.device.context)) self.module = cuda.CUmodule() status = cuda.cuModuleLoadData(ctypes.byref(self.module), lib) @@ -82,15 +100,21 @@ class CUDAProgram: cuda_disassemble(lib, device.arch) raise RuntimeError("module load failed") check(cuda.cuModuleGetFunction(ctypes.byref(prg := cuda.CUfunction()), self.module, name.encode("utf-8"))) - self.prg = prg if not CUDACPU else lib + self.prg = prg #type: ignore def __del__(self): if hasattr(self, 'module'): check(cuda.cuModuleUnload(self.module)) - def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False): - if not CUDACPU: check(cuda.cuCtxSetCurrent(self.device.context)) - c_kernel_input_config = encode_args_cuda_style(bufs, vals, cuda.CUdeviceptr_v2, (1,2,0))[0] if not CUDACPU else (bufs+tuple(vals)) - return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, c_kernel_input_config)), enable=wait) # noqa: E501 + def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False): + if CUDACPU: self.vargs = args+tuple(vals) + else: + check(cuda.cuCtxSetCurrent(self.device.context)) + if not hasattr(self, "vargs"): + self.c_args, self.vargs = encode_args(args, vals) #type: ignore + else: + for i in range(len(args)): self.c_args.__setattr__(f'f{i}', args[i]) + for i in range(len(vals)): self.c_args.__setattr__(f'v{i}', vals[i]) + return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, self.vargs)), enable=wait) class CUDAAllocator(LRUAllocator): def __init__(self, device:CUDADevice):