hcq: lazy prof signal allocation (#11531)

This commit is contained in:
nimlgen
2025-08-06 15:28:11 +03:00
committed by GitHub
parent eafc7fda12
commit 930d8dae0c

View File

@@ -58,8 +58,8 @@ class HCQGraph(MultiGraphRunner):
# When profiling allocate 2 signals for each jit item to measure speed. The jth jit item have signals at 2*j and 2*j+1.
# TODO: This logic might allocate a few extra signals...
self.prof_signals: list[HCQSignal] = [self.devices[0].new_signal() for i in range(len(jit_cache) * 2)] if PROFILE else []
self.prog_graph_deps: list[list[int]] = []
self.prof_signals: list[HCQSignal] = []
self.prof_graph_deps: list[list[int]] = []
self.prof_graph_entries: list[ProfileGraphEntry] = []
last_j: dict[HWQueue, int|None] = collections.defaultdict(lambda: None)
@@ -127,12 +127,12 @@ class HCQGraph(MultiGraphRunner):
prof_ji_desc = ji.prg._prg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
self.prof_graph_entries.append(ProfileGraphEntry(enqueue_dev.device, prof_ji_desc, sig_st, j * 2 + 1, is_copy=not is_exec_prg))
self.prog_graph_deps.append([d - 1 for _, d in rdeps])
self.prof_graph_deps.append([d - 1 for _, d in rdeps])
last_j[enqueue_queue] = j
# Check which signals are used in the profile graph.
self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(self.prof_signals))]
self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(jit_cache) * 2)]
# Build hardware queues.
self.copy_to_devs: dict[HCQCompiled, set[HCQCompiled]] = {dev: set() for dev in self.devices}
@@ -149,6 +149,9 @@ class HCQGraph(MultiGraphRunner):
for j,ji in enumerate(jit_cache):
enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
# Lazy allocate signals
if PROFILE: self.prof_signals += [enqueue_dev.new_signal(value=0) for _ in range(2)]
for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val)
# Encode waits and start profile timestamp (if needed).
@@ -213,7 +216,7 @@ class HCQGraph(MultiGraphRunner):
def collect_timestamps(self):
# NOTE: Append to any device is fine...
self.devices[0].profile_events += [ProfileGraphEvent(self.prof_graph_entries, self.prog_graph_deps, [s.timestamp for s in self.prof_signals])]
self.devices[0].profile_events += [ProfileGraphEvent(self.prof_graph_entries, self.prof_graph_deps, [s.timestamp for s in self.prof_signals])]
def dev_name(self, dev) -> str: return dev.device.replace(":", "_")