hcq sync aql (#13756)

* hcq sync aql

* w
This commit is contained in:
nimlgen
2026-01-03 12:59:24 +03:00
committed by GitHub
parent bd55507ee4
commit efb2ae87c6

View File

@@ -88,7 +88,8 @@ class HCQGraph(MultiGraphRunner):
else:
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
queue_idx = next(self.copy_queue_cnt[enqueue_dev]) % self.num_copy_queues
enqueue_queue = self.copy_queues.setdefault((enqueue_dev, queue_idx), enqueue_dev.hw_copy_queue_t(queue_idx=queue_idx))
enqueue_queue = self.copy_queues.setdefault((enqueue_dev, queue_idx),
enqueue_dev.hw_copy_queue_t(queue_idx=queue_idx).wait(self.signals['KICK'], self.kickoff_var))
out_signal = self.signals.setdefault(enqueue_queue, self.devices[0].new_signal(value=0))
@@ -190,15 +191,13 @@ class HCQGraph(MultiGraphRunner):
def _dev_copy_queues(self, dev): return [q for (d, _), q in self.copy_queues.items() if d == dev]
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
# Wait and restore signals
self.kickoff_value += 1
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
for sig in self.queue_signals_to_reset: sig.value = 0
self.signals['KICK'].value = self.kickoff_value
# Map input rawbuffers
for dev in self.devices:
for idx_to_map in self.input_replace_map[dev]: cast(HCQAllocator, dev.allocator).map(input_rawbuffers[idx_to_map]._buf)
# Wait and restore signals
self.kickoff_value += 1
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
hcq_var_vals = {self.kickoff_var.expr: self.kickoff_value, **var_vals,
@@ -214,6 +213,10 @@ class HCQGraph(MultiGraphRunner):
for copy_queue in self._dev_copy_queues(dev): copy_queue.submit(dev, hcq_var_vals_local)
self.last_timeline[dev] = (dev.timeline_signal, dev.next_timeline())
# Launch graph
for sig in self.queue_signals_to_reset: sig.value = 0
self.signals['KICK'].value = self.kickoff_value
if wait:
st = time.perf_counter()
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])