diff --git a/tinygrad/device.py b/tinygrad/device.py index c378c64c8b..27511c1cb9 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -492,7 +492,7 @@ class HCQCompiled(Compiled): self.timeline_signal, self._shadow_timeline_signal = timeline_signals self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = [] self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool]] = [] - self.dep_prof_records:List[Tuple[int, int, HCQCompiled, bool, int, int, HCQCompiled, bool]] = [] + self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = [] if PROFILE: self._prof_setup() from tinygrad.runtime.graph.hcq import HCQGraph diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index 61b134eb61..298370e8d3 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -38,71 +38,85 @@ class HCQGraph(MultiGraphRunner): # graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with # global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the device’s # compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue. - self.comp_queues: Dict[Compiled, HWComputeQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices} - self.copy_queues: Dict[Compiled, HWCopyQueue] = {dev: dev.hw_copy_queue_t() for dev in self.devices} + self.ji_schedule: Dict[int, Tuple[HCQCompiled, HWCommandQueue, List, List, HCQSignal, Optional[int]]] = {} - self.signal_sched: Dict[int, Tuple[List, HCQSignal, Optional[int], Optional[List]]] = {} # Dict[ji_idx, (deps, signal, sigval, prof_info)] - self.signals = {q: dev.signal_t(value=0) for queues in (self.comp_queues, self.copy_queues) for dev,q in queues.items()} #type:ignore - self.dev_kickoff_signal = {**{dev.dname: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}} - self.kickoff_value = 0 + self.comp_queues: Dict[HCQCompiled, HWComputeQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices} + self.copy_queues: Dict[HCQCompiled, HWCopyQueue] = {} # lazy allocation - self.save_devs: Dict[HWCommandQueue, Set] = {q: set() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())} - for dev in self.devices: self.save_devs[self.comp_queues[dev]].add(dev) + self.signals: Dict[Any, HCQSignal] = {**{dev.dname: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}} + self.kickoff_value: int = 0 - self.last_timeline: Dict[HCQCompiled, Tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices} - self.last_ji: Dict[HWCommandQueue, Optional[int]] = {q: None for q in list(self.comp_queues.values()) + list(self.copy_queues.values())} self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(self.jit_cache) * 2)] if PROFILE else [] - self.prof_deps: List[Tuple[Tuple, Tuple]] = [] + self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int]]] = [] + + last_j: Dict[HWCommandQueue, Optional[int]] = collections.defaultdict(lambda: None) + queue_access: Dict[HWCommandQueue, Dict[HWCommandQueue, Optional[int]]] = collections.defaultdict(lambda: collections.defaultdict(lambda: None)) + dev_access: Dict[HWCommandQueue, Set[str]] = collections.defaultdict(set) + + for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev.dname) for j,ji in enumerate(self.jit_cache): - enqueue_dev = ji.prg.device if isinstance(ji.prg, CompiledRunner) else Device[ji.bufs[1].device] #type:ignore - enqueue_queue = self.comp_queues[enqueue_dev] if isinstance(ji.prg, CompiledRunner) else self.copy_queues[enqueue_dev] - out_signal = self.signals[enqueue_queue] #type:ignore - writable_buffers = ji.prg.p.outcount if isinstance(ji.prg, CompiledRunner) else 1 - deps = self.access_resources(enqueue_queue, ji.bufs[writable_buffers:], ji.bufs[:writable_buffers], j + 1) + enqueue_dev = ji.prg.device if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore + enqueue_queue = self.comp_queues[enqueue_dev] if is_exec_prg else self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t()) + out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0)) - # Profiler related info - prof_ji_desc = ji.prg.clprg.name if isinstance(ji.prg, CompiledRunner) else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore - prof_info = ([(j * 2, True), (j * 2 + 1, True), enqueue_dev, prof_ji_desc, isinstance(ji.prg, BufferXfer)]) if PROFILE else None + # Get dependencies based on input and output buffers. + rdeps = self._access_resources(ji.bufs[(wb:=ji.prg.p.outcount if is_exec_prg else 1):], ji.bufs[:wb], (enqueue_queue, j + 1)) #type:ignore - if isinstance(ji.prg, CompiledRunner): - # Update signal on compute kernel to depend on the previous kernel. - if (last_j:=self.last_ji[enqueue_queue]) is not None: deps = [x for x in deps if id(x[0]) != id(out_signal)] + [(out_signal, last_j + 1)] + # Update dependencies to include previous kernel in queue. This is required for timeline signals. + opt_deps, deps = [], rdeps + ([(enqueue_queue, prev_ji + 1)] if (prev_ji:=last_j[enqueue_queue]) is not None else []) - # Remove self-dependency for AMD or NV with only 1 same-queue dep, since NV chains 2+ execs in this case, eliminating dep need. - if (dname:=enqueue_dev.dname.split(":", 1)[0]) == "AMD" or (dname == "NV" and len(deps) == 1 and id(deps[0][0]) == id(out_signal)): - deps = [x for x in deps if id(x[0]) != id(out_signal)] - if last_j is not None and prof_info is not None: prof_info = [(self.signal_sched[last_j][3][1][0], False)] + prof_info[1:] # type: ignore + # Optimize dependencies by removing redundant ones. Remove waiting for the value of the queue which is known to be already + # synced with the current queue. + for dep_queue, dep_val in sorted(deps, key=lambda x: x[1], reverse=True): + if (qa:=queue_access[enqueue_queue][dep_queue]) is None or qa < dep_val: + opt_deps.append((self.signals[dep_queue], dep_val)) + queue_access[enqueue_queue][dep_queue] = dep_val - elif isinstance(ji.prg, BufferXfer): deps = [x for x in deps if id(x[0]) != id(out_signal)] + # Ensure device is ready for use in current context: the graph has initialized the device and it's safe to operate on it within this graph. + for dep_queue, _ in opt_deps: dev_access[enqueue_queue].update(dev_access[dep_queue]) + sync_signals = [(self.signals[bdev], self.kickoff_value) for b in ji.bufs if (bdev:=cast(Buffer, b).device) not in dev_access[enqueue_queue]] + dev_access[enqueue_queue].update(cast(Buffer, b).device for b in ji.bufs) - # Go through all dependencies and, if we need the signal from that ji, enable it by setting the signal value in the signal schedule. - for sig, val in deps: - if id(sig) in [id(x) for x in self.signals.values()]: - self.signal_sched[val - 1] = self.signal_sched[val - 1][:2] + (val,) + self.signal_sched[val - 1][3:] - if PROFILE: self.prof_deps += [(self.signal_sched[val - 1][3], prof_info)] # type: ignore + # Remove self-dependency for compute and copy queues. + # For compute, in case of NV, optimize when only 1 same-queue dependency exists, since NV chains 2+ executions in this case, + # eliminating dependency need. + dname = enqueue_dev.dname.split(":", 1)[0] + can_opt = (dname == "AMD" or (dname == "NV" and len(sync_signals) == 0 and len(opt_deps) == 1 and id(opt_deps[0][0]) == id(out_signal))) + if can_opt or isinstance(ji.prg, BufferXfer): opt_deps = [x for x in opt_deps if id(x[0]) != id(out_signal)] - self.signal_sched[j] = (deps, out_signal, None if isinstance(ji.prg, CompiledRunner) else (j + 1), prof_info) - self.last_ji[enqueue_queue] = j + # Enable necessary signals in the schedule by setting the signal value. + for sig, val in opt_deps: self.ji_schedule[val - 1] = self.ji_schedule[val - 1][:5] + (val,) + self.ji_schedule[j] = (enqueue_dev, enqueue_queue, sync_signals, opt_deps[::-1], out_signal, None if is_exec_prg else (j + 1)) + + # Collect profile information if profiling is enabled. + if PROFILE: + prof_ji_desc = ji.prg.clprg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore + + sig_st, sig_en = (j * 2, True), (j * 2 + 1, True) + if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None: sig_st = (prev_ji * 2 + 1, False) + + self.prof_records.append((sig_st, sig_en, enqueue_dev, prof_ji_desc, not is_exec_prg, [d - 1 for _, d in rdeps])) + + last_j[enqueue_queue] = j # Build hardware queues. self.op_cmd_idx: Dict[int, Tuple[Any, int]] = {} - self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices} + self.copy_to_devs: Dict[HCQCompiled, Set[HCQCompiled]] = {dev: set() for dev in self.devices} self.kickoff_wait_cmds: Dict[HWCommandQueue, List] = {q: list() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())} for dev in self.devices: self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1) \ - .wait(self.dev_kickoff_signal['CPU'], self.kickoff_value).signal(self.dev_kickoff_signal[dev.dname], self.kickoff_value) + .wait(self.signals['CPU'], self.kickoff_value).signal(self.signals[dev.dname], self.kickoff_value) for j,ji in enumerate(self.jit_cache): - deps, signal, signal_val, prof_info = self.signal_sched[j] - enqueue_queue = self.copy_queues[Device[ji.bufs[1].device]] if isinstance(ji.prg, BufferXfer) else self.comp_queues[ji.prg.device] #type:ignore + enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j] + + for i in range(len(sync_signals)): self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) + i) + for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val) # Encode waits and start profile timestamp (if needed). - for sig, val in deps: - enqueue_queue.wait(sig, val) - if id(sig) in [id(x) for x in self.dev_kickoff_signal.values()]: self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) - 1) - if prof_info and prof_info[0][1]: enqueue_queue.timestamp(self.prof_signals[prof_info[0][0]]) + if PROFILE and self.prof_records[j][0][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][0][0]]) # Encode main commands based on ji type. if isinstance(ji.prg, CompiledRunner): @@ -111,30 +125,30 @@ class HCQGraph(MultiGraphRunner): dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]] cast(HCQAllocator, Device[src.device].allocator).map(dest._buf) cast(HWCopyQueue, enqueue_queue).copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes) - self.copy_to_devs[Device[dest.device]].add(Device[src.device]) + self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device])) self.op_cmd_idx[j] = (enqueue_queue, len(enqueue_queue) - 1) # Encode finish profile timestamp (if needed). - if prof_info and prof_info[1][1]: enqueue_queue.timestamp(self.prof_signals[prof_info[1][0]]) + if PROFILE and self.prof_records[j][1][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][1][0]]) if signal_val is not None: enqueue_queue.signal(signal, signal_val) for dev in self.devices: for dep_dev in list(self.copy_to_devs[dev]) + [dev]: - if (last_j:=self.last_ji[self.copy_queues[dep_dev]]) is None: continue - self.comp_queues[dev].wait(self.signals[self.copy_queues[dep_dev]], self.signal_sched[last_j][2]) + if dep_dev in self.copy_queues: self.comp_queues[dev].wait(self.signals[(copy_q:=self.copy_queues[dep_dev])], cast(int, last_j[copy_q]) + 1) - self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value) - self.comp_queues[dev].bind(dev) - if self.last_ji[self.copy_queues[dev]] is not None: self.copy_queues[dev].bind(dev) + self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value).bind(dev) + if dev in self.copy_queues: copy_q.bind(dev) + + self.last_timeline: Dict[HCQCompiled, Tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices} + self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals] def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]: # Wait and restore signals self.kickoff_value += 1 for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1]) - for comp_queue in self.comp_queues.values(): self.signals[comp_queue].value = 0 - for copy_queue in self.copy_queues.values(): self.signals[copy_queue].value = 0 - self.dev_kickoff_signal['CPU'].value = self.kickoff_value + for sig in self.queue_signals_to_reset: sig.value = 0 + self.signals['CPU'].value = self.kickoff_value if PROFILE and self.kickoff_value > 1: self.collect_timestamps() @@ -152,12 +166,12 @@ class HCQGraph(MultiGraphRunner): queue.update_exec(cmd_ptr, global_dims, local_dims) for dev in self.devices: - comp_queue, copy_queue, need_sig_upd = self.comp_queues[dev], self.copy_queues[dev], dev.timeline_signal != self.last_timeline[dev][0] + comp_queue, copy_queue, need_sig_upd = self.comp_queues[dev], self.copy_queues.get(dev, None), dev.timeline_signal != self.last_timeline[dev][0] comp_queue.update_wait(1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value - 1) \ .update_wait(2, value=self.kickoff_value).update_signal(3, value=self.kickoff_value) \ .update_signal(len(comp_queue)-1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value).submit(dev) - if self.last_ji[copy_queue] is not None: + if copy_queue is not None: for cmd_idx in self.kickoff_wait_cmds[copy_queue]: copy_queue.update_wait(cmd_idx, value=self.kickoff_value) copy_queue.submit(dev) @@ -170,26 +184,15 @@ class HCQGraph(MultiGraphRunner): return time.perf_counter() - st return None - def access_resources(self, queue, read, write, new_val): - deps = self._access_resources(read, write, (queue, new_val)) - - sync_signals = [] - for dep_queue,_ in deps: self.save_devs[queue].update(self.save_devs[dep_queue]) - for buf in read+write: - if buf.device not in self.save_devs[queue]: - self.save_devs[queue].add(buf.device) - sync_signals += [(self.dev_kickoff_signal[Device[buf.device].dname], self.kickoff_value)] - - return [(self.signals[k], max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()] + sync_signals - def collect_timestamps(self): timestamps = [s.timestamp for s in self.prof_signals] - for _,_,_,((st,_),(en,_),dev,desc,is_cp) in self.signal_sched.values(): # type: ignore + for (st,_), (en,_), dev, desc, is_cp, deps in self.prof_records: dev.raw_prof_records += [(timestamps[st], timestamps[en], desc, is_cp)] - for ((a_st,_), (a_en,_), a_dev, _, a_is_cp), ((b_st,_), (b_en,_), b_dev, _, b_is_cp) in self.prof_deps: - b_dev.dep_prof_records += [(timestamps[a_st], timestamps[a_en], a_dev, a_is_cp, timestamps[b_st], timestamps[b_en], b_dev, b_is_cp)] + for x in deps: + (b_st,_), (b_en,_), b_dev, _, b_is_cp, _ = self.prof_records[x] + dev.dep_prof_records += [(timestamps[b_st], timestamps[b_en], b_dev, b_is_cp, timestamps[st], timestamps[en], dev, is_cp)] def __del__(self): for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])