From efb2ae87c60f10f8db30fd67b5f9129ebfaf470d Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Sat, 3 Jan 2026 12:59:24 +0300 Subject: [PATCH] hcq sync aql (#13756) * hcq sync aql * w --- tinygrad/runtime/graph/hcq.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index 868d1cc5d7..1331e57c64 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -88,7 +88,8 @@ class HCQGraph(MultiGraphRunner): else: assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue" queue_idx = next(self.copy_queue_cnt[enqueue_dev]) % self.num_copy_queues - enqueue_queue = self.copy_queues.setdefault((enqueue_dev, queue_idx), enqueue_dev.hw_copy_queue_t(queue_idx=queue_idx)) + enqueue_queue = self.copy_queues.setdefault((enqueue_dev, queue_idx), + enqueue_dev.hw_copy_queue_t(queue_idx=queue_idx).wait(self.signals['KICK'], self.kickoff_var)) out_signal = self.signals.setdefault(enqueue_queue, self.devices[0].new_signal(value=0)) @@ -190,15 +191,13 @@ class HCQGraph(MultiGraphRunner): def _dev_copy_queues(self, dev): return [q for (d, _), q in self.copy_queues.items() if d == dev] def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None: - # Wait and restore signals - self.kickoff_value += 1 - for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1]) - for sig in self.queue_signals_to_reset: sig.value = 0 - self.signals['KICK'].value = self.kickoff_value - + # Map input rawbuffers for dev in self.devices: for idx_to_map in self.input_replace_map[dev]: cast(HCQAllocator, dev.allocator).map(input_rawbuffers[idx_to_map]._buf) + # Wait and restore signals + self.kickoff_value += 1 + for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1]) if PROFILE and self.kickoff_value > 1: self.collect_timestamps() hcq_var_vals = {self.kickoff_var.expr: self.kickoff_value, **var_vals, @@ -214,6 +213,10 @@ class HCQGraph(MultiGraphRunner): for copy_queue in self._dev_copy_queues(dev): copy_queue.submit(dev, hcq_var_vals_local) self.last_timeline[dev] = (dev.timeline_signal, dev.next_timeline()) + # Launch graph + for sig in self.queue_signals_to_reset: sig.value = 0 + self.signals['KICK'].value = self.kickoff_value + if wait: st = time.perf_counter() for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])