mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
do not create structs every call in CUDAProgram (#3855)
* do not create structs in cuda * fix graph * linter * do not exec twice * fix graph
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import ctypes
|
||||
from typing import Any, Optional, Tuple, Dict, List, cast
|
||||
import tinygrad.runtime.autogen.cuda as cuda
|
||||
from tinygrad.helpers import init_c_var, encode_args_cuda_style, all_same, GraphException
|
||||
from tinygrad.helpers import init_c_var, all_same, GraphException
|
||||
from tinygrad.device import CompiledASTRunner, update_stats, Buffer
|
||||
from tinygrad.runtime.ops_cuda import check, cu_time_execution
|
||||
from tinygrad.runtime.ops_cuda import check, cu_time_execution, encode_args
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.features.jit import JitItem, get_input_replace, get_jit_stats, \
|
||||
get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
|
||||
@@ -31,7 +31,7 @@ class CUDAGraph:
|
||||
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
|
||||
|
||||
c_deps = (type(graph_node)*1)(*(graph_node,)) if graph_node is not None else None
|
||||
c_kernel_input_config, c_input_params = encode_args_cuda_style([cast(Buffer, x)._buf for x in ji.rawbufs], [var_vals[x] for x in prg.vars], *self.encode_args_info()) # noqa: E501
|
||||
c_input_params, c_kernel_input_config = encode_args([cast(Buffer, x)._buf for x in ji.rawbufs], [var_vals[x] for x in prg.vars])
|
||||
c_node_params = self.build_kernel_node_params(prg, *cast(Tuple[List[int], List[int]], prg.launch_dims(var_vals)), c_kernel_input_config)
|
||||
graph_node = self.graph_add_kernel_node(self.graph, c_deps, c_node_params)
|
||||
|
||||
@@ -49,7 +49,7 @@ class CUDAGraph:
|
||||
# Update var_vals in the c_input_params struct.
|
||||
for j in self.jc_idxs_with_updatable_var_vals:
|
||||
for i,v in enumerate(cast(CompiledASTRunner, self.jit_cache[j].prg).vars):
|
||||
setattr(self.updatable_nodes[j][2], f'f{len(self.jit_cache[j].rawbufs) + i}', var_vals[v])
|
||||
setattr(self.updatable_nodes[j][2], f'v{i}', var_vals[v])
|
||||
|
||||
# Update launch dims in the c_node_params struct.
|
||||
for j in self.jc_idxs_with_updatable_launch_dims:
|
||||
@@ -69,7 +69,6 @@ class CUDAGraph:
|
||||
if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance))
|
||||
|
||||
def set_device(self): check(cuda.cuCtxSetCurrent(self.device.context))
|
||||
def encode_args_info(self): return (cuda.CUdeviceptr_v2, (1,2,0))
|
||||
def graph_create(self): return init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
|
||||
def graph_instantiate(self, graph):
|
||||
return init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), graph, None, None, 0)))
|
||||
|
||||
Reference in New Issue
Block a user