mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
the slowest line in hcq graph (#15635)
* the slowest line in hcq graph * x
This commit is contained in:
@@ -144,7 +144,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
self.last_j[enqueue_queue] = j
|
||||
|
||||
# Check which signals are used in the profile graph.
|
||||
self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(self.jit_cache) * 2)]
|
||||
self.prof_signal_is_used: set[int] = {sid for ent in self.prof_graph_entries for sid in (ent.st_id, ent.en_id)}
|
||||
|
||||
# Build hardware queues.
|
||||
self.copy_to_devs: dict[HCQCompiled, set[HCQCompiled]] = {dev: set() for dev in self.devices}
|
||||
@@ -167,7 +167,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val)
|
||||
|
||||
# Encode waits and start profile timestamp (if needed).
|
||||
if PROFILE and self.prof_signal_is_used[j * 2]: enqueue_queue.timestamp(self.prof_signals[j * 2])
|
||||
if PROFILE and j * 2 in self.prof_signal_is_used: enqueue_queue.timestamp(self.prof_signals[j * 2])
|
||||
|
||||
# Encode main commands based on ji type.
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
@@ -203,7 +203,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
|
||||
|
||||
# Encode finish profile timestamp (if needed).
|
||||
if PROFILE and self.prof_signal_is_used[j * 2 + 1]: enqueue_queue.timestamp(self.prof_signals[j * 2 + 1])
|
||||
if PROFILE and j * 2 + 1 in self.prof_signal_is_used: enqueue_queue.timestamp(self.prof_signals[j * 2 + 1])
|
||||
|
||||
if signal_val is not None: enqueue_queue.signal(signal, signal_val)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user