mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
add dependency viewer to hcq profiler (#5874)
* hcq profiler support deps * clean up * cleaner * cleanup * revert this * linter * mypy * add test * sync is strange, need to take the end * linter + test
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import unittest, ctypes, struct, contextlib, tempfile, pathlib, json, time, atexit, random
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad import Device, Tensor, dtypes, TinyJit
|
||||
from tinygrad.helpers import CI, getenv, Context
|
||||
from tinygrad.device import Buffer, BufferOptions, HCQCompiled
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
@@ -397,7 +397,7 @@ def helper_collect_profile(*devs, random_setup_delay=False):
|
||||
def helper_profile_filter_node(profile, **kwargs):
|
||||
assert len(profile) > 0, "Empty profile"
|
||||
assert 'traceEvents' in profile, "traceEvents should present"
|
||||
return [x for x in profile['traceEvents'] if all(x[k] == v for k,v in kwargs.items())]
|
||||
return [x for x in profile['traceEvents'] if all(x.get(k, None) == v for k,v in kwargs.items())]
|
||||
|
||||
def helper_profile_parse_pids(profile):
|
||||
pids, tids = {}, {}
|
||||
@@ -407,6 +407,20 @@ def helper_profile_parse_pids(profile):
|
||||
for th in threads: tids[th['tid']] = th['args']['name']
|
||||
return pids, tids
|
||||
|
||||
def helper_profile_parse_deps(profile):
|
||||
deps = []
|
||||
for s in helper_profile_filter_node(profile, ph="s"):
|
||||
f = helper_profile_filter_node(profile, ph="f", id=s['id'])[0]
|
||||
|
||||
starts, ends = [], []
|
||||
for x in helper_profile_filter_node(profile, ph="X"):
|
||||
if s['pid'] == x['pid'] and s['tid'] == x['tid'] and x['ts'] <= s['ts'] <= x['ts'] + x['dur']: starts.append(x)
|
||||
if f['pid'] == x['pid'] and f['tid'] == x['tid'] and x['ts'] <= f['ts'] <= x['ts'] + x['dur']: ends.append(x)
|
||||
|
||||
assert len(starts) == 1 and len(ends) == 1, "more than one start and end possible, valid?"
|
||||
deps.append((s, f, starts[0], ends[0]))
|
||||
return deps
|
||||
|
||||
def helper_validate_node(node, duration_s=10, ts_age_s=30, profile=None, pid_name=None, tid_name=None):
|
||||
pids, tids = helper_profile_parse_pids(profile)
|
||||
assert abs(node['ts'] - time.perf_counter_ns() / 1e3) < ts_age_s * 1e6, "timestimp is not in 30s range"
|
||||
@@ -483,6 +497,27 @@ class TestProfiler(unittest.TestCase):
|
||||
copyin_node_2 = helper_profile_filter_node(profile, name=f"CPU -> {Device.DEFAULT}:1")[0]
|
||||
helper_validate_node(copyin_node_2, profile=profile, pid_name=f"{Device.DEFAULT}:1", tid_name="DMA")
|
||||
|
||||
@unittest.skipIf(MOCKGPU and Device.DEFAULT == "AMD", "AMD mockgpu with indirect buffers does not support queue wait interrupts")
|
||||
def test_profile_deps(self):
|
||||
d1 = Device[f"{Device.DEFAULT}:1"]
|
||||
|
||||
def f(a):
|
||||
x = (a + 1).realize()
|
||||
return x, x.to(d1.dname).realize()
|
||||
|
||||
a = Tensor.randn(10, 10, device=TestProfiler.d0.dname).realize()
|
||||
with helper_collect_profile(TestProfiler.d0, d1) as profile:
|
||||
jf = TinyJit(f)
|
||||
for _ in range(3): jf(a)
|
||||
del jf
|
||||
|
||||
deps = helper_profile_parse_deps(profile)
|
||||
assert len(deps) == 1, "one dep is expected, one launch"
|
||||
|
||||
_, _, l, r = deps[0]
|
||||
assert l['name'].find("->") == -1, "should be kernel"
|
||||
assert r['name'] == f"{Device.DEFAULT} -> {Device.DEFAULT}:1", "should be copy"
|
||||
|
||||
@unittest.skipIf(CI, "skip CI")
|
||||
def test_profile_sync(self):
|
||||
mv = memoryview(bytearray(struct.pack("ff", 0, 1)))
|
||||
|
||||
@@ -489,6 +489,7 @@ class HCQCompiled(Compiled):
|
||||
self.timeline_signal, self._shadow_timeline_signal = timeline_signals
|
||||
self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = []
|
||||
self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool]] = []
|
||||
self.dep_prof_records:List[Any] = []
|
||||
if PROFILE: self._prof_setup()
|
||||
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
@@ -501,13 +502,9 @@ class HCQCompiled(Compiled):
|
||||
self.timeline_signal.wait(self.timeline_value - 1)
|
||||
|
||||
if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
|
||||
if PROFILE: self._prof_process_events()
|
||||
|
||||
def _gpu2cpu_time(self, gpu_time:decimal.Decimal, is_copy:bool) -> float:
|
||||
"""
|
||||
Translates local gpu time (timestamp) into global cpu time.
|
||||
"""
|
||||
return float(gpu_time + (self.gpu2cpu_copy_time_diff if is_copy else self.gpu2cpu_compute_time_diff))
|
||||
if PROFILE:
|
||||
self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp) for st, en, name, is_cp in self.sig_prof_records]
|
||||
self.sig_prof_records = []
|
||||
|
||||
def _alloc_kernargs(self, alloc_size:int) -> int:
|
||||
"""
|
||||
@@ -517,26 +514,41 @@ class HCQCompiled(Compiled):
|
||||
self.kernargs_ptr = (res:=self.kernargs_ptr) + alloc_size
|
||||
return res
|
||||
|
||||
def _prof_setup(self):
|
||||
if not hasattr(self, 'profile_logger'): atexit.register(self._prof_finalize)
|
||||
self.profile_logger = ProfileLogger()
|
||||
def _ensure_shared_time_base(self):
|
||||
if hasattr(self, 'gpu2cpu_compute_time_diff'): return
|
||||
|
||||
def _sync_queue(q_t):
|
||||
self.synchronize()
|
||||
q_t().timestamp(self.timeline_signal).signal(self.timeline_signal, self.timeline_value).submit(self)
|
||||
self.timeline_value += 1
|
||||
cpu_start_time = decimal.Decimal(time.perf_counter_ns()) / decimal.Decimal(1000)
|
||||
self.timeline_signal.wait(self.timeline_value - 1)
|
||||
return cpu_start_time - self.timeline_signal.timestamp
|
||||
|
||||
self.gpu2cpu_compute_time_diff, self.gpu2cpu_copy_time_diff = _sync_queue(self.hw_compute_queue_t), _sync_queue(self.hw_copy_queue_t)
|
||||
|
||||
def _prof_process_events(self):
|
||||
self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp) for st, en, name, is_cp in self.sig_prof_records]
|
||||
self.sig_prof_records = []
|
||||
def _gpu2cpu_time(self, gpu_time:decimal.Decimal, is_copy:bool) -> float:
|
||||
"""
|
||||
Translates local gpu time (timestamp) into global cpu time.
|
||||
"""
|
||||
self._ensure_shared_time_base()
|
||||
return float(gpu_time + (self.gpu2cpu_copy_time_diff if is_copy else self.gpu2cpu_compute_time_diff))
|
||||
|
||||
def _prof_setup(self):
|
||||
if hasattr(self, 'profile_logger'): return
|
||||
atexit.register(self._prof_finalize)
|
||||
self.profile_logger = ProfileLogger()
|
||||
|
||||
def _prof_finalize(self):
|
||||
qname = ["COMPUTE", "DMA"]
|
||||
|
||||
for st, en, name, is_cp in self.raw_prof_records:
|
||||
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, ["COMPUTE", "DMA"][is_cp])]
|
||||
self.raw_prof_records = []
|
||||
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, qname[is_cp])]
|
||||
for d_st, d_dev, d_cp, st, dev, cp in self.dep_prof_records:
|
||||
self.profile_logger.deps += [(d_dev._gpu2cpu_time(d_st, d_cp), dev._gpu2cpu_time(st, cp), d_dev.dname, qname[d_cp], dev.dname, qname[cp])]
|
||||
self.raw_prof_records, self.dep_prof_records = [], []
|
||||
|
||||
# Remove the logger, this flushes all data written by the device.
|
||||
del self.profile_logger
|
||||
|
||||
def _wrap_timeline_signal(self):
|
||||
|
||||
@@ -161,24 +161,34 @@ class Profiling(contextlib.ContextDecorator):
|
||||
class ProfileLogger:
|
||||
writers: int = 0
|
||||
mjson: List[Dict] = []
|
||||
actors: Dict[str, int] = {}
|
||||
subactors: Dict[Tuple[str, str], int] = {}
|
||||
actors: Dict[Union[str, Tuple[str, str]], int] = {}
|
||||
|
||||
def __init__(self): self.events, ProfileLogger.writers = [], ProfileLogger.writers + 1
|
||||
def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1
|
||||
|
||||
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor)]
|
||||
|
||||
def _ensure_actor(self, actor_name, subactor_name):
|
||||
if actor_name not in self.actors:
|
||||
self.actors[actor_name] = (pid:=len(self.actors))
|
||||
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
|
||||
|
||||
if (subactor_key:=(actor_name,subactor_name)) not in self.actors:
|
||||
self.actors[subactor_key] = (tid:=len(self.actors))
|
||||
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
|
||||
|
||||
return self.actors[actor_name], self.actors.get(subactor_key, -1)
|
||||
|
||||
def __del__(self):
|
||||
# perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
|
||||
for name,st,et,actor_name,subactor_name in self.events:
|
||||
if actor_name not in self.actors:
|
||||
self.actors[actor_name] = (pid:=len(self.actors))
|
||||
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
|
||||
pid, tid = self._ensure_actor(actor_name,subactor_name)
|
||||
self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts":st, "dur":et-st})
|
||||
|
||||
if (subactor_key:=(actor_name,subactor_name)) not in self.subactors:
|
||||
self.subactors[subactor_key] = (tid:=len(self.subactors))
|
||||
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
|
||||
|
||||
self.mjson.append({"name": name, "ph": "X", "pid": self.actors[actor_name], "tid": self.subactors.get(subactor_key, -1), "ts":st, "dur":et-st})
|
||||
for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
|
||||
dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
|
||||
pid, tid = self._ensure_actor(actor_name,subactor_name)
|
||||
self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts":en-0.1, "bp": "e"})
|
||||
self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts":st, "bp": "e"})
|
||||
|
||||
ProfileLogger.writers -= 1
|
||||
if ProfileLogger.writers == 0 and len(self.mjson) > 0:
|
||||
|
||||
@@ -52,7 +52,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
self.last_timeline: Dict[HCQCompiled, Tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
|
||||
self.last_ji: Dict[HWCommandQueue, Optional[int]] = {q: None for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
|
||||
self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(self.jit_cache) * 2)] if PROFILE else []
|
||||
self.prof_records: List
|
||||
self.prof_deps: List[Tuple[int, HCQCompiled, bool, int, HCQCompiled, bool]] = []
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
enqueue_dev = ji.prg.device if isinstance(ji.prg, CompiledRunner) else Device[ji.bufs[1].device] #type:ignore
|
||||
@@ -79,7 +79,8 @@ class HCQGraph(MultiGraphRunner):
|
||||
# Go through all dependencies and, if we need the signal from that ji, enable it by setting the signal value in the signal schedule.
|
||||
for sig, val in deps:
|
||||
if id(sig) in [id(x) for x in self.signals.values()]:
|
||||
self.signal_sched[val - 1] = self.signal_sched[val - 1][:2] + (val,) + self.signal_sched[val - 1][3:]
|
||||
self.signal_sched[val - 1] = (sched:=(self.signal_sched[val - 1][:2] + (val,) + self.signal_sched[val - 1][3:]))
|
||||
if PROFILE: self.prof_deps += [(sched[3][1][0], sched[3][2], sched[3][4], prof_info[0][0], prof_info[2], prof_info[4])] # type: ignore
|
||||
|
||||
self.signal_sched[j] = (deps, out_signal, None if isinstance(ji.prg, CompiledRunner) else (j + 1), prof_info)
|
||||
self.last_ji[enqueue_queue] = j
|
||||
@@ -183,12 +184,13 @@ class HCQGraph(MultiGraphRunner):
|
||||
|
||||
def collect_timestamps(self):
|
||||
timestamps = [s.timestamp for s in self.prof_signals]
|
||||
for _,_,_,((st,_),(en,_),dev,desc,is_cp) in self.signal_sched.values(): #type: ignore
|
||||
for _,_,_,((st,_),(en,_),dev,desc,is_cp) in self.signal_sched.values(): # type: ignore
|
||||
dev.raw_prof_records += [(timestamps[st], timestamps[en], desc, is_cp)]
|
||||
for d_st,d_dev,d_is_cp,st,dev,is_cp in self.prof_deps: dev.dep_prof_records += [(timestamps[d_st],d_dev,d_is_cp,timestamps[st]+1,dev,is_cp)]
|
||||
|
||||
def __del__(self):
|
||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||
|
||||
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
|
||||
if PROFILE and self.kickoff_value >= 1: self.collect_timestamps()
|
||||
|
||||
for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferOptions(cpu_access=True))
|
||||
|
||||
Reference in New Issue
Block a user