mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
hotfix: hcq profiler use mid point for deps flow (#5882)
* hcq profiler use mid point for deps * fixes * mypy
This commit is contained in:
@@ -489,7 +489,7 @@ class HCQCompiled(Compiled):
|
||||
self.timeline_signal, self._shadow_timeline_signal = timeline_signals
|
||||
self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = []
|
||||
self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool]] = []
|
||||
self.dep_prof_records:List[Any] = []
|
||||
self.dep_prof_records:List[Tuple[int, int, HCQCompiled, bool, int, int, HCQCompiled, bool]] = []
|
||||
if PROFILE: self._prof_setup()
|
||||
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
@@ -544,8 +544,10 @@ class HCQCompiled(Compiled):
|
||||
|
||||
for st, en, name, is_cp in self.raw_prof_records:
|
||||
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, qname[is_cp])]
|
||||
for d_st, d_dev, d_cp, st, dev, cp in self.dep_prof_records:
|
||||
self.profile_logger.deps += [(d_dev._gpu2cpu_time(d_st, d_cp), dev._gpu2cpu_time(st, cp), d_dev.dname, qname[d_cp], dev.dname, qname[cp])]
|
||||
for a_st, a_en, a_dev, a_is_copy, b_st, b_en, b_dev, b_is_copy in self.dep_prof_records:
|
||||
# Perfetto connects nodes based on timing data, ensuring every choice is valid by averaging times to a midpoint.
|
||||
a_tm, b_tm = a_dev._gpu2cpu_time((a_st+a_en)/decimal.Decimal(2), a_is_copy), b_dev._gpu2cpu_time((b_st+b_en)/decimal.Decimal(2), b_is_copy)
|
||||
self.profile_logger.deps += [(a_tm, b_tm, a_dev.dname, qname[a_is_copy], b_dev.dname, qname[b_is_copy])]
|
||||
self.raw_prof_records, self.dep_prof_records = [], []
|
||||
|
||||
# Remove the logger, this flushes all data written by the device.
|
||||
|
||||
@@ -187,7 +187,7 @@ class ProfileLogger:
|
||||
for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
|
||||
dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
|
||||
pid, tid = self._ensure_actor(actor_name,subactor_name)
|
||||
self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts":en-0.1, "bp": "e"})
|
||||
self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts":en, "bp": "e"})
|
||||
self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts":st, "bp": "e"})
|
||||
|
||||
ProfileLogger.writers -= 1
|
||||
|
||||
@@ -52,7 +52,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
self.last_timeline: Dict[HCQCompiled, Tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
|
||||
self.last_ji: Dict[HWCommandQueue, Optional[int]] = {q: None for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
|
||||
self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(self.jit_cache) * 2)] if PROFILE else []
|
||||
self.prof_deps: List[Tuple[int, HCQCompiled, bool, int, HCQCompiled, bool]] = []
|
||||
self.prof_deps: List[Tuple[Tuple, Tuple]] = []
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
enqueue_dev = ji.prg.device if isinstance(ji.prg, CompiledRunner) else Device[ji.bufs[1].device] #type:ignore
|
||||
@@ -79,8 +79,8 @@ class HCQGraph(MultiGraphRunner):
|
||||
# Go through all dependencies and, if we need the signal from that ji, enable it by setting the signal value in the signal schedule.
|
||||
for sig, val in deps:
|
||||
if id(sig) in [id(x) for x in self.signals.values()]:
|
||||
self.signal_sched[val - 1] = (sched:=(self.signal_sched[val - 1][:2] + (val,) + self.signal_sched[val - 1][3:]))
|
||||
if PROFILE: self.prof_deps += [(sched[3][1][0], sched[3][2], sched[3][4], prof_info[0][0], prof_info[2], prof_info[4])] # type: ignore
|
||||
self.signal_sched[val - 1] = self.signal_sched[val - 1][:2] + (val,) + self.signal_sched[val - 1][3:]
|
||||
if PROFILE: self.prof_deps += [(self.signal_sched[val - 1][3], prof_info)] # type: ignore
|
||||
|
||||
self.signal_sched[j] = (deps, out_signal, None if isinstance(ji.prg, CompiledRunner) else (j + 1), prof_info)
|
||||
self.last_ji[enqueue_queue] = j
|
||||
@@ -184,9 +184,12 @@ class HCQGraph(MultiGraphRunner):
|
||||
|
||||
def collect_timestamps(self):
|
||||
timestamps = [s.timestamp for s in self.prof_signals]
|
||||
|
||||
for _,_,_,((st,_),(en,_),dev,desc,is_cp) in self.signal_sched.values(): # type: ignore
|
||||
dev.raw_prof_records += [(timestamps[st], timestamps[en], desc, is_cp)]
|
||||
for d_st,d_dev,d_is_cp,st,dev,is_cp in self.prof_deps: dev.dep_prof_records += [(timestamps[d_st],d_dev,d_is_cp,timestamps[st]+1,dev,is_cp)]
|
||||
|
||||
for ((a_st,_), (a_en,_), a_dev, _, a_is_cp), ((b_st,_), (b_en,_), b_dev, _, b_is_cp) in self.prof_deps:
|
||||
b_dev.dep_prof_records += [(timestamps[a_st], timestamps[a_en], a_dev, a_is_cp, timestamps[b_st], timestamps[b_en], b_dev, b_is_cp)]
|
||||
|
||||
def __del__(self):
|
||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||
|
||||
Reference in New Issue
Block a user