From 3b354bc11f548848343874b127122dc7e9c24bf8 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Sat, 3 Jan 2026 13:11:15 +0300 Subject: [PATCH] hcq: better queue managment (#13991) --- tinygrad/runtime/graph/hcq.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index 1331e57c64..9220f4a9c8 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -1,4 +1,4 @@ -import collections, itertools, time +import collections, time from typing import Any, cast from tinygrad.helpers import round_up, PROFILE, ALL2ALL, merge_dicts, getenv, dedup, suppress_finalizing from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator, MMIOInterface @@ -50,8 +50,7 @@ class HCQGraph(MultiGraphRunner): self.comp_queues: dict[HCQCompiled, HWQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices} self.copy_queues: dict[tuple[HCQCompiled, int], HWQueue] = {} # lazy allocation, keyed by (device, queue_idx) - self.num_copy_queues: int = getenv("HCQ_NUM_SDMA", 7 if ALL2ALL >= 1 else 1) - self.copy_queue_cnt: collections.defaultdict[HCQCompiled, itertools.count] = collections.defaultdict(itertools.count) + self.num_copy_queues: int = getenv("HCQ_NUM_SDMA", min(len(self.devices), 8) if ALL2ALL >= 1 else 1) self.signals: dict[Any, HCQSignal] = {**{dev: dev.new_signal(value=0) for dev in self.devices if not dev._is_cpu()}, **{"KICK": self.devices[0].new_signal(value=0)}, **{dev: self.devices[0].new_signal(value=0) for dev in self.devices if dev._is_cpu()}} @@ -87,7 +86,7 @@ class HCQGraph(MultiGraphRunner): enqueue_queue = self.comp_queues[enqueue_dev] else: assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue" - queue_idx = next(self.copy_queue_cnt[enqueue_dev]) % self.num_copy_queues + queue_idx = self.devices.index(cast(HCQCompiled, Device[cast(Buffer, ji.bufs[0]).device])) % self.num_copy_queues enqueue_queue = self.copy_queues.setdefault((enqueue_dev, queue_idx), enqueue_dev.hw_copy_queue_t(queue_idx=queue_idx).wait(self.signals['KICK'], self.kickoff_var))