hcq: better queue managment (#13991)

This commit is contained in:
nimlgen
2026-01-03 13:11:15 +03:00
committed by GitHub
parent efb2ae87c6
commit 3b354bc11f

View File

@@ -1,4 +1,4 @@
import collections, itertools, time
import collections, time
from typing import Any, cast
from tinygrad.helpers import round_up, PROFILE, ALL2ALL, merge_dicts, getenv, dedup, suppress_finalizing
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator, MMIOInterface
@@ -50,8 +50,7 @@ class HCQGraph(MultiGraphRunner):
self.comp_queues: dict[HCQCompiled, HWQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
self.copy_queues: dict[tuple[HCQCompiled, int], HWQueue] = {} # lazy allocation, keyed by (device, queue_idx)
self.num_copy_queues: int = getenv("HCQ_NUM_SDMA", 7 if ALL2ALL >= 1 else 1)
self.copy_queue_cnt: collections.defaultdict[HCQCompiled, itertools.count] = collections.defaultdict(itertools.count)
self.num_copy_queues: int = getenv("HCQ_NUM_SDMA", min(len(self.devices), 8) if ALL2ALL >= 1 else 1)
self.signals: dict[Any, HCQSignal] = {**{dev: dev.new_signal(value=0) for dev in self.devices if not dev._is_cpu()},
**{"KICK": self.devices[0].new_signal(value=0)}, **{dev: self.devices[0].new_signal(value=0) for dev in self.devices if dev._is_cpu()}}
@@ -87,7 +86,7 @@ class HCQGraph(MultiGraphRunner):
enqueue_queue = self.comp_queues[enqueue_dev]
else:
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
queue_idx = next(self.copy_queue_cnt[enqueue_dev]) % self.num_copy_queues
queue_idx = self.devices.index(cast(HCQCompiled, Device[cast(Buffer, ji.bufs[0]).device])) % self.num_copy_queues
enqueue_queue = self.copy_queues.setdefault((enqueue_dev, queue_idx),
enqueue_dev.hw_copy_queue_t(queue_idx=queue_idx).wait(self.signals['KICK'], self.kickoff_var))