hcq graph refactor (#5887)

* cleanup

* prof

* cleaner

* comments

* more types
This commit is contained in:
nimlgen
2024-08-03 23:35:33 +03:00
committed by GitHub
parent da61dea1b2
commit dad8e72ee9
2 changed files with 73 additions and 70 deletions

View File

@@ -492,7 +492,7 @@ class HCQCompiled(Compiled):
self.timeline_signal, self._shadow_timeline_signal = timeline_signals
self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = []
self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool]] = []
self.dep_prof_records:List[Tuple[int, int, HCQCompiled, bool, int, int, HCQCompiled, bool]] = []
self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = []
if PROFILE: self._prof_setup()
from tinygrad.runtime.graph.hcq import HCQGraph

View File

@@ -38,71 +38,85 @@ class HCQGraph(MultiGraphRunner):
# graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with
# global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the devices
# compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue.
self.comp_queues: Dict[Compiled, HWComputeQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
self.copy_queues: Dict[Compiled, HWCopyQueue] = {dev: dev.hw_copy_queue_t() for dev in self.devices}
self.ji_schedule: Dict[int, Tuple[HCQCompiled, HWCommandQueue, List, List, HCQSignal, Optional[int]]] = {}
self.signal_sched: Dict[int, Tuple[List, HCQSignal, Optional[int], Optional[List]]] = {} # Dict[ji_idx, (deps, signal, sigval, prof_info)]
self.signals = {q: dev.signal_t(value=0) for queues in (self.comp_queues, self.copy_queues) for dev,q in queues.items()} #type:ignore
self.dev_kickoff_signal = {**{dev.dname: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}}
self.kickoff_value = 0
self.comp_queues: Dict[HCQCompiled, HWComputeQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
self.copy_queues: Dict[HCQCompiled, HWCopyQueue] = {} # lazy allocation
self.save_devs: Dict[HWCommandQueue, Set] = {q: set() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
for dev in self.devices: self.save_devs[self.comp_queues[dev]].add(dev)
self.signals: Dict[Any, HCQSignal] = {**{dev.dname: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}}
self.kickoff_value: int = 0
self.last_timeline: Dict[HCQCompiled, Tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
self.last_ji: Dict[HWCommandQueue, Optional[int]] = {q: None for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(self.jit_cache) * 2)] if PROFILE else []
self.prof_deps: List[Tuple[Tuple, Tuple]] = []
self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int]]] = []
last_j: Dict[HWCommandQueue, Optional[int]] = collections.defaultdict(lambda: None)
queue_access: Dict[HWCommandQueue, Dict[HWCommandQueue, Optional[int]]] = collections.defaultdict(lambda: collections.defaultdict(lambda: None))
dev_access: Dict[HWCommandQueue, Set[str]] = collections.defaultdict(set)
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev.dname)
for j,ji in enumerate(self.jit_cache):
enqueue_dev = ji.prg.device if isinstance(ji.prg, CompiledRunner) else Device[ji.bufs[1].device] #type:ignore
enqueue_queue = self.comp_queues[enqueue_dev] if isinstance(ji.prg, CompiledRunner) else self.copy_queues[enqueue_dev]
out_signal = self.signals[enqueue_queue] #type:ignore
writable_buffers = ji.prg.p.outcount if isinstance(ji.prg, CompiledRunner) else 1
deps = self.access_resources(enqueue_queue, ji.bufs[writable_buffers:], ji.bufs[:writable_buffers], j + 1)
enqueue_dev = ji.prg.device if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
enqueue_queue = self.comp_queues[enqueue_dev] if is_exec_prg else self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))
# Profiler related info
prof_ji_desc = ji.prg.clprg.name if isinstance(ji.prg, CompiledRunner) else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
prof_info = ([(j * 2, True), (j * 2 + 1, True), enqueue_dev, prof_ji_desc, isinstance(ji.prg, BufferXfer)]) if PROFILE else None
# Get dependencies based on input and output buffers.
rdeps = self._access_resources(ji.bufs[(wb:=ji.prg.p.outcount if is_exec_prg else 1):], ji.bufs[:wb], (enqueue_queue, j + 1)) #type:ignore
if isinstance(ji.prg, CompiledRunner):
# Update signal on compute kernel to depend on the previous kernel.
if (last_j:=self.last_ji[enqueue_queue]) is not None: deps = [x for x in deps if id(x[0]) != id(out_signal)] + [(out_signal, last_j + 1)]
# Update dependencies to include previous kernel in queue. This is required for timeline signals.
opt_deps, deps = [], rdeps + ([(enqueue_queue, prev_ji + 1)] if (prev_ji:=last_j[enqueue_queue]) is not None else [])
# Remove self-dependency for AMD or NV with only 1 same-queue dep, since NV chains 2+ execs in this case, eliminating dep need.
if (dname:=enqueue_dev.dname.split(":", 1)[0]) == "AMD" or (dname == "NV" and len(deps) == 1 and id(deps[0][0]) == id(out_signal)):
deps = [x for x in deps if id(x[0]) != id(out_signal)]
if last_j is not None and prof_info is not None: prof_info = [(self.signal_sched[last_j][3][1][0], False)] + prof_info[1:] # type: ignore
# Optimize dependencies by removing redundant ones. Remove waiting for the value of the queue which is known to be already
# synced with the current queue.
for dep_queue, dep_val in sorted(deps, key=lambda x: x[1], reverse=True):
if (qa:=queue_access[enqueue_queue][dep_queue]) is None or qa < dep_val:
opt_deps.append((self.signals[dep_queue], dep_val))
queue_access[enqueue_queue][dep_queue] = dep_val
elif isinstance(ji.prg, BufferXfer): deps = [x for x in deps if id(x[0]) != id(out_signal)]
# Ensure device is ready for use in current context: the graph has initialized the device and it's safe to operate on it within this graph.
for dep_queue, _ in opt_deps: dev_access[enqueue_queue].update(dev_access[dep_queue])
sync_signals = [(self.signals[bdev], self.kickoff_value) for b in ji.bufs if (bdev:=cast(Buffer, b).device) not in dev_access[enqueue_queue]]
dev_access[enqueue_queue].update(cast(Buffer, b).device for b in ji.bufs)
# Go through all dependencies and, if we need the signal from that ji, enable it by setting the signal value in the signal schedule.
for sig, val in deps:
if id(sig) in [id(x) for x in self.signals.values()]:
self.signal_sched[val - 1] = self.signal_sched[val - 1][:2] + (val,) + self.signal_sched[val - 1][3:]
if PROFILE: self.prof_deps += [(self.signal_sched[val - 1][3], prof_info)] # type: ignore
# Remove self-dependency for compute and copy queues.
# For compute, in case of NV, optimize when only 1 same-queue dependency exists, since NV chains 2+ executions in this case,
# eliminating dependency need.
dname = enqueue_dev.dname.split(":", 1)[0]
can_opt = (dname == "AMD" or (dname == "NV" and len(sync_signals) == 0 and len(opt_deps) == 1 and id(opt_deps[0][0]) == id(out_signal)))
if can_opt or isinstance(ji.prg, BufferXfer): opt_deps = [x for x in opt_deps if id(x[0]) != id(out_signal)]
self.signal_sched[j] = (deps, out_signal, None if isinstance(ji.prg, CompiledRunner) else (j + 1), prof_info)
self.last_ji[enqueue_queue] = j
# Enable necessary signals in the schedule by setting the signal value.
for sig, val in opt_deps: self.ji_schedule[val - 1] = self.ji_schedule[val - 1][:5] + (val,)
self.ji_schedule[j] = (enqueue_dev, enqueue_queue, sync_signals, opt_deps[::-1], out_signal, None if is_exec_prg else (j + 1))
# Collect profile information if profiling is enabled.
if PROFILE:
prof_ji_desc = ji.prg.clprg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
sig_st, sig_en = (j * 2, True), (j * 2 + 1, True)
if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None: sig_st = (prev_ji * 2 + 1, False)
self.prof_records.append((sig_st, sig_en, enqueue_dev, prof_ji_desc, not is_exec_prg, [d - 1 for _, d in rdeps]))
last_j[enqueue_queue] = j
# Build hardware queues.
self.op_cmd_idx: Dict[int, Tuple[Any, int]] = {}
self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
self.copy_to_devs: Dict[HCQCompiled, Set[HCQCompiled]] = {dev: set() for dev in self.devices}
self.kickoff_wait_cmds: Dict[HWCommandQueue, List] = {q: list() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
for dev in self.devices:
self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1) \
.wait(self.dev_kickoff_signal['CPU'], self.kickoff_value).signal(self.dev_kickoff_signal[dev.dname], self.kickoff_value)
.wait(self.signals['CPU'], self.kickoff_value).signal(self.signals[dev.dname], self.kickoff_value)
for j,ji in enumerate(self.jit_cache):
deps, signal, signal_val, prof_info = self.signal_sched[j]
enqueue_queue = self.copy_queues[Device[ji.bufs[1].device]] if isinstance(ji.prg, BufferXfer) else self.comp_queues[ji.prg.device] #type:ignore
enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
for i in range(len(sync_signals)): self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) + i)
for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val)
# Encode waits and start profile timestamp (if needed).
for sig, val in deps:
enqueue_queue.wait(sig, val)
if id(sig) in [id(x) for x in self.dev_kickoff_signal.values()]: self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) - 1)
if prof_info and prof_info[0][1]: enqueue_queue.timestamp(self.prof_signals[prof_info[0][0]])
if PROFILE and self.prof_records[j][0][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][0][0]])
# Encode main commands based on ji type.
if isinstance(ji.prg, CompiledRunner):
@@ -111,30 +125,30 @@ class HCQGraph(MultiGraphRunner):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
cast(HCQAllocator, Device[src.device].allocator).map(dest._buf)
cast(HWCopyQueue, enqueue_queue).copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes)
self.copy_to_devs[Device[dest.device]].add(Device[src.device])
self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
self.op_cmd_idx[j] = (enqueue_queue, len(enqueue_queue) - 1)
# Encode finish profile timestamp (if needed).
if prof_info and prof_info[1][1]: enqueue_queue.timestamp(self.prof_signals[prof_info[1][0]])
if PROFILE and self.prof_records[j][1][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][1][0]])
if signal_val is not None: enqueue_queue.signal(signal, signal_val)
for dev in self.devices:
for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
if (last_j:=self.last_ji[self.copy_queues[dep_dev]]) is None: continue
self.comp_queues[dev].wait(self.signals[self.copy_queues[dep_dev]], self.signal_sched[last_j][2])
if dep_dev in self.copy_queues: self.comp_queues[dev].wait(self.signals[(copy_q:=self.copy_queues[dep_dev])], cast(int, last_j[copy_q]) + 1)
self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value)
self.comp_queues[dev].bind(dev)
if self.last_ji[self.copy_queues[dev]] is not None: self.copy_queues[dev].bind(dev)
self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value).bind(dev)
if dev in self.copy_queues: copy_q.bind(dev)
self.last_timeline: Dict[HCQCompiled, Tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals]
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
# Wait and restore signals
self.kickoff_value += 1
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
for comp_queue in self.comp_queues.values(): self.signals[comp_queue].value = 0
for copy_queue in self.copy_queues.values(): self.signals[copy_queue].value = 0
self.dev_kickoff_signal['CPU'].value = self.kickoff_value
for sig in self.queue_signals_to_reset: sig.value = 0
self.signals['CPU'].value = self.kickoff_value
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
@@ -152,12 +166,12 @@ class HCQGraph(MultiGraphRunner):
queue.update_exec(cmd_ptr, global_dims, local_dims)
for dev in self.devices:
comp_queue, copy_queue, need_sig_upd = self.comp_queues[dev], self.copy_queues[dev], dev.timeline_signal != self.last_timeline[dev][0]
comp_queue, copy_queue, need_sig_upd = self.comp_queues[dev], self.copy_queues.get(dev, None), dev.timeline_signal != self.last_timeline[dev][0]
comp_queue.update_wait(1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value - 1) \
.update_wait(2, value=self.kickoff_value).update_signal(3, value=self.kickoff_value) \
.update_signal(len(comp_queue)-1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value).submit(dev)
if self.last_ji[copy_queue] is not None:
if copy_queue is not None:
for cmd_idx in self.kickoff_wait_cmds[copy_queue]: copy_queue.update_wait(cmd_idx, value=self.kickoff_value)
copy_queue.submit(dev)
@@ -170,26 +184,15 @@ class HCQGraph(MultiGraphRunner):
return time.perf_counter() - st
return None
def access_resources(self, queue, read, write, new_val):
deps = self._access_resources(read, write, (queue, new_val))
sync_signals = []
for dep_queue,_ in deps: self.save_devs[queue].update(self.save_devs[dep_queue])
for buf in read+write:
if buf.device not in self.save_devs[queue]:
self.save_devs[queue].add(buf.device)
sync_signals += [(self.dev_kickoff_signal[Device[buf.device].dname], self.kickoff_value)]
return [(self.signals[k], max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()] + sync_signals
def collect_timestamps(self):
timestamps = [s.timestamp for s in self.prof_signals]
for _,_,_,((st,_),(en,_),dev,desc,is_cp) in self.signal_sched.values(): # type: ignore
for (st,_), (en,_), dev, desc, is_cp, deps in self.prof_records:
dev.raw_prof_records += [(timestamps[st], timestamps[en], desc, is_cp)]
for ((a_st,_), (a_en,_), a_dev, _, a_is_cp), ((b_st,_), (b_en,_), b_dev, _, b_is_cp) in self.prof_deps:
b_dev.dep_prof_records += [(timestamps[a_st], timestamps[a_en], a_dev, a_is_cp, timestamps[b_st], timestamps[b_en], b_dev, b_is_cp)]
for x in deps:
(b_st,_), (b_en,_), b_dev, _, b_is_cp, _ = self.prof_records[x]
dev.dep_prof_records += [(timestamps[b_st], timestamps[b_en], b_dev, b_is_cp, timestamps[st], timestamps[en], dev, is_cp)]
def __del__(self):
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])