mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
@@ -88,7 +88,8 @@ class HCQGraph(MultiGraphRunner):
|
|||||||
else:
|
else:
|
||||||
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
|
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
|
||||||
queue_idx = next(self.copy_queue_cnt[enqueue_dev]) % self.num_copy_queues
|
queue_idx = next(self.copy_queue_cnt[enqueue_dev]) % self.num_copy_queues
|
||||||
enqueue_queue = self.copy_queues.setdefault((enqueue_dev, queue_idx), enqueue_dev.hw_copy_queue_t(queue_idx=queue_idx))
|
enqueue_queue = self.copy_queues.setdefault((enqueue_dev, queue_idx),
|
||||||
|
enqueue_dev.hw_copy_queue_t(queue_idx=queue_idx).wait(self.signals['KICK'], self.kickoff_var))
|
||||||
|
|
||||||
out_signal = self.signals.setdefault(enqueue_queue, self.devices[0].new_signal(value=0))
|
out_signal = self.signals.setdefault(enqueue_queue, self.devices[0].new_signal(value=0))
|
||||||
|
|
||||||
@@ -190,15 +191,13 @@ class HCQGraph(MultiGraphRunner):
|
|||||||
def _dev_copy_queues(self, dev): return [q for (d, _), q in self.copy_queues.items() if d == dev]
|
def _dev_copy_queues(self, dev): return [q for (d, _), q in self.copy_queues.items() if d == dev]
|
||||||
|
|
||||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
||||||
# Wait and restore signals
|
# Map input rawbuffers
|
||||||
self.kickoff_value += 1
|
|
||||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
|
||||||
for sig in self.queue_signals_to_reset: sig.value = 0
|
|
||||||
self.signals['KICK'].value = self.kickoff_value
|
|
||||||
|
|
||||||
for dev in self.devices:
|
for dev in self.devices:
|
||||||
for idx_to_map in self.input_replace_map[dev]: cast(HCQAllocator, dev.allocator).map(input_rawbuffers[idx_to_map]._buf)
|
for idx_to_map in self.input_replace_map[dev]: cast(HCQAllocator, dev.allocator).map(input_rawbuffers[idx_to_map]._buf)
|
||||||
|
|
||||||
|
# Wait and restore signals
|
||||||
|
self.kickoff_value += 1
|
||||||
|
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||||
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
|
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
|
||||||
|
|
||||||
hcq_var_vals = {self.kickoff_var.expr: self.kickoff_value, **var_vals,
|
hcq_var_vals = {self.kickoff_var.expr: self.kickoff_value, **var_vals,
|
||||||
@@ -214,6 +213,10 @@ class HCQGraph(MultiGraphRunner):
|
|||||||
for copy_queue in self._dev_copy_queues(dev): copy_queue.submit(dev, hcq_var_vals_local)
|
for copy_queue in self._dev_copy_queues(dev): copy_queue.submit(dev, hcq_var_vals_local)
|
||||||
self.last_timeline[dev] = (dev.timeline_signal, dev.next_timeline())
|
self.last_timeline[dev] = (dev.timeline_signal, dev.next_timeline())
|
||||||
|
|
||||||
|
# Launch graph
|
||||||
|
for sig in self.queue_signals_to_reset: sig.value = 0
|
||||||
|
self.signals['KICK'].value = self.kickoff_value
|
||||||
|
|
||||||
if wait:
|
if wait:
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||||
|
|||||||
Reference in New Issue
Block a user