mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
hcq graph refactor (#5887)
* cleanup * prof * cleaner * comments * more types
This commit is contained in:
@@ -492,7 +492,7 @@ class HCQCompiled(Compiled):
|
||||
self.timeline_signal, self._shadow_timeline_signal = timeline_signals
|
||||
self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = []
|
||||
self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool]] = []
|
||||
self.dep_prof_records:List[Tuple[int, int, HCQCompiled, bool, int, int, HCQCompiled, bool]] = []
|
||||
self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = []
|
||||
if PROFILE: self._prof_setup()
|
||||
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
|
||||
@@ -38,71 +38,85 @@ class HCQGraph(MultiGraphRunner):
|
||||
# graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with
|
||||
# global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the device’s
|
||||
# compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue.
|
||||
self.comp_queues: Dict[Compiled, HWComputeQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
|
||||
self.copy_queues: Dict[Compiled, HWCopyQueue] = {dev: dev.hw_copy_queue_t() for dev in self.devices}
|
||||
self.ji_schedule: Dict[int, Tuple[HCQCompiled, HWCommandQueue, List, List, HCQSignal, Optional[int]]] = {}
|
||||
|
||||
self.signal_sched: Dict[int, Tuple[List, HCQSignal, Optional[int], Optional[List]]] = {} # Dict[ji_idx, (deps, signal, sigval, prof_info)]
|
||||
self.signals = {q: dev.signal_t(value=0) for queues in (self.comp_queues, self.copy_queues) for dev,q in queues.items()} #type:ignore
|
||||
self.dev_kickoff_signal = {**{dev.dname: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}}
|
||||
self.kickoff_value = 0
|
||||
self.comp_queues: Dict[HCQCompiled, HWComputeQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
|
||||
self.copy_queues: Dict[HCQCompiled, HWCopyQueue] = {} # lazy allocation
|
||||
|
||||
self.save_devs: Dict[HWCommandQueue, Set] = {q: set() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
|
||||
for dev in self.devices: self.save_devs[self.comp_queues[dev]].add(dev)
|
||||
self.signals: Dict[Any, HCQSignal] = {**{dev.dname: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}}
|
||||
self.kickoff_value: int = 0
|
||||
|
||||
self.last_timeline: Dict[HCQCompiled, Tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
|
||||
self.last_ji: Dict[HWCommandQueue, Optional[int]] = {q: None for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
|
||||
self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(self.jit_cache) * 2)] if PROFILE else []
|
||||
self.prof_deps: List[Tuple[Tuple, Tuple]] = []
|
||||
self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int]]] = []
|
||||
|
||||
last_j: Dict[HWCommandQueue, Optional[int]] = collections.defaultdict(lambda: None)
|
||||
queue_access: Dict[HWCommandQueue, Dict[HWCommandQueue, Optional[int]]] = collections.defaultdict(lambda: collections.defaultdict(lambda: None))
|
||||
dev_access: Dict[HWCommandQueue, Set[str]] = collections.defaultdict(set)
|
||||
|
||||
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev.dname)
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
enqueue_dev = ji.prg.device if isinstance(ji.prg, CompiledRunner) else Device[ji.bufs[1].device] #type:ignore
|
||||
enqueue_queue = self.comp_queues[enqueue_dev] if isinstance(ji.prg, CompiledRunner) else self.copy_queues[enqueue_dev]
|
||||
out_signal = self.signals[enqueue_queue] #type:ignore
|
||||
writable_buffers = ji.prg.p.outcount if isinstance(ji.prg, CompiledRunner) else 1
|
||||
deps = self.access_resources(enqueue_queue, ji.bufs[writable_buffers:], ji.bufs[:writable_buffers], j + 1)
|
||||
enqueue_dev = ji.prg.device if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
|
||||
enqueue_queue = self.comp_queues[enqueue_dev] if is_exec_prg else self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
|
||||
out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))
|
||||
|
||||
# Profiler related info
|
||||
prof_ji_desc = ji.prg.clprg.name if isinstance(ji.prg, CompiledRunner) else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
|
||||
prof_info = ([(j * 2, True), (j * 2 + 1, True), enqueue_dev, prof_ji_desc, isinstance(ji.prg, BufferXfer)]) if PROFILE else None
|
||||
# Get dependencies based on input and output buffers.
|
||||
rdeps = self._access_resources(ji.bufs[(wb:=ji.prg.p.outcount if is_exec_prg else 1):], ji.bufs[:wb], (enqueue_queue, j + 1)) #type:ignore
|
||||
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
# Update signal on compute kernel to depend on the previous kernel.
|
||||
if (last_j:=self.last_ji[enqueue_queue]) is not None: deps = [x for x in deps if id(x[0]) != id(out_signal)] + [(out_signal, last_j + 1)]
|
||||
# Update dependencies to include previous kernel in queue. This is required for timeline signals.
|
||||
opt_deps, deps = [], rdeps + ([(enqueue_queue, prev_ji + 1)] if (prev_ji:=last_j[enqueue_queue]) is not None else [])
|
||||
|
||||
# Remove self-dependency for AMD or NV with only 1 same-queue dep, since NV chains 2+ execs in this case, eliminating dep need.
|
||||
if (dname:=enqueue_dev.dname.split(":", 1)[0]) == "AMD" or (dname == "NV" and len(deps) == 1 and id(deps[0][0]) == id(out_signal)):
|
||||
deps = [x for x in deps if id(x[0]) != id(out_signal)]
|
||||
if last_j is not None and prof_info is not None: prof_info = [(self.signal_sched[last_j][3][1][0], False)] + prof_info[1:] # type: ignore
|
||||
# Optimize dependencies by removing redundant ones. Remove waiting for the value of the queue which is known to be already
|
||||
# synced with the current queue.
|
||||
for dep_queue, dep_val in sorted(deps, key=lambda x: x[1], reverse=True):
|
||||
if (qa:=queue_access[enqueue_queue][dep_queue]) is None or qa < dep_val:
|
||||
opt_deps.append((self.signals[dep_queue], dep_val))
|
||||
queue_access[enqueue_queue][dep_queue] = dep_val
|
||||
|
||||
elif isinstance(ji.prg, BufferXfer): deps = [x for x in deps if id(x[0]) != id(out_signal)]
|
||||
# Ensure device is ready for use in current context: the graph has initialized the device and it's safe to operate on it within this graph.
|
||||
for dep_queue, _ in opt_deps: dev_access[enqueue_queue].update(dev_access[dep_queue])
|
||||
sync_signals = [(self.signals[bdev], self.kickoff_value) for b in ji.bufs if (bdev:=cast(Buffer, b).device) not in dev_access[enqueue_queue]]
|
||||
dev_access[enqueue_queue].update(cast(Buffer, b).device for b in ji.bufs)
|
||||
|
||||
# Go through all dependencies and, if we need the signal from that ji, enable it by setting the signal value in the signal schedule.
|
||||
for sig, val in deps:
|
||||
if id(sig) in [id(x) for x in self.signals.values()]:
|
||||
self.signal_sched[val - 1] = self.signal_sched[val - 1][:2] + (val,) + self.signal_sched[val - 1][3:]
|
||||
if PROFILE: self.prof_deps += [(self.signal_sched[val - 1][3], prof_info)] # type: ignore
|
||||
# Remove self-dependency for compute and copy queues.
|
||||
# For compute, in case of NV, optimize when only 1 same-queue dependency exists, since NV chains 2+ executions in this case,
|
||||
# eliminating dependency need.
|
||||
dname = enqueue_dev.dname.split(":", 1)[0]
|
||||
can_opt = (dname == "AMD" or (dname == "NV" and len(sync_signals) == 0 and len(opt_deps) == 1 and id(opt_deps[0][0]) == id(out_signal)))
|
||||
if can_opt or isinstance(ji.prg, BufferXfer): opt_deps = [x for x in opt_deps if id(x[0]) != id(out_signal)]
|
||||
|
||||
self.signal_sched[j] = (deps, out_signal, None if isinstance(ji.prg, CompiledRunner) else (j + 1), prof_info)
|
||||
self.last_ji[enqueue_queue] = j
|
||||
# Enable necessary signals in the schedule by setting the signal value.
|
||||
for sig, val in opt_deps: self.ji_schedule[val - 1] = self.ji_schedule[val - 1][:5] + (val,)
|
||||
self.ji_schedule[j] = (enqueue_dev, enqueue_queue, sync_signals, opt_deps[::-1], out_signal, None if is_exec_prg else (j + 1))
|
||||
|
||||
# Collect profile information if profiling is enabled.
|
||||
if PROFILE:
|
||||
prof_ji_desc = ji.prg.clprg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
|
||||
|
||||
sig_st, sig_en = (j * 2, True), (j * 2 + 1, True)
|
||||
if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None: sig_st = (prev_ji * 2 + 1, False)
|
||||
|
||||
self.prof_records.append((sig_st, sig_en, enqueue_dev, prof_ji_desc, not is_exec_prg, [d - 1 for _, d in rdeps]))
|
||||
|
||||
last_j[enqueue_queue] = j
|
||||
|
||||
# Build hardware queues.
|
||||
self.op_cmd_idx: Dict[int, Tuple[Any, int]] = {}
|
||||
self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
|
||||
self.copy_to_devs: Dict[HCQCompiled, Set[HCQCompiled]] = {dev: set() for dev in self.devices}
|
||||
self.kickoff_wait_cmds: Dict[HWCommandQueue, List] = {q: list() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
|
||||
|
||||
for dev in self.devices:
|
||||
self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1) \
|
||||
.wait(self.dev_kickoff_signal['CPU'], self.kickoff_value).signal(self.dev_kickoff_signal[dev.dname], self.kickoff_value)
|
||||
.wait(self.signals['CPU'], self.kickoff_value).signal(self.signals[dev.dname], self.kickoff_value)
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
deps, signal, signal_val, prof_info = self.signal_sched[j]
|
||||
enqueue_queue = self.copy_queues[Device[ji.bufs[1].device]] if isinstance(ji.prg, BufferXfer) else self.comp_queues[ji.prg.device] #type:ignore
|
||||
enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
|
||||
|
||||
for i in range(len(sync_signals)): self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) + i)
|
||||
for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val)
|
||||
|
||||
# Encode waits and start profile timestamp (if needed).
|
||||
for sig, val in deps:
|
||||
enqueue_queue.wait(sig, val)
|
||||
if id(sig) in [id(x) for x in self.dev_kickoff_signal.values()]: self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) - 1)
|
||||
if prof_info and prof_info[0][1]: enqueue_queue.timestamp(self.prof_signals[prof_info[0][0]])
|
||||
if PROFILE and self.prof_records[j][0][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][0][0]])
|
||||
|
||||
# Encode main commands based on ji type.
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
@@ -111,30 +125,30 @@ class HCQGraph(MultiGraphRunner):
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
cast(HCQAllocator, Device[src.device].allocator).map(dest._buf)
|
||||
cast(HWCopyQueue, enqueue_queue).copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes)
|
||||
self.copy_to_devs[Device[dest.device]].add(Device[src.device])
|
||||
self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
|
||||
self.op_cmd_idx[j] = (enqueue_queue, len(enqueue_queue) - 1)
|
||||
|
||||
# Encode finish profile timestamp (if needed).
|
||||
if prof_info and prof_info[1][1]: enqueue_queue.timestamp(self.prof_signals[prof_info[1][0]])
|
||||
if PROFILE and self.prof_records[j][1][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][1][0]])
|
||||
|
||||
if signal_val is not None: enqueue_queue.signal(signal, signal_val)
|
||||
|
||||
for dev in self.devices:
|
||||
for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
|
||||
if (last_j:=self.last_ji[self.copy_queues[dep_dev]]) is None: continue
|
||||
self.comp_queues[dev].wait(self.signals[self.copy_queues[dep_dev]], self.signal_sched[last_j][2])
|
||||
if dep_dev in self.copy_queues: self.comp_queues[dev].wait(self.signals[(copy_q:=self.copy_queues[dep_dev])], cast(int, last_j[copy_q]) + 1)
|
||||
|
||||
self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value)
|
||||
self.comp_queues[dev].bind(dev)
|
||||
if self.last_ji[self.copy_queues[dev]] is not None: self.copy_queues[dev].bind(dev)
|
||||
self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value).bind(dev)
|
||||
if dev in self.copy_queues: copy_q.bind(dev)
|
||||
|
||||
self.last_timeline: Dict[HCQCompiled, Tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
|
||||
self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals]
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
# Wait and restore signals
|
||||
self.kickoff_value += 1
|
||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||
for comp_queue in self.comp_queues.values(): self.signals[comp_queue].value = 0
|
||||
for copy_queue in self.copy_queues.values(): self.signals[copy_queue].value = 0
|
||||
self.dev_kickoff_signal['CPU'].value = self.kickoff_value
|
||||
for sig in self.queue_signals_to_reset: sig.value = 0
|
||||
self.signals['CPU'].value = self.kickoff_value
|
||||
|
||||
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
|
||||
|
||||
@@ -152,12 +166,12 @@ class HCQGraph(MultiGraphRunner):
|
||||
queue.update_exec(cmd_ptr, global_dims, local_dims)
|
||||
|
||||
for dev in self.devices:
|
||||
comp_queue, copy_queue, need_sig_upd = self.comp_queues[dev], self.copy_queues[dev], dev.timeline_signal != self.last_timeline[dev][0]
|
||||
comp_queue, copy_queue, need_sig_upd = self.comp_queues[dev], self.copy_queues.get(dev, None), dev.timeline_signal != self.last_timeline[dev][0]
|
||||
comp_queue.update_wait(1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value - 1) \
|
||||
.update_wait(2, value=self.kickoff_value).update_signal(3, value=self.kickoff_value) \
|
||||
.update_signal(len(comp_queue)-1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value).submit(dev)
|
||||
|
||||
if self.last_ji[copy_queue] is not None:
|
||||
if copy_queue is not None:
|
||||
for cmd_idx in self.kickoff_wait_cmds[copy_queue]: copy_queue.update_wait(cmd_idx, value=self.kickoff_value)
|
||||
copy_queue.submit(dev)
|
||||
|
||||
@@ -170,26 +184,15 @@ class HCQGraph(MultiGraphRunner):
|
||||
return time.perf_counter() - st
|
||||
return None
|
||||
|
||||
def access_resources(self, queue, read, write, new_val):
|
||||
deps = self._access_resources(read, write, (queue, new_val))
|
||||
|
||||
sync_signals = []
|
||||
for dep_queue,_ in deps: self.save_devs[queue].update(self.save_devs[dep_queue])
|
||||
for buf in read+write:
|
||||
if buf.device not in self.save_devs[queue]:
|
||||
self.save_devs[queue].add(buf.device)
|
||||
sync_signals += [(self.dev_kickoff_signal[Device[buf.device].dname], self.kickoff_value)]
|
||||
|
||||
return [(self.signals[k], max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()] + sync_signals
|
||||
|
||||
def collect_timestamps(self):
|
||||
timestamps = [s.timestamp for s in self.prof_signals]
|
||||
|
||||
for _,_,_,((st,_),(en,_),dev,desc,is_cp) in self.signal_sched.values(): # type: ignore
|
||||
for (st,_), (en,_), dev, desc, is_cp, deps in self.prof_records:
|
||||
dev.raw_prof_records += [(timestamps[st], timestamps[en], desc, is_cp)]
|
||||
|
||||
for ((a_st,_), (a_en,_), a_dev, _, a_is_cp), ((b_st,_), (b_en,_), b_dev, _, b_is_cp) in self.prof_deps:
|
||||
b_dev.dep_prof_records += [(timestamps[a_st], timestamps[a_en], a_dev, a_is_cp, timestamps[b_st], timestamps[b_en], b_dev, b_is_cp)]
|
||||
for x in deps:
|
||||
(b_st,_), (b_en,_), b_dev, _, b_is_cp, _ = self.prof_records[x]
|
||||
dev.dep_prof_records += [(timestamps[b_st], timestamps[b_en], b_dev, b_is_cp, timestamps[st], timestamps[en], dev, is_cp)]
|
||||
|
||||
def __del__(self):
|
||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||
|
||||
Reference in New Issue
Block a user