mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
hcq do not update the same signal (#5719)
* hcq do not update the same signal * import them
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import collections, time
|
||||
from typing import List, Any, Dict, cast, Optional, Tuple, Set
|
||||
from tinygrad.helpers import round_up, to_mv, PROFILE
|
||||
from tinygrad.device import HCQAllocator, Buffer, BufferOptions, Compiled, Device
|
||||
from tinygrad.device import HCQCompiled, HCQAllocator, HCQSignal, Buffer, BufferOptions, Compiled, Device
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
@@ -48,7 +48,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
self.save_devs: Dict[Any, Set] = {q: set() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
|
||||
for dev in self.devices: self.save_devs[self.comp_queues[dev]].add(dev)
|
||||
|
||||
self.graph_timeline = {dev: 0 for dev in self.devices} # Dict[dev, last graph sigval]
|
||||
self.last_timeline: Dict[HCQCompiled, Tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
|
||||
self.last_ji: Dict[Any, Any] = {q: None for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
@@ -122,7 +122,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
# Wait and restore signals
|
||||
self.kickoff_value += 1
|
||||
for dev in self.devices: dev.timeline_signal.wait(self.graph_timeline[dev])
|
||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||
for queue in self.comp_queues.values(): self.signals[queue].value = 0
|
||||
for queue in self.copy_queues.values(): self.signals[queue].value = 0
|
||||
self.dev_kickoff_signal['CPU'].value = self.kickoff_value
|
||||
@@ -145,20 +145,21 @@ class HCQGraph(MultiGraphRunner):
|
||||
queue.update_exec(cmd_ptr, *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
|
||||
|
||||
for dev in self.devices:
|
||||
self.comp_queues[dev].update_wait(1, dev.timeline_signal, dev.timeline_value - 1).update_wait(2, value=self.kickoff_value) \
|
||||
.update_signal(3, value=self.kickoff_value) \
|
||||
.update_signal(len(self.comp_queues[dev]) - 1, dev.timeline_signal, dev.timeline_value).submit(dev)
|
||||
comp_queue, copy_queue, need_sig_upd = self.comp_queues[dev], self.copy_queues[dev], dev.timeline_signal != self.last_timeline[dev][0]
|
||||
comp_queue.update_wait(1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value - 1) \
|
||||
.update_wait(2, value=self.kickoff_value).update_signal(3, value=self.kickoff_value) \
|
||||
.update_signal(len(comp_queue)-1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value).submit(dev)
|
||||
|
||||
if self.last_ji[(cp_queue:=self.copy_queues[dev])] is not None:
|
||||
for cmd_idx in self.kickoff_wait_cmds[cp_queue]: cp_queue.update_wait(cmd_idx, value=self.kickoff_value)
|
||||
cp_queue.submit(dev)
|
||||
if self.last_ji[copy_queue] is not None:
|
||||
for cmd_idx in self.kickoff_wait_cmds[copy_queue]: copy_queue.update_wait(cmd_idx, value=self.kickoff_value)
|
||||
copy_queue.submit(dev)
|
||||
|
||||
self.graph_timeline[dev] = dev.timeline_value
|
||||
self.last_timeline[dev] = (dev.timeline_signal, dev.timeline_value)
|
||||
dev.timeline_value += 1
|
||||
|
||||
if wait:
|
||||
st = time.perf_counter()
|
||||
for dev in self.devices: dev.timeline_signal.wait(self.graph_timeline[dev])
|
||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||
return time.perf_counter() - st
|
||||
return None
|
||||
|
||||
@@ -175,7 +176,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
return [(self.signals[k], max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()] + sync_signals
|
||||
|
||||
def __del__(self):
|
||||
for dev in self.devices: dev.timeline_signal.wait(self.graph_timeline[dev])
|
||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||
|
||||
# Graph is destructed. No need to keep signals any more, so return them as part of profiling.
|
||||
if PROFILE and self.kickoff_value > 1:
|
||||
|
||||
Reference in New Issue
Block a user