mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
@@ -88,7 +88,8 @@ class HCQGraph(MultiGraphRunner):
|
||||
else:
|
||||
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
|
||||
queue_idx = next(self.copy_queue_cnt[enqueue_dev]) % self.num_copy_queues
|
||||
enqueue_queue = self.copy_queues.setdefault((enqueue_dev, queue_idx), enqueue_dev.hw_copy_queue_t(queue_idx=queue_idx))
|
||||
enqueue_queue = self.copy_queues.setdefault((enqueue_dev, queue_idx),
|
||||
enqueue_dev.hw_copy_queue_t(queue_idx=queue_idx).wait(self.signals['KICK'], self.kickoff_var))
|
||||
|
||||
out_signal = self.signals.setdefault(enqueue_queue, self.devices[0].new_signal(value=0))
|
||||
|
||||
@@ -190,15 +191,13 @@ class HCQGraph(MultiGraphRunner):
|
||||
def _dev_copy_queues(self, dev): return [q for (d, _), q in self.copy_queues.items() if d == dev]
|
||||
|
||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
||||
# Wait and restore signals
|
||||
self.kickoff_value += 1
|
||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||
for sig in self.queue_signals_to_reset: sig.value = 0
|
||||
self.signals['KICK'].value = self.kickoff_value
|
||||
|
||||
# Map input rawbuffers
|
||||
for dev in self.devices:
|
||||
for idx_to_map in self.input_replace_map[dev]: cast(HCQAllocator, dev.allocator).map(input_rawbuffers[idx_to_map]._buf)
|
||||
|
||||
# Wait and restore signals
|
||||
self.kickoff_value += 1
|
||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
|
||||
|
||||
hcq_var_vals = {self.kickoff_var.expr: self.kickoff_value, **var_vals,
|
||||
@@ -214,6 +213,10 @@ class HCQGraph(MultiGraphRunner):
|
||||
for copy_queue in self._dev_copy_queues(dev): copy_queue.submit(dev, hcq_var_vals_local)
|
||||
self.last_timeline[dev] = (dev.timeline_signal, dev.next_timeline())
|
||||
|
||||
# Launch graph
|
||||
for sig in self.queue_signals_to_reset: sig.value = 0
|
||||
self.signals['KICK'].value = self.kickoff_value
|
||||
|
||||
if wait:
|
||||
st = time.perf_counter()
|
||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||
|
||||
Reference in New Issue
Block a user