diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index f9884851b4..d41c56fe9e 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -58,8 +58,8 @@ class HCQGraph(MultiGraphRunner): # When profiling allocate 2 signals for each jit item to measure speed. The jth jit item have signals at 2*j and 2*j+1. # TODO: This logic might allocate a few extra signals... - self.prof_signals: list[HCQSignal] = [self.devices[0].new_signal() for i in range(len(jit_cache) * 2)] if PROFILE else [] - self.prog_graph_deps: list[list[int]] = [] + self.prof_signals: list[HCQSignal] = [] + self.prof_graph_deps: list[list[int]] = [] self.prof_graph_entries: list[ProfileGraphEntry] = [] last_j: dict[HWQueue, int|None] = collections.defaultdict(lambda: None) @@ -127,12 +127,12 @@ class HCQGraph(MultiGraphRunner): prof_ji_desc = ji.prg._prg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore self.prof_graph_entries.append(ProfileGraphEntry(enqueue_dev.device, prof_ji_desc, sig_st, j * 2 + 1, is_copy=not is_exec_prg)) - self.prog_graph_deps.append([d - 1 for _, d in rdeps]) + self.prof_graph_deps.append([d - 1 for _, d in rdeps]) last_j[enqueue_queue] = j # Check which signals are used in the profile graph. - self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(self.prof_signals))] + self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(jit_cache) * 2)] # Build hardware queues. self.copy_to_devs: dict[HCQCompiled, set[HCQCompiled]] = {dev: set() for dev in self.devices} @@ -149,6 +149,9 @@ class HCQGraph(MultiGraphRunner): for j,ji in enumerate(jit_cache): enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j] + # Lazy allocate signals + if PROFILE: self.prof_signals += [enqueue_dev.new_signal(value=0) for _ in range(2)] + for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val) # Encode waits and start profile timestamp (if needed). @@ -213,7 +216,7 @@ class HCQGraph(MultiGraphRunner): def collect_timestamps(self): # NOTE: Append to any device is fine... - self.devices[0].profile_events += [ProfileGraphEvent(self.prof_graph_entries, self.prog_graph_deps, [s.timestamp for s in self.prof_signals])] + self.devices[0].profile_events += [ProfileGraphEvent(self.prof_graph_entries, self.prof_graph_deps, [s.timestamp for s in self.prof_signals])] def dev_name(self, dev) -> str: return dev.device.replace(":", "_")