do not create structs every call in CUDAProgram (#3855)

* do not create structs in cuda

* fix graph

* linter

* do not exec twice

* fix graph
This commit is contained in:
nimlgen
2024-03-21 17:51:40 +03:00
committed by GitHub
parent e5745c1a0d
commit b78352b423
3 changed files with 36 additions and 30 deletions

View File

@@ -1,9 +1,9 @@
import ctypes
from typing import Any, Optional, Tuple, Dict, List, cast
import tinygrad.runtime.autogen.cuda as cuda
from tinygrad.helpers import init_c_var, encode_args_cuda_style, all_same, GraphException
from tinygrad.helpers import init_c_var, all_same, GraphException
from tinygrad.device import CompiledASTRunner, update_stats, Buffer
from tinygrad.runtime.ops_cuda import check, cu_time_execution
from tinygrad.runtime.ops_cuda import check, cu_time_execution, encode_args
from tinygrad.shape.symbolic import Variable
from tinygrad.features.jit import JitItem, get_input_replace, get_jit_stats, \
get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
@@ -31,7 +31,7 @@ class CUDAGraph:
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
c_deps = (type(graph_node)*1)(*(graph_node,)) if graph_node is not None else None
c_kernel_input_config, c_input_params = encode_args_cuda_style([cast(Buffer, x)._buf for x in ji.rawbufs], [var_vals[x] for x in prg.vars], *self.encode_args_info()) # noqa: E501
c_input_params, c_kernel_input_config = encode_args([cast(Buffer, x)._buf for x in ji.rawbufs], [var_vals[x] for x in prg.vars])
c_node_params = self.build_kernel_node_params(prg, *cast(Tuple[List[int], List[int]], prg.launch_dims(var_vals)), c_kernel_input_config)
graph_node = self.graph_add_kernel_node(self.graph, c_deps, c_node_params)
@@ -49,7 +49,7 @@ class CUDAGraph:
# Update var_vals in the c_input_params struct.
for j in self.jc_idxs_with_updatable_var_vals:
for i,v in enumerate(cast(CompiledASTRunner, self.jit_cache[j].prg).vars):
setattr(self.updatable_nodes[j][2], f'f{len(self.jit_cache[j].rawbufs) + i}', var_vals[v])
setattr(self.updatable_nodes[j][2], f'v{i}', var_vals[v])
# Update launch dims in the c_node_params struct.
for j in self.jc_idxs_with_updatable_launch_dims:
@@ -69,7 +69,6 @@ class CUDAGraph:
if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance))
def set_device(self): check(cuda.cuCtxSetCurrent(self.device.context))
def encode_args_info(self): return (cuda.CUdeviceptr_v2, (1,2,0))
def graph_create(self): return init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
def graph_instantiate(self, graph):
return init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), graph, None, None, 0)))