add dependency viewer to hcq profiler (#5874)

* hcq profiler support deps

* clean up

* cleaner

* cleanup

* revert this

* linter

* mypy

* add test

* sync is strange, need to take the end

* linter + test
This commit is contained in:
nimlgen
2024-08-02 22:07:01 +03:00
committed by GitHub
parent 23e8c39288
commit 2777784b91
4 changed files with 91 additions and 32 deletions

View File

@@ -1,5 +1,5 @@
import unittest, ctypes, struct, contextlib, tempfile, pathlib, json, time, atexit, random
from tinygrad import Device, Tensor, dtypes
from tinygrad import Device, Tensor, dtypes, TinyJit
from tinygrad.helpers import CI, getenv, Context
from tinygrad.device import Buffer, BufferOptions, HCQCompiled
from tinygrad.engine.schedule import create_schedule
@@ -397,7 +397,7 @@ def helper_collect_profile(*devs, random_setup_delay=False):
def helper_profile_filter_node(profile, **kwargs):
assert len(profile) > 0, "Empty profile"
assert 'traceEvents' in profile, "traceEvents should present"
return [x for x in profile['traceEvents'] if all(x[k] == v for k,v in kwargs.items())]
return [x for x in profile['traceEvents'] if all(x.get(k, None) == v for k,v in kwargs.items())]
def helper_profile_parse_pids(profile):
pids, tids = {}, {}
@@ -407,6 +407,20 @@ def helper_profile_parse_pids(profile):
for th in threads: tids[th['tid']] = th['args']['name']
return pids, tids
def helper_profile_parse_deps(profile):
deps = []
for s in helper_profile_filter_node(profile, ph="s"):
f = helper_profile_filter_node(profile, ph="f", id=s['id'])[0]
starts, ends = [], []
for x in helper_profile_filter_node(profile, ph="X"):
if s['pid'] == x['pid'] and s['tid'] == x['tid'] and x['ts'] <= s['ts'] <= x['ts'] + x['dur']: starts.append(x)
if f['pid'] == x['pid'] and f['tid'] == x['tid'] and x['ts'] <= f['ts'] <= x['ts'] + x['dur']: ends.append(x)
assert len(starts) == 1 and len(ends) == 1, "more than one start and end possible, valid?"
deps.append((s, f, starts[0], ends[0]))
return deps
def helper_validate_node(node, duration_s=10, ts_age_s=30, profile=None, pid_name=None, tid_name=None):
pids, tids = helper_profile_parse_pids(profile)
assert abs(node['ts'] - time.perf_counter_ns() / 1e3) < ts_age_s * 1e6, "timestimp is not in 30s range"
@@ -483,6 +497,27 @@ class TestProfiler(unittest.TestCase):
copyin_node_2 = helper_profile_filter_node(profile, name=f"CPU -> {Device.DEFAULT}:1")[0]
helper_validate_node(copyin_node_2, profile=profile, pid_name=f"{Device.DEFAULT}:1", tid_name="DMA")
@unittest.skipIf(MOCKGPU and Device.DEFAULT == "AMD", "AMD mockgpu with indirect buffers does not support queue wait interrupts")
def test_profile_deps(self):
d1 = Device[f"{Device.DEFAULT}:1"]
def f(a):
x = (a + 1).realize()
return x, x.to(d1.dname).realize()
a = Tensor.randn(10, 10, device=TestProfiler.d0.dname).realize()
with helper_collect_profile(TestProfiler.d0, d1) as profile:
jf = TinyJit(f)
for _ in range(3): jf(a)
del jf
deps = helper_profile_parse_deps(profile)
assert len(deps) == 1, "one dep is expected, one launch"
_, _, l, r = deps[0]
assert l['name'].find("->") == -1, "should be kernel"
assert r['name'] == f"{Device.DEFAULT} -> {Device.DEFAULT}:1", "should be copy"
@unittest.skipIf(CI, "skip CI")
def test_profile_sync(self):
mv = memoryview(bytearray(struct.pack("ff", 0, 1)))

View File

@@ -489,6 +489,7 @@ class HCQCompiled(Compiled):
self.timeline_signal, self._shadow_timeline_signal = timeline_signals
self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = []
self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool]] = []
self.dep_prof_records:List[Any] = []
if PROFILE: self._prof_setup()
from tinygrad.runtime.graph.hcq import HCQGraph
@@ -501,13 +502,9 @@ class HCQCompiled(Compiled):
self.timeline_signal.wait(self.timeline_value - 1)
if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
if PROFILE: self._prof_process_events()
def _gpu2cpu_time(self, gpu_time:decimal.Decimal, is_copy:bool) -> float:
"""
Translates local gpu time (timestamp) into global cpu time.
"""
return float(gpu_time + (self.gpu2cpu_copy_time_diff if is_copy else self.gpu2cpu_compute_time_diff))
if PROFILE:
self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp) for st, en, name, is_cp in self.sig_prof_records]
self.sig_prof_records = []
def _alloc_kernargs(self, alloc_size:int) -> int:
"""
@@ -517,26 +514,41 @@ class HCQCompiled(Compiled):
self.kernargs_ptr = (res:=self.kernargs_ptr) + alloc_size
return res
def _prof_setup(self):
if not hasattr(self, 'profile_logger'): atexit.register(self._prof_finalize)
self.profile_logger = ProfileLogger()
def _ensure_shared_time_base(self):
if hasattr(self, 'gpu2cpu_compute_time_diff'): return
def _sync_queue(q_t):
self.synchronize()
q_t().timestamp(self.timeline_signal).signal(self.timeline_signal, self.timeline_value).submit(self)
self.timeline_value += 1
cpu_start_time = decimal.Decimal(time.perf_counter_ns()) / decimal.Decimal(1000)
self.timeline_signal.wait(self.timeline_value - 1)
return cpu_start_time - self.timeline_signal.timestamp
self.gpu2cpu_compute_time_diff, self.gpu2cpu_copy_time_diff = _sync_queue(self.hw_compute_queue_t), _sync_queue(self.hw_copy_queue_t)
def _prof_process_events(self):
self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp) for st, en, name, is_cp in self.sig_prof_records]
self.sig_prof_records = []
def _gpu2cpu_time(self, gpu_time:decimal.Decimal, is_copy:bool) -> float:
"""
Translates local gpu time (timestamp) into global cpu time.
"""
self._ensure_shared_time_base()
return float(gpu_time + (self.gpu2cpu_copy_time_diff if is_copy else self.gpu2cpu_compute_time_diff))
def _prof_setup(self):
if hasattr(self, 'profile_logger'): return
atexit.register(self._prof_finalize)
self.profile_logger = ProfileLogger()
def _prof_finalize(self):
qname = ["COMPUTE", "DMA"]
for st, en, name, is_cp in self.raw_prof_records:
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, ["COMPUTE", "DMA"][is_cp])]
self.raw_prof_records = []
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, qname[is_cp])]
for d_st, d_dev, d_cp, st, dev, cp in self.dep_prof_records:
self.profile_logger.deps += [(d_dev._gpu2cpu_time(d_st, d_cp), dev._gpu2cpu_time(st, cp), d_dev.dname, qname[d_cp], dev.dname, qname[cp])]
self.raw_prof_records, self.dep_prof_records = [], []
# Remove the logger, this flushes all data written by the device.
del self.profile_logger
def _wrap_timeline_signal(self):

View File

@@ -161,24 +161,34 @@ class Profiling(contextlib.ContextDecorator):
class ProfileLogger:
writers: int = 0
mjson: List[Dict] = []
actors: Dict[str, int] = {}
subactors: Dict[Tuple[str, str], int] = {}
actors: Dict[Union[str, Tuple[str, str]], int] = {}
def __init__(self): self.events, ProfileLogger.writers = [], ProfileLogger.writers + 1
def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor)]
def _ensure_actor(self, actor_name, subactor_name):
if actor_name not in self.actors:
self.actors[actor_name] = (pid:=len(self.actors))
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
if (subactor_key:=(actor_name,subactor_name)) not in self.actors:
self.actors[subactor_key] = (tid:=len(self.actors))
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
return self.actors[actor_name], self.actors.get(subactor_key, -1)
def __del__(self):
# perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
for name,st,et,actor_name,subactor_name in self.events:
if actor_name not in self.actors:
self.actors[actor_name] = (pid:=len(self.actors))
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
pid, tid = self._ensure_actor(actor_name,subactor_name)
self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts":st, "dur":et-st})
if (subactor_key:=(actor_name,subactor_name)) not in self.subactors:
self.subactors[subactor_key] = (tid:=len(self.subactors))
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
self.mjson.append({"name": name, "ph": "X", "pid": self.actors[actor_name], "tid": self.subactors.get(subactor_key, -1), "ts":st, "dur":et-st})
for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
pid, tid = self._ensure_actor(actor_name,subactor_name)
self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts":en-0.1, "bp": "e"})
self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts":st, "bp": "e"})
ProfileLogger.writers -= 1
if ProfileLogger.writers == 0 and len(self.mjson) > 0:

View File

@@ -52,7 +52,7 @@ class HCQGraph(MultiGraphRunner):
self.last_timeline: Dict[HCQCompiled, Tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
self.last_ji: Dict[HWCommandQueue, Optional[int]] = {q: None for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(self.jit_cache) * 2)] if PROFILE else []
self.prof_records: List
self.prof_deps: List[Tuple[int, HCQCompiled, bool, int, HCQCompiled, bool]] = []
for j,ji in enumerate(self.jit_cache):
enqueue_dev = ji.prg.device if isinstance(ji.prg, CompiledRunner) else Device[ji.bufs[1].device] #type:ignore
@@ -79,7 +79,8 @@ class HCQGraph(MultiGraphRunner):
# Go through all dependencies and, if we need the signal from that ji, enable it by setting the signal value in the signal schedule.
for sig, val in deps:
if id(sig) in [id(x) for x in self.signals.values()]:
self.signal_sched[val - 1] = self.signal_sched[val - 1][:2] + (val,) + self.signal_sched[val - 1][3:]
self.signal_sched[val - 1] = (sched:=(self.signal_sched[val - 1][:2] + (val,) + self.signal_sched[val - 1][3:]))
if PROFILE: self.prof_deps += [(sched[3][1][0], sched[3][2], sched[3][4], prof_info[0][0], prof_info[2], prof_info[4])] # type: ignore
self.signal_sched[j] = (deps, out_signal, None if isinstance(ji.prg, CompiledRunner) else (j + 1), prof_info)
self.last_ji[enqueue_queue] = j
@@ -183,12 +184,13 @@ class HCQGraph(MultiGraphRunner):
def collect_timestamps(self):
timestamps = [s.timestamp for s in self.prof_signals]
for _,_,_,((st,_),(en,_),dev,desc,is_cp) in self.signal_sched.values(): #type: ignore
for _,_,_,((st,_),(en,_),dev,desc,is_cp) in self.signal_sched.values(): # type: ignore
dev.raw_prof_records += [(timestamps[st], timestamps[en], desc, is_cp)]
for d_st,d_dev,d_is_cp,st,dev,is_cp in self.prof_deps: dev.dep_prof_records += [(timestamps[d_st],d_dev,d_is_cp,timestamps[st]+1,dev,is_cp)]
def __del__(self):
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
if PROFILE and self.kickoff_value >= 1: self.collect_timestamps()
for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferOptions(cpu_access=True))