mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-13 08:05:10 -05:00
clean up update stats (#4226)
* WIP: clean up update stats * line savings now * fix graphs * fix tests * tighter prints * remove extra jit=false * debug=2 means wait * that won't update stats * still wait
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
import ctypes, collections
|
||||
from typing import Any, Optional, Tuple, Dict, List, cast
|
||||
import tinygrad.runtime.autogen.cuda as cuda
|
||||
from tinygrad.helpers import init_c_var, GraphException, getenv
|
||||
from tinygrad.device import CompiledRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions
|
||||
from tinygrad.helpers import init_c_var, GraphException, getenv, colored
|
||||
from tinygrad.device import CompiledRunner, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions
|
||||
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
@@ -15,7 +15,6 @@ class CUDAGraph(MultiDeviceJITGraph):
|
||||
|
||||
self.jit_cache = jit_cache
|
||||
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
||||
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
|
||||
self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
|
||||
self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache)
|
||||
self.jc_idxs_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
|
||||
@@ -70,8 +69,9 @@ class CUDAGraph(MultiDeviceJITGraph):
|
||||
|
||||
# clear jit inputs to allow their memory to be freed/reused
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), "CUDA", *get_jit_stats(jit_cache))
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
# Update rawbuffers in the c_args struct.
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
|
||||
@@ -93,10 +93,7 @@ class CUDAGraph(MultiDeviceJITGraph):
|
||||
if not is_copy: check(cuda.cuGraphExecKernelNodeSetParams(self.instance, node, ctypes.byref(c_node_params)))
|
||||
else: check(cuda.cuGraphExecMemcpyNodeSetParams(self.instance, node, ctypes.byref(c_node_params), c_args))
|
||||
|
||||
et = cu_time_execution(lambda: check(cuda.cuGraphLaunch(self.instance, None)), enable=wait)
|
||||
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers),
|
||||
jit=jit, num_kernels=len(self.jit_cache), device="CUDA")
|
||||
return et
|
||||
return cu_time_execution(lambda: check(cuda.cuGraphLaunch(self.instance, None)), enable=wait)
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph))
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import ctypes, collections, time, itertools
|
||||
from typing import List, Any, Dict, cast, Optional, Union, Tuple
|
||||
from tinygrad.helpers import GraphException, init_c_var, round_up
|
||||
from tinygrad.helpers import GraphException, init_c_var, round_up, colored
|
||||
from tinygrad.buffer import Buffer, BufferOptions
|
||||
from tinygrad.device import Compiled, CompiledRunner, BufferXfer, MultiDeviceJITGraph, update_stats, Device
|
||||
from tinygrad.device import Compiled, CompiledRunner, BufferXfer, MultiDeviceJITGraph, Device
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
@@ -29,7 +29,6 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
self.jit_cache = jit_cache
|
||||
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
||||
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache) #type:ignore
|
||||
self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
|
||||
self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache)
|
||||
|
||||
@@ -114,8 +113,9 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
|
||||
# clear jit inputs to allow their memory to be freed/reused
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), "HSA", *get_jit_stats(jit_cache))
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
# Wait and restore signals
|
||||
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 1)
|
||||
@@ -158,8 +158,6 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
et = time.perf_counter() - st
|
||||
|
||||
for profdev,profdata in self.profile_info.items(): Profiler.tracked_signals[profdev] += profdata
|
||||
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers),
|
||||
jit=jit, num_kernels=len(self.jit_cache), device="HSA")
|
||||
return et
|
||||
|
||||
def alloc_signal(self, reset_on_start=False, wait_on=None):
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
from typing import List, Any, Dict, cast, Optional
|
||||
import Metal
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import dedup, unwrap2, GraphException
|
||||
from tinygrad.device import Buffer, CompiledRunner, update_stats
|
||||
from tinygrad.helpers import dedup, unwrap2, GraphException, colored
|
||||
from tinygrad.device import Buffer, CompiledRunner, Runner
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.runtime.ops_metal import MetalDevice, wait_check
|
||||
|
||||
class MetalGraph:
|
||||
class MetalGraph(Runner):
|
||||
def __init__(self, device:MetalDevice, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
||||
|
||||
self.jit_cache = jit_cache
|
||||
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
||||
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
|
||||
self.jc_idx_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
|
||||
self.device: MetalDevice = device
|
||||
|
||||
@@ -54,8 +53,9 @@ class MetalGraph:
|
||||
|
||||
# clear jit inputs to allow their memory to be freed/reused
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), device.dname, *get_jit_stats(jit_cache))
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
# NOTE: you at least can't update the ints if this is running
|
||||
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
||||
all_resources = self.all_resources + [x._buf for x in input_rawbuffers]
|
||||
@@ -74,9 +74,6 @@ class MetalGraph:
|
||||
self.command_buffer = command_buffer
|
||||
if wait:
|
||||
wait_check(command_buffer)
|
||||
et = command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
||||
else:
|
||||
self.device.mtl_buffers_in_flight.append(command_buffer)
|
||||
et = None
|
||||
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache)) # noqa: E501
|
||||
return et
|
||||
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
||||
self.device.mtl_buffers_in_flight.append(command_buffer)
|
||||
return None
|
||||
@@ -52,7 +52,8 @@ class DiskRunner(Runner):
|
||||
assert strides_for_shape(view.shape) == view.strides, "disk tensors don't support strides"
|
||||
self.new_size = prod(view.shape)
|
||||
self.new_offset = view.offset * top_src.arg.dtype.itemsize
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Any, int], wait=False, jit=False):
|
||||
super().__init__(f"sz 0x{self.new_size:X} offset 0x{self.new_offset:X}", "DISK")
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Any, int], wait=False):
|
||||
assert len(rawbufs) == 2
|
||||
# TODO: this is a terrible hack that should be moved to lazy.py
|
||||
rawbufs[0]._buf.offset = rawbufs[1]._buf.offset+self.new_offset
|
||||
|
||||
Reference in New Issue
Block a user