hcq do not update the same signal (#5719)

* hcq do not update the same signal

* import them
This commit is contained in:
nimlgen
2024-07-26 00:24:45 +03:00
committed by GitHub
parent 6ec9ea9ddd
commit fb8148077e

View File

@@ -1,7 +1,7 @@
import collections, time
from typing import List, Any, Dict, cast, Optional, Tuple, Set
from tinygrad.helpers import round_up, to_mv, PROFILE
from tinygrad.device import HCQAllocator, Buffer, BufferOptions, Compiled, Device
from tinygrad.device import HCQCompiled, HCQAllocator, HCQSignal, Buffer, BufferOptions, Compiled, Device
from tinygrad.shape.symbolic import Variable
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
from tinygrad.engine.jit import MultiGraphRunner
@@ -48,7 +48,7 @@ class HCQGraph(MultiGraphRunner):
self.save_devs: Dict[Any, Set] = {q: set() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
for dev in self.devices: self.save_devs[self.comp_queues[dev]].add(dev)
self.graph_timeline = {dev: 0 for dev in self.devices} # Dict[dev, last graph sigval]
self.last_timeline: Dict[HCQCompiled, Tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
self.last_ji: Dict[Any, Any] = {q: None for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
for j,ji in enumerate(self.jit_cache):
@@ -122,7 +122,7 @@ class HCQGraph(MultiGraphRunner):
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
# Wait and restore signals
self.kickoff_value += 1
for dev in self.devices: dev.timeline_signal.wait(self.graph_timeline[dev])
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
for queue in self.comp_queues.values(): self.signals[queue].value = 0
for queue in self.copy_queues.values(): self.signals[queue].value = 0
self.dev_kickoff_signal['CPU'].value = self.kickoff_value
@@ -145,20 +145,21 @@ class HCQGraph(MultiGraphRunner):
queue.update_exec(cmd_ptr, *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
for dev in self.devices:
self.comp_queues[dev].update_wait(1, dev.timeline_signal, dev.timeline_value - 1).update_wait(2, value=self.kickoff_value) \
.update_signal(3, value=self.kickoff_value) \
.update_signal(len(self.comp_queues[dev]) - 1, dev.timeline_signal, dev.timeline_value).submit(dev)
comp_queue, copy_queue, need_sig_upd = self.comp_queues[dev], self.copy_queues[dev], dev.timeline_signal != self.last_timeline[dev][0]
comp_queue.update_wait(1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value - 1) \
.update_wait(2, value=self.kickoff_value).update_signal(3, value=self.kickoff_value) \
.update_signal(len(comp_queue)-1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value).submit(dev)
if self.last_ji[(cp_queue:=self.copy_queues[dev])] is not None:
for cmd_idx in self.kickoff_wait_cmds[cp_queue]: cp_queue.update_wait(cmd_idx, value=self.kickoff_value)
cp_queue.submit(dev)
if self.last_ji[copy_queue] is not None:
for cmd_idx in self.kickoff_wait_cmds[copy_queue]: copy_queue.update_wait(cmd_idx, value=self.kickoff_value)
copy_queue.submit(dev)
self.graph_timeline[dev] = dev.timeline_value
self.last_timeline[dev] = (dev.timeline_signal, dev.timeline_value)
dev.timeline_value += 1
if wait:
st = time.perf_counter()
for dev in self.devices: dev.timeline_signal.wait(self.graph_timeline[dev])
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
return time.perf_counter() - st
return None
@@ -175,7 +176,7 @@ class HCQGraph(MultiGraphRunner):
return [(self.signals[k], max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()] + sync_signals
def __del__(self):
for dev in self.devices: dev.timeline_signal.wait(self.graph_timeline[dev])
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
# Graph is destructed. No need to keep signals any more, so return them as part of profiling.
if PROFILE and self.kickoff_value > 1: