hcq more types (#5791)

* mhcq more types

* linter

* pylint

* docs: bind
This commit is contained in:
nimlgen
2024-07-29 18:03:23 +03:00
committed by GitHub
parent 9c80f9adf9
commit 71e1472290
5 changed files with 46 additions and 27 deletions

View File

@@ -26,6 +26,7 @@ Each runtime should implement the required functions that are defined in the `HW
"timestamp",
"update_signal",
"update_wait",
"bind",
"submit",
]
show_source: false

View File

@@ -289,6 +289,20 @@ class HWCommandQueue:
return self
def _update_wait(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
def bind(self, device:HCQCompiled):
"""
Associates the queue with a specific device for optimized execution.
This optional method allows backend implementations to tailor the queue for efficient use on the given device. When implemented, it can eliminate
the need to copy queues into the device, thereby enhancing performance.
Args:
device: The target device for queue optimization.
Note:
Implementing this method is optional but recommended for performance gains.
"""
def submit(self, device:HCQCompiled):
"""
Submits the command queue to a specific device for execution.
@@ -366,6 +380,8 @@ class HWCopyQueue(HWCommandQueue):
def _update_copy(self, cmd_idx, dest, src): raise NotImplementedError("backend should overload this function")
class HCQSignal:
def __init__(self, value:int=0): self._set_value(value)
@property
def value(self) -> int: return self._get_value()

View File

@@ -1,7 +1,8 @@
import collections, time
from typing import List, Any, Dict, cast, Optional, Tuple, Set
from tinygrad.helpers import round_up, to_mv, PROFILE
from tinygrad.device import HCQCompiled, HCQAllocator, HCQSignal, Buffer, BufferOptions, Compiled, Device
from tinygrad.device import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWCommandQueue, HWComputeQueue, HWCopyQueue, \
Buffer, BufferOptions, Compiled, Device
from tinygrad.shape.symbolic import Variable
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
from tinygrad.engine.jit import MultiGraphRunner
@@ -9,14 +10,14 @@ from tinygrad.engine.jit import MultiGraphRunner
class HCQGraph(MultiGraphRunner):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
self.devices = list(set(cast(Any, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
# Allocate kernel args.
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
for ji in self.jit_cache:
if not isinstance(ji.prg, CompiledRunner): continue
kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_alloc_size, 16)
self.kernargs_bufs: Dict[Compiled, Any] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)) for dev,sz in kernargs_size.items()}
self.kernargs_bufs: Dict[Compiled, HCQBuffer] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)) for dev,sz in kernargs_size.items()}
kernargs_ptrs: Dict[Compiled, int] = {dev:buf.va_addr for dev,buf in self.kernargs_bufs.items()}
# Fill initial arguments.
@@ -37,19 +38,19 @@ class HCQGraph(MultiGraphRunner):
# graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with
# global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the devices
# compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue.
self.comp_queues: Dict[Compiled, Any] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
self.copy_queues: Dict[Compiled, Any] = {dev: dev.hw_copy_queue_t() for dev in self.devices}
self.comp_queues: Dict[Compiled, HWComputeQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
self.copy_queues: Dict[Compiled, HWCopyQueue] = {dev: dev.hw_copy_queue_t() for dev in self.devices}
self.signal_sched: Dict[int, Tuple[List, Any, Optional[int], Optional[List]]] = {} # Dict[ji_idx, (deps, signal, sigval, prof_info)]
self.signals = {q: self.devices[0].signal_t(value=0) for q in list(self.comp_queues.values())+list(self.copy_queues.values())}
self.dev_kickoff_signal = {dev: self.devices[0].signal_t(value=0) for dev in self.devices + ['CPU']} # Dict[dev, signal]
self.signal_sched: Dict[int, Tuple[List, HCQSignal, Optional[int], Optional[List]]] = {} # Dict[ji_idx, (deps, signal, sigval, prof_info)]
self.signals = {q: dev.signal_t(value=0) for queues in (self.comp_queues, self.copy_queues) for dev,q in queues.items()} #type:ignore
self.dev_kickoff_signal = {**{dev.dname: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}}
self.kickoff_value = 0
self.save_devs: Dict[Any, Set] = {q: set() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
self.save_devs: Dict[HWCommandQueue, Set] = {q: set() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
for dev in self.devices: self.save_devs[self.comp_queues[dev]].add(dev)
self.last_timeline: Dict[HCQCompiled, Tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
self.last_ji: Dict[Any, Any] = {q: None for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
self.last_ji: Dict[HWCommandQueue, Optional[int]] = {q: None for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
for j,ji in enumerate(self.jit_cache):
enqueue_dev = ji.prg.device if isinstance(ji.prg, CompiledRunner) else Device[ji.bufs[1].device] #type:ignore
@@ -80,11 +81,11 @@ class HCQGraph(MultiGraphRunner):
# Build hardware queues.
self.op_cmd_idx: Dict[int, Tuple[Any, int]] = {}
self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
self.kickoff_wait_cmds: Dict[Any, List] = {q: list() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
self.kickoff_wait_cmds: Dict[HWCommandQueue, List] = {q: list() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
for dev in self.devices:
self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1) \
.wait(self.dev_kickoff_signal['CPU'], self.kickoff_value).signal(self.dev_kickoff_signal[dev], self.kickoff_value)
.wait(self.dev_kickoff_signal['CPU'], self.kickoff_value).signal(self.dev_kickoff_signal[dev.dname], self.kickoff_value)
for j,ji in enumerate(self.jit_cache):
deps, signal, signal_val, prof_info = self.signal_sched[j]
@@ -97,11 +98,12 @@ class HCQGraph(MultiGraphRunner):
if prof_info: enqueue_queue.timestamp(prof_info[0])
# Encode main commands based on ji type.
if isinstance(ji.prg, CompiledRunner): enqueue_queue.exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals))
if isinstance(ji.prg, CompiledRunner):
cast(HWComputeQueue, enqueue_queue).exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals))
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
cast(HCQAllocator, Device[src.device].allocator).map(dest._buf)
enqueue_queue.copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes)
cast(HWCopyQueue, enqueue_queue).copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes)
self.copy_to_devs[Device[dest.device]].add(Device[src.device])
self.op_cmd_idx[j] = (enqueue_queue, len(enqueue_queue) - 1)
@@ -116,15 +118,15 @@ class HCQGraph(MultiGraphRunner):
self.comp_queues[dev].wait(self.signals[self.copy_queues[dep_dev]], self.signal_sched[last_j][2])
self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value)
if hasattr(self.comp_queues[dev], 'bind'): self.comp_queues[dev].bind(dev)
if hasattr(self.copy_queues[dev], 'bind') and self.last_ji[self.copy_queues[dev]] is not None: self.copy_queues[dev].bind(dev)
self.comp_queues[dev].bind(dev)
if self.last_ji[self.copy_queues[dev]] is not None: self.copy_queues[dev].bind(dev)
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
# Wait and restore signals
self.kickoff_value += 1
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
for queue in self.comp_queues.values(): self.signals[queue].value = 0
for queue in self.copy_queues.values(): self.signals[queue].value = 0
for comp_queue in self.comp_queues.values(): self.signals[comp_queue].value = 0
for copy_queue in self.copy_queues.values(): self.signals[copy_queue].value = 0
self.dev_kickoff_signal['CPU'].value = self.kickoff_value
if PROFILE and self.kickoff_value > 1:
@@ -171,7 +173,7 @@ class HCQGraph(MultiGraphRunner):
for buf in read+write:
if buf.device not in self.save_devs[queue]:
self.save_devs[queue].add(buf.device)
sync_signals += [(self.dev_kickoff_signal[Device[buf.device]], self.kickoff_value)]
sync_signals += [(self.dev_kickoff_signal[Device[buf.device].dname], self.kickoff_value)]
return [(self.signals[k], max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()] + sync_signals
@@ -182,4 +184,4 @@ class HCQGraph(MultiGraphRunner):
if PROFILE and self.kickoff_value > 1:
for _,_,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): dev.sig_prof_records += [(st, en, desc, is_cp)] #type: ignore
for dev, buf in self.kernargs_bufs.items(): dev.allocator._free(buf, BufferOptions(cpu_access=True))
for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferOptions(cpu_access=True))

View File

@@ -40,7 +40,6 @@ def nbioreg(reg): return reg + 0x00000d20 # NBIO_BASE__INST0_SEG2
class AMDSignal(HCQSignal):
def __init__(self, value=0, alloc_event=False):
self._signal = AMDDevice.signals_pool.pop()
self._signal[0] = value
self._value_addr, self._timestamp_addr = mv_address(self._signal), mv_address(self._signal) + 8
if alloc_event:
sync_event = kio.create_event(AMDDevice.kfd, auto_reset=1)
@@ -48,6 +47,7 @@ class AMDSignal(HCQSignal):
self._event_id = sync_event.event_id
self._evt_array = (kfd.struct_kfd_event_data)(event_id=self._event_id)
else: self._event_mailbox_ptr = self._event_id = 0
super().__init__(value)
def __del__(self): AMDDevice.signals_pool.append(self._signal)
def _get_value(self) -> int: return self._signal[0]
def _get_timestamp(self) -> float: return self._signal[1] / 1e2
@@ -163,9 +163,9 @@ class AMDComputeQueue(HWComputeQueue):
if signal is not None and self.cmds_len[cmd_idx] > 8:
self._patch(cmd_idx, offset=11, data=[*data64_le(signal._event_mailbox_ptr), *data64_le(signal._event_id), signal._event_id])
def bind(self, device: AMDDevice):
def bind(self, device):
self.binded_device = device
self.hw_page = device._gpu_alloc(len(self.q) * 4, kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT, uncached=True)
self.hw_page = cast(AMDDevice, device)._gpu_alloc(len(self.q) * 4, kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT, uncached=True)
hw_view = to_mv(self.hw_page.va_addr, self.hw_page.size).cast("I")
for i, value in enumerate(self.q): hw_view[i] = value

View File

@@ -64,9 +64,9 @@ def nvmethod(subc, mthd, size, typ=2): return (typ << 28) | (size << 16) | (subc
class NVSignal(HCQSignal):
def __init__(self, value=0, **kwargs):
def __init__(self, value=0):
self._signal = NVDevice.signals_pool.pop()
self._signal[0] = value
super().__init__(value)
def __del__(self): NVDevice.signals_pool.append(self._signal)
def _get_value(self) -> int: return self._signal[0]
def _get_timestamp(self) -> float: return self._signal[1] / 1e3
@@ -107,9 +107,9 @@ class NVCommandQueue(HWCommandQueue): # pylint: disable=abstract-method
if signal is not None: self.q[(sigoff:=self.cmds_offset[cmd_idx]+1):sigoff+2] = array.array('I', data64_le(mv_address(signal._signal)))
if value is not None: self.q[(valoff:=self.cmds_offset[cmd_idx]+3):valoff+2] = array.array('I', data64_le(value))
def bind(self, device: NVDevice):
def bind(self, device):
self.binded_device = device
self.hw_page = device._gpu_alloc(len(self.q) * 4, map_to_cpu=True)
self.hw_page = cast(NVDevice, device)._gpu_alloc(len(self.q) * 4, map_to_cpu=True)
hw_view = to_mv(self.hw_page.va_addr, self.hw_page.size).cast("I")
for i, value in enumerate(self.q): hw_view[i] = value