mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-20 20:38:03 -05:00
hcq: lazy prof signal allocation (#11531)
This commit is contained in:
@@ -58,8 +58,8 @@ class HCQGraph(MultiGraphRunner):
|
||||
|
||||
# When profiling allocate 2 signals for each jit item to measure speed. The jth jit item have signals at 2*j and 2*j+1.
|
||||
# TODO: This logic might allocate a few extra signals...
|
||||
self.prof_signals: list[HCQSignal] = [self.devices[0].new_signal() for i in range(len(jit_cache) * 2)] if PROFILE else []
|
||||
self.prog_graph_deps: list[list[int]] = []
|
||||
self.prof_signals: list[HCQSignal] = []
|
||||
self.prof_graph_deps: list[list[int]] = []
|
||||
self.prof_graph_entries: list[ProfileGraphEntry] = []
|
||||
|
||||
last_j: dict[HWQueue, int|None] = collections.defaultdict(lambda: None)
|
||||
@@ -127,12 +127,12 @@ class HCQGraph(MultiGraphRunner):
|
||||
prof_ji_desc = ji.prg._prg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
|
||||
|
||||
self.prof_graph_entries.append(ProfileGraphEntry(enqueue_dev.device, prof_ji_desc, sig_st, j * 2 + 1, is_copy=not is_exec_prg))
|
||||
self.prog_graph_deps.append([d - 1 for _, d in rdeps])
|
||||
self.prof_graph_deps.append([d - 1 for _, d in rdeps])
|
||||
|
||||
last_j[enqueue_queue] = j
|
||||
|
||||
# Check which signals are used in the profile graph.
|
||||
self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(self.prof_signals))]
|
||||
self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(jit_cache) * 2)]
|
||||
|
||||
# Build hardware queues.
|
||||
self.copy_to_devs: dict[HCQCompiled, set[HCQCompiled]] = {dev: set() for dev in self.devices}
|
||||
@@ -149,6 +149,9 @@ class HCQGraph(MultiGraphRunner):
|
||||
for j,ji in enumerate(jit_cache):
|
||||
enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
|
||||
|
||||
# Lazy allocate signals
|
||||
if PROFILE: self.prof_signals += [enqueue_dev.new_signal(value=0) for _ in range(2)]
|
||||
|
||||
for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val)
|
||||
|
||||
# Encode waits and start profile timestamp (if needed).
|
||||
@@ -213,7 +216,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
|
||||
def collect_timestamps(self):
|
||||
# NOTE: Append to any device is fine...
|
||||
self.devices[0].profile_events += [ProfileGraphEvent(self.prof_graph_entries, self.prog_graph_deps, [s.timestamp for s in self.prof_signals])]
|
||||
self.devices[0].profile_events += [ProfileGraphEvent(self.prof_graph_entries, self.prof_graph_deps, [s.timestamp for s in self.prof_signals])]
|
||||
|
||||
def dev_name(self, dev) -> str: return dev.device.replace(":", "_")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user