diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 0acb00b616..74c53da609 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -1,10 +1,10 @@ import subprocess, time, re, hashlib, tempfile, functools from pathlib import Path -from typing import Optional +from typing import Optional, List, Any, Tuple import numpy as np from pycuda.compiler import compile as cuda_compile # type: ignore from tinygrad.helpers import DEBUG, getenv, colored, fromimport -from tinygrad.ops import Compiled +from tinygrad.ops import Compiled, GraphBatchExecutor, ASTRunner from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage @@ -57,6 +57,37 @@ else: def _copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._buf, x.ravel(), stream) # type: ignore def _copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._buf) # type: ignore +class CUDAGraph(GraphBatchExecutor): + def __init__(self, jit_cache: List[Tuple[Any, Any, Any]]): + super().__init__(jit_cache) + self.jc_info: List[Any] = [] + + # Check if CUDAGraph could run the given jit_cache. + if DEBUG>0 or getenv("CUDACPU") or not all(isinstance(prg, ASTRunner) and isinstance(prg.clprg, CUDAProgram) for prg,_,_ in jit_cache): return # Only CUDAProgram can be captured. + self.split_into_graphs(jit_cache) + + def create_graph(self, jit_cache: List[Tuple[Any, Any, Any]]): + try: + graph, graph_node = cuda.Graph(), None # type: ignore + + for prg, pargs, variables in jit_cache: + global_size, local_size = prg.launch_dims(variables) + cuda_args = [x._buf if isinstance(x, RawCUDABuffer) else np.int32(x) for x in [*pargs, *variables.values()]] + graph_node = graph.add_kernel_node(*cuda_args, block=tuple(local_size), grid=tuple(global_size), func=prg.clprg.prg, dependencies=[graph_node] if graph_node else []) + self.jc_info.append(graph_node) + + self.graphs.append((graph.instantiate(), graph)) + except Exception as e: + # CudaGraph might not be suported with the installed version of pycuda. + if DEBUG>=3: print(f"Error creating CUDAGraph: {e}") + + def update_node(self, instid, jcid, prg, pargs, variables, updated_args=None): + global_size, local_size = prg.launch_dims(variables) + cuda_args = [x._buf if isinstance(x, RawCUDABuffer) else np.int32(x) for x in [*pargs, *variables.values()]] + self.graphs[instid][0].kernel_node_set_params(*cuda_args, block=tuple(local_size), grid=tuple(global_size), func=prg.clprg.prg, kernel_node=self.jc_info[jcid]) + + def exec_instance(self, instid): self.graphs[instid][0].launch() + class CUDAProgram: def __init__(self, name:str, prg:str, binary=False, shared = 0, local_size_override=None): if not binary: @@ -103,4 +134,4 @@ if getenv("TRITON") == 1: renderer = uops_to_triton CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024], has_shared=False), renderer, CUDAProgram, cuda.Context.synchronize) else: - CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False if getenv("PTX") else True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), renderer, CUDAProgram, cuda.Context.synchronize) + CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False if getenv("PTX") else True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), renderer, CUDAProgram, cuda.Context.synchronize, CUDAGraph)