From 16e31f7f0d8ab4fb0b85d172ae21acdf6e3fd019 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Fri, 22 Mar 2024 23:49:48 +0300 Subject: [PATCH] init multidevice cuda graph (#3858) * init multidevice cuda graph * cuda just works! * clean * linter happier * liners happy * update transfer inputs * do not change free * useless check for cuda --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> --- tinygrad/features/jit.py | 3 +- tinygrad/runtime/graph/cuda.py | 101 ++++++++++++++++++++------------- 2 files changed, 63 insertions(+), 41 deletions(-) diff --git a/tinygrad/features/jit.py b/tinygrad/features/jit.py index 11d0afc6b8..b36ccd35c6 100644 --- a/tinygrad/features/jit.py +++ b/tinygrad/features/jit.py @@ -57,7 +57,8 @@ def apply_graph_to_jit(jit_cache: List[JitItem], input_rawbuffers: List[Buffer], for ji in jit_cache: ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None. if isinstance(ji.prg, CompiledASTRunner): ji_graph_dev = ji.prg.device - elif isinstance(ji.prg, BufferXfer) and ji.rawbufs[0] and ji.rawbufs[0].d.dname.startswith("HSA"): ji_graph_dev = ji.rawbufs[0].d + elif isinstance(ji.prg, BufferXfer) and ji.rawbufs[0] and ji.rawbufs[0].d.dname.split(":", 1)[0] in {"HSA", "CUDA"}: + ji_graph_dev = ji.rawbufs[0].d can_be_graphed = ji_graph_dev and ji_graph_dev.graph can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and (ji_graph_dev == current_device or diff --git a/tinygrad/runtime/graph/cuda.py b/tinygrad/runtime/graph/cuda.py index 9814be0c8e..3fab54513a 100644 --- a/tinygrad/runtime/graph/cuda.py +++ b/tinygrad/runtime/graph/cuda.py @@ -1,19 +1,17 @@ -import ctypes +import ctypes, collections from typing import Any, Optional, Tuple, Dict, List, cast import tinygrad.runtime.autogen.cuda as cuda -from tinygrad.helpers import init_c_var, all_same, GraphException -from tinygrad.device import CompiledASTRunner, update_stats, Buffer -from tinygrad.runtime.ops_cuda import check, cu_time_execution, encode_args +from tinygrad.helpers import init_c_var, GraphException +from tinygrad.device import CompiledASTRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer +from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution from tinygrad.shape.symbolic import Variable from tinygrad.features.jit import JitItem, get_input_replace, get_jit_stats, \ get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals -class CUDAGraph: +class CUDAGraph(MultiDeviceJITGraph): def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): - devices = [ji.prg.clprg.device if isinstance(ji.prg, CompiledASTRunner) else None for ji in jit_cache] - if len(devices) == 0 or not all_same(devices) or devices[0] is None: raise GraphException - self.device = devices[0] - self.set_device() + # Check all jit items are compatible. + if not all(isinstance(ji.prg, CompiledASTRunner) or isinstance(ji.prg, BufferXfer) for ji in jit_cache): raise GraphException self.jit_cache = jit_cache self.input_replace = get_input_replace(jit_cache, input_rawbuffers) @@ -21,62 +19,85 @@ class CUDAGraph: self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache) self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache) self.jc_idxs_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()])) - self.updatable_nodes: Dict[int, Tuple[Any, Any, Any]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params) + self.updatable_nodes: Dict[int, Tuple[Any, Any, Any, bool]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy) - self.graph = self.graph_create() - graph_node: Optional[ctypes._CData] = None + self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0))) + self.w_dependency_map: Dict[Any, Any] = {} + self.r_dependency_map: Dict[Any, List[Any]] = collections.defaultdict(list) - for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name] for j,ji in enumerate(self.jit_cache): - prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg) + if isinstance(ji.prg, CompiledASTRunner): + global_size, local_size = ji.prg.launch_dims(var_vals) - c_deps = (type(graph_node)*1)(*(graph_node,)) if graph_node is not None else None - c_input_params, c_kernel_input_config = encode_args([cast(Buffer, x)._buf for x in ji.rawbufs], [var_vals[x] for x in prg.vars]) - c_node_params = self.build_kernel_node_params(prg, *cast(Tuple[List[int], List[int]], prg.launch_dims(var_vals)), c_kernel_input_config) - graph_node = self.graph_add_kernel_node(self.graph, c_deps, c_node_params) + new_node = cuda.CUgraphNode() + deps = self.access_resources(ji.rawbufs[(outs:=ji.prg.outcount):], ji.rawbufs[:outs], new_dependency=new_node) + c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None - if j in self.jc_idxs_with_updatable_launch_dims or j in self.jc_idxs_with_updatable_var_vals or j in self.jc_idxs_with_updatable_rawbufs: - self.updatable_nodes[j] = (graph_node, c_node_params, c_input_params) + c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.rawbufs], [var_vals[x] for x in ji.prg.vars]) + kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg.clprg.prg, *global_size, *local_size, 0, None, vargs) + check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params))) - self.instance = self.graph_instantiate(self.graph) + if j in self.jc_idxs_with_updatable_launch_dims or j in self.jc_idxs_with_updatable_var_vals or j in self.jc_idxs_with_updatable_rawbufs: + self.updatable_nodes[j] = (new_node, kern_params, c_args, False) + elif isinstance(ji.prg, BufferXfer): + dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]] + src_dev = cast(CUDADevice, src.d) + + new_node = cuda.CUgraphNode() + deps = self.access_resources(read=[src], write=[dest], new_dependency=new_node) + c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None + + cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1, + dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1, + WidthInBytes=dest.nbytes, Height=1, Depth=1) + check(cuda.cuGraphAddMemcpyNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context)) + if j in self.jc_idxs_with_updatable_rawbufs: self.updatable_nodes[j] = (new_node, cp_params, src_dev.context, True) + + self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0))) def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]: - self.set_device() - # Update rawbuffers in the c_input_params struct. + # Update rawbuffers in the c_args struct. for (j,i),input_idx in self.input_replace.items(): - setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf) + if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf) + else: + if i == 0: self.updatable_nodes[j][1].destDevice = input_rawbuffers[input_idx]._buf + elif i == 1: self.updatable_nodes[j][1].srcDevice = input_rawbuffers[input_idx]._buf - # Update var_vals in the c_input_params struct. + # Update var_vals in the c_args struct. for j in self.jc_idxs_with_updatable_var_vals: for i,v in enumerate(cast(CompiledASTRunner, self.jit_cache[j].prg).vars): setattr(self.updatable_nodes[j][2], f'v{i}', var_vals[v]) - # Update launch dims in the c_node_params struct. + # Update launch dims in the kern_params struct. for j in self.jc_idxs_with_updatable_launch_dims: self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals)) # Update graph nodes with the updated structs. - for node, c_node_params, _ in self.updatable_nodes.values(): - self.graph_exec_kernel_node_set_params(self.instance, node, ctypes.byref(c_node_params)) + for node, c_node_params, c_args, is_copy in self.updatable_nodes.values(): + if not is_copy: check(cuda.cuGraphExecKernelNodeSetParams(self.instance, node, ctypes.byref(c_node_params))) + else: check(cuda.cuGraphExecMemcpyNodeSetParams(self.instance, node, ctypes.byref(c_node_params), c_args)) - et = self.graph_launch(self.instance, None, wait=wait) + et = cu_time_execution(lambda: check(cuda.cuGraphLaunch(self.instance, None)), enable=wait) update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), - jit=jit, num_kernels=len(self.jit_cache), device=f":{self.device}") + jit=jit, num_kernels=len(self.jit_cache), device="CUDA") return et def __del__(self): if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph)) if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance)) - def set_device(self): check(cuda.cuCtxSetCurrent(self.device.context)) - def graph_create(self): return init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0))) - def graph_instantiate(self, graph): - return init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), graph, None, None, 0))) - def graph_add_kernel_node(self, graph, c_deps, c_node_params): - return init_c_var(cuda.CUgraphNode(), lambda x: check(cuda.cuGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_node_params)))) # noqa: E501 - def graph_launch(self, *args, wait=False): return cu_time_execution(lambda: check(cuda.cuGraphLaunch(*args)), enable=wait) - def graph_exec_kernel_node_set_params(self, *args): return check(cuda.cuGraphExecKernelNodeSetParams(*args)) - def build_kernel_node_params(self, prg, global_size, local_size, c_kernel_config): - return cuda.CUDA_KERNEL_NODE_PARAMS(prg.clprg.prg, *global_size, *local_size, 0, None, c_kernel_config) def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]): node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size + + def access_resources(self, read, write, new_dependency): + wait_nodes = [] + + for rawbuf in read + write: + if rawbuf._buf.value in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[rawbuf._buf.value]) + for rawbuf in write: + if rawbuf._buf.value in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(rawbuf._buf.value)) + + if new_dependency is not None: + for rawbuf in read: self.r_dependency_map[rawbuf._buf.value].append(new_dependency) + for rawbuf in write: self.w_dependency_map[rawbuf._buf.value] = new_dependency + return {id(x):x for x in wait_nodes}.values()