clean up update stats (#4226)

* WIP: clean up update stats

* line savings now

* fix graphs

* fix tests

* tighter prints

* remove extra jit=false

* debug=2 means wait

* that won't update stats

* still wait
This commit is contained in:
George Hotz
2024-04-19 15:41:30 +04:00
committed by GitHub
parent 1c87e5dbf6
commit b9570d6100
12 changed files with 70 additions and 87 deletions

View File

@@ -1,8 +1,8 @@
import ctypes, collections
from typing import Any, Optional, Tuple, Dict, List, cast
import tinygrad.runtime.autogen.cuda as cuda
from tinygrad.helpers import init_c_var, GraphException, getenv
from tinygrad.device import CompiledRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions
from tinygrad.helpers import init_c_var, GraphException, getenv, colored
from tinygrad.device import CompiledRunner, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
from tinygrad.shape.symbolic import Variable
from tinygrad.engine.realize import ExecItem
@@ -15,7 +15,6 @@ class CUDAGraph(MultiDeviceJITGraph):
self.jit_cache = jit_cache
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache)
self.jc_idxs_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
@@ -70,8 +69,9 @@ class CUDAGraph(MultiDeviceJITGraph):
# clear jit inputs to allow their memory to be freed/reused
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), "CUDA", *get_jit_stats(jit_cache))
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
# Update rawbuffers in the c_args struct.
for (j,i),input_idx in self.input_replace.items():
if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
@@ -93,10 +93,7 @@ class CUDAGraph(MultiDeviceJITGraph):
if not is_copy: check(cuda.cuGraphExecKernelNodeSetParams(self.instance, node, ctypes.byref(c_node_params)))
else: check(cuda.cuGraphExecMemcpyNodeSetParams(self.instance, node, ctypes.byref(c_node_params), c_args))
et = cu_time_execution(lambda: check(cuda.cuGraphLaunch(self.instance, None)), enable=wait)
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers),
jit=jit, num_kernels=len(self.jit_cache), device="CUDA")
return et
return cu_time_execution(lambda: check(cuda.cuGraphLaunch(self.instance, None)), enable=wait)
def __del__(self):
if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph))