hcq: refactor into peer_groups (#11277)

* hcq: refactor into peer_groups

* fix fors

* fixes

* ooops

* mypy

* tiny fixes
This commit is contained in:
nimlgen
2025-07-18 16:34:18 +03:00
committed by GitHub
parent f432eef708
commit 9a88bd841c
7 changed files with 51 additions and 67 deletions

View File

@@ -52,7 +52,7 @@ Signals are device-dependent structures used for synchronization and timing in H
The following Python code demonstrates the usage of signals:
```python
signal = your_device.signal_t()
signal = your_device.new_signal(value=0)
HWQueue().timestamp(signal) \
.signal(signal, value_to_fire) \

View File

@@ -68,7 +68,7 @@ class TestHCQ(unittest.TestCase):
if queue_type is None: continue
with self.subTest(name=str(queue_type)):
fake_signal = TestHCQ.d0.signal_t()
fake_signal = TestHCQ.d0.new_signal()
fake_signal.value = 1
queue_type().wait(fake_signal, 1) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
@@ -81,7 +81,7 @@ class TestHCQ(unittest.TestCase):
if queue_type is None: continue
with self.subTest(name=str(queue_type)):
fake_signal = TestHCQ.d0.signal_t()
fake_signal = TestHCQ.d0.new_signal()
queue_type().wait(fake_signal, 1) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
@@ -101,7 +101,7 @@ class TestHCQ(unittest.TestCase):
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
virt_signal = TestHCQ.d0.signal_t(base_buf=HCQBuffer(Variable("sig_addr", 0, 0xffffffffffffffff, dtypes.uint64), 16))
fake_signal = TestHCQ.d0.signal_t()
fake_signal = TestHCQ.d0.new_signal()
q = queue_type().wait(virt_signal, virt_val).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
fake_signal.value = 0x30
@@ -293,7 +293,7 @@ class TestHCQ(unittest.TestCase):
virt_signal = TestHCQ.d0.signal_t(base_buf=HCQBuffer(Variable("sig_addr", 0, 0xffffffffffffffff, dtypes.uint64), 16))
with self.subTest(name=str(queue_type)):
fake_signal = TestHCQ.d0.signal_t()
fake_signal = TestHCQ.d0.new_signal()
q = queue_type().wait(virt_signal, virt_val).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.bind(TestHCQ.d0)
@@ -310,7 +310,7 @@ class TestHCQ(unittest.TestCase):
try: d1 = Device[f"{Device.DEFAULT}:1"]
except Exception: self.skipTest("no multidevice, test skipped")
TestHCQ.d0.hw_copy_queue_t().signal(sig:=TestHCQ.d0.signal_t(value=0), value=0xfff) \
TestHCQ.d0.hw_copy_queue_t().signal(sig:=TestHCQ.d0.new_signal(value=0), value=0xfff) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
d1.hw_copy_queue_t().wait(sig, value=0xfff) \
@@ -324,7 +324,7 @@ class TestHCQ(unittest.TestCase):
# Test profile api
def test_speed_exec_time(self):
sig_st, sig_en = TestHCQ.d0.signal_t(), TestHCQ.d0.signal_t()
sig_st, sig_en = TestHCQ.d0.new_signal(), TestHCQ.d0.new_signal()
TestHCQ.d0.hw_compute_queue_t().timestamp(sig_st) \
.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
.timestamp(sig_en) \
@@ -346,7 +346,7 @@ class TestHCQ(unittest.TestCase):
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
b = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
sig_st, sig_en = TestHCQ.d0.signal_t(), TestHCQ.d0.signal_t()
sig_st, sig_en = TestHCQ.d0.new_signal(), TestHCQ.d0.new_signal()
TestHCQ.d0.hw_copy_queue_t().timestamp(sig_st) \
.copy(a._buf.va_addr, b._buf.va_addr, SZ) \
.timestamp(sig_en) \
@@ -373,7 +373,7 @@ class TestHCQ(unittest.TestCase):
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
TestHCQ.d0.allocator.map(b._buf)
sig_st, sig_en = TestHCQ.d0.signal_t(), TestHCQ.d0.signal_t()
sig_st, sig_en = TestHCQ.d0.new_signal(), TestHCQ.d0.new_signal()
TestHCQ.d0.hw_copy_queue_t().timestamp(sig_st) \
.copy(a._buf.va_addr, b._buf.va_addr, SZ) \
.timestamp(sig_en) \
@@ -531,8 +531,8 @@ class TestHCQ(unittest.TestCase):
try: nv_dev = Device["NV"]
except Exception: self.skipTest("no NV device, test skipped")
x = amd_dev.signal_t()
y = nv_dev.signal_t()
x = amd_dev.new_signal()
y = nv_dev.new_signal()
assert type(x) is amd_dev.signal_t
assert type(y) is nv_dev.signal_t

View File

@@ -48,13 +48,13 @@ class HCQGraph(MultiGraphRunner):
self.comp_queues: dict[HCQCompiled, HWQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
self.copy_queues: dict[HCQCompiled, HWQueue] = {} # lazy allocation
self.signals: dict[Any, HCQSignal] = {**{dev: dev.signal_t(value=0) for dev in self.devices}, **{"KICK": self.devices[0].signal_t(value=0)}}
self.signals: dict[Any, HCQSignal] = {**{dev: dev.new_signal(value=0) for dev in self.devices}, **{"KICK": self.devices[0].new_signal(value=0)}}
self.kickoff_value: int = 0
self.kickoff_var = UOp.variable("kickoff_var", 0, 0xffffffff, dtype=dtypes.uint32)
# When profiling allocate 2 signals for each jit item to measure speed. The jth jit item have signals at 2*j and 2*j+1.
# TODO: This logic might allocate a few extra signals...
self.prof_signals: list[HCQSignal] = [self.devices[0].signal_t() for i in range(len(jit_cache) * 2)] if PROFILE else []
self.prof_signals: list[HCQSignal] = [self.devices[0].new_signal() for i in range(len(jit_cache) * 2)] if PROFILE else []
self.prog_graph_deps: list[list[int]] = []
self.prof_graph_entries: list[ProfileGraphEntry] = []
@@ -78,7 +78,7 @@ class HCQGraph(MultiGraphRunner):
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
enqueue_queue = self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))
out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.new_signal(value=0))
# Get dependencies based on input and output buffers.
rdeps = self._access_resources(ji.bufs, ji.prg.p.outs if is_exec_prg else [0], (enqueue_queue, j + 1)) #type:ignore
@@ -131,7 +131,7 @@ class HCQGraph(MultiGraphRunner):
# Create variable timeline signals for each device.
timeline_sigaddrs = {dev: UOp.variable(f"timeline_sig_{dev.device_id}", 0, 0xffffffffffffffff, dtype=dtypes.uint64) for dev in self.devices}
self.virt_timeline_vals = {dev: UOp.variable(f"timeline_var_{dev.device_id}", 0, 0xffffffff, dtype=dtypes.uint32) for dev in self.devices}
self.virt_timeline_signals = {dev: dev.signal_t(base_buf=HCQBuffer(timeline_sigaddrs[dev], 16), timeline_for_device=dev) for dev in self.devices}
self.virt_timeline_signals = {dev: dev.signal_t(HCQBuffer(timeline_sigaddrs[dev], 16), owner=dev, is_timeline=True) for dev in self.devices}
for dev in self.devices:
self.comp_queues[dev].memory_barrier().wait(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev]) \

View File

@@ -26,12 +26,11 @@ WAIT_REG_MEM_FUNCTION_NEQ = 4 # !=
WAIT_REG_MEM_FUNCTION_GEQ = 5 # >=
class AMDSignal(HCQSignal):
def __init__(self, base_buf:HCQBuffer|None=None, **kwargs):
super().__init__(base_buf, **kwargs, timestamp_divider=100, dev_t=AMDDevice)
def __init__(self, *args, **kwargs): super().__init__(*args, **{**kwargs, 'timestamp_divider': 100})
def _sleep(self, time_spent_waiting_ms:int):
# Resonable to sleep for long workloads (which take more than 2s) and only timeline signals.
if time_spent_waiting_ms > 2000 and self.timeline_for_device is not None: self.timeline_for_device.iface.sleep(200)
if time_spent_waiting_ms > 2000 and self.is_timeline and self.owner is not None: self.owner.iface.sleep(200)
class AMDComputeQueue(HWQueue):
def __init__(self, dev:AMDDevice):
@@ -299,7 +298,7 @@ class AMDComputeQueue(HWQueue):
self.release_mem(signal.value_addr, value, self.pm4.data_sel__mec_release_mem__send_32_bit_low,
self.pm4.int_sel__mec_release_mem__send_interrupt_after_write_confirm, cache_flush=True)
if (dev:=signal.timeline_for_device) is not None and not dev.is_am():
if (dev:=signal.owner) is not None and signal.is_timeline and not dev.is_am():
self.release_mem(dev.queue_event_mailbox_ptr, dev.queue_event.event_id, self.pm4.data_sel__mec_release_mem__send_32_bit_low,
self.pm4.int_sel__mec_release_mem__send_interrupt_after_write_confirm, ctxid=dev.queue_event.event_id)
return self
@@ -355,7 +354,7 @@ class AMDCopyQueue(HWQueue):
fence_flags = self.sdma.SDMA_PKT_FENCE_HEADER_MTYPE(3) if self.dev.target >= (10,0,0) else 0
self.q(self.sdma.SDMA_OP_FENCE | fence_flags, *data64_le(signal.value_addr), value)
if (dev:=signal.timeline_for_device) is not None and not dev.is_am():
if (dev:=signal.owner) is not None and signal.is_timeline and not dev.is_am():
self.q(self.sdma.SDMA_OP_FENCE | fence_flags, *data64_le(dev.queue_event_mailbox_ptr), dev.queue_event.event_id)
self.q(self.sdma.SDMA_OP_TRAP, self.sdma.SDMA_PKT_TRAP_INT_CONTEXT_INT_CONTEXT(dev.queue_event.event_id))
elif dev is not None and dev.is_am(): self.q(self.sdma.SDMA_OP_TRAP, self.sdma.SDMA_PKT_TRAP_INT_CONTEXT_INT_CONTEXT(0))
@@ -682,7 +681,8 @@ class PCIIface(PCIIfaceBase):
self.dev_impl.ih.interrupt_handler()
def on_device_hang(self):
for d in self.dev.devices: d.iface.dev_impl.gmc.on_interrupt()
devs:list[AMDDevice] = [d for pg in HCQCompiled.peer_groups.values() for d in pg if isinstance(d, AMDDevice) and d.is_am()]
for d in devs: d.iface.dev_impl.gmc.on_interrupt()
raise RuntimeError("Device hang detected")
def device_fini(self): self.dev_impl.fini()
@@ -730,10 +730,6 @@ class USBIface(PCIIface):
def sleep(self, timeout): pass
class AMDDevice(HCQCompiled):
devices: ClassVar[list[HCQCompiled]] = []
signal_pages: ClassVar[list[HCQBuffer]] = []
signal_pool: ClassVar[list[HCQBuffer]] = []
def is_am(self) -> bool: return isinstance(self.iface, (PCIIface, USBIface))
def is_usb(self) -> bool: return isinstance(self.iface, USBIface)

View File

@@ -72,8 +72,7 @@ class QMD:
else: self.write(**{f'constant_buffer_addr_upper_shifted6_{i}':hi32(addr >> 6), f'constant_buffer_addr_lower_shifted6_{i}':lo32(addr >> 6)})
class NVSignal(HCQSignal):
def __init__(self, base_buf:HCQBuffer|None=None, **kwargs):
super().__init__(base_buf, **kwargs, timestamp_divider=1000, dev_t=NVDevice)
def __init__(self, *args, **kwargs): super().__init__(*args, **{**kwargs, 'timestamp_divider': 1000})
class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']):
def __init__(self):
@@ -365,7 +364,7 @@ class NVKIface:
uvm.register_gpu(self.fd_uvm, rmCtrlFd=-1, gpu_uuid=self.gpu_uuid)
uvm.register_gpu_vaspace(self.fd_uvm, gpuUuid=self.gpu_uuid, rmCtrlFd=self.fd_ctl.fd, hClient=self.root, hVaSpace=vaspace)
for dev in cast(list[NVDevice], self.dev.devices):
for dev in cast(list[NVDevice], [d for pg in HCQCompiled.peer_groups.values() for d in pg if isinstance(d, NVDevice) and not d.is_nvd()]):
try: uvm.enable_peer_access(self.fd_uvm, gpuUuidA=self.gpu_uuid, gpuUuidB=dev.iface.gpu_uuid)
except RuntimeError as e: raise RuntimeError(f"{e}. Make sure GPUs #{self.gpu_minor} & #{dev.iface.gpu_minor} have P2P enabled.") from e
@@ -481,10 +480,6 @@ class PCIIface(PCIIfaceBase):
def device_fini(self): self.dev_impl.fini()
class NVDevice(HCQCompiled[NVSignal]):
devices: ClassVar[list[HCQCompiled]] = []
signal_pages: ClassVar[list[HCQBuffer]] = []
signal_pool: ClassVar[list[HCQBuffer]] = []
def is_nvd(self) -> bool: return isinstance(self.iface, PCIIface)
def __init__(self, device:str=""):

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import os, ctypes, functools, mmap, struct, array, math, sys, weakref
assert sys.platform != 'win32'
from types import SimpleNamespace
from typing import Any, cast, ClassVar
from typing import Any, cast
from tinygrad.device import BufferSpec
from tinygrad.runtime.support.hcq import HCQBuffer, HWQueue, HCQProgram, HCQCompiled, HCQAllocatorBase, HCQSignal, HCQArgsState, BumpAllocator
from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface
@@ -37,14 +37,12 @@ class QCOMCompiler(CLCompiler):
def disassemble(self, lib:bytes): fromimport('extra.disassemblers.adreno', 'disasm')(lib)
class QCOMSignal(HCQSignal):
def __init__(self, base_buf:HCQBuffer|None=None, **kwargs):
super().__init__(base_buf, **kwargs, timestamp_divider=19.2, dev_t=QCOMDevice)
def __init__(self, *args, **kwargs): super().__init__(*args, **{**kwargs, 'timestamp_divider': 19.2})
def _sleep(self, time_spent_waiting_ms:int):
# Sleep only for only timeline signals. Do it immediately to free cpu.
if self.timeline_for_device is not None:
kgsl.IOCTL_KGSL_DEVICE_WAITTIMESTAMP_CTXTID(self.timeline_for_device.fd, context_id=self.timeline_for_device.ctx,
timestamp=self.timeline_for_device.last_cmd, timeout=0xffffffff)
# Sleep only for timeline signals. Do it immediately to free cpu.
if self.is_timeline and self.owner is not None:
kgsl.IOCTL_KGSL_DEVICE_WAITTIMESTAMP_CTXTID(self.owner.fd, context_id=self.owner.ctx, timestamp=self.owner.last_cmd, timeout=0xffffffff)
class QCOMComputeQueue(HWQueue):
def __del__(self):
@@ -315,10 +313,6 @@ class QCOMAllocator(HCQAllocatorBase):
self.dev._gpu_free(opaque)
class QCOMDevice(HCQCompiled):
devices: ClassVar[list[HCQCompiled]] = []
signal_pages: ClassVar[list[HCQBuffer]] = []
signal_pool: ClassVar[list[HCQBuffer]] = []
gpu_id: int = 0
dummy_addr: int = 0

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import cast, Callable, Type, TypeVar, Generic, Any, ClassVar
import contextlib, decimal, statistics, time, ctypes, array, os, fcntl, struct, traceback
from typing import cast, Callable, Type, TypeVar, Generic, Any
import contextlib, decimal, statistics, time, ctypes, array, os, fcntl, struct, traceback, collections
from tinygrad.helpers import PROFILE, getenv, to_mv, round_up, ProfileRangeEvent
from tinygrad.renderer import Renderer
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileDeviceEvent, ProfileProgramEvent
@@ -217,19 +217,17 @@ class HWQueue(Generic[SignalType, HCQDeviceType, ProgramType, ArgsStateType]):
def _submit(self, dev:HCQDeviceType): raise NotImplementedError("need _submit")
class HCQSignal(Generic[HCQDeviceType]):
def __init__(self, base_buf:HCQBuffer|None=None, value:int=0, dev_t:Type[HCQDeviceType]|None=None, timeline_for_device:HCQDeviceType|None=None,
timestamp_divider=1, value_off=0, timestamp_off=8):
self.base_buf = cast(HCQBuffer, dev_t._alloc_signal() if dev_t is not None and base_buf is None else base_buf)
self.value_addr, self.timestamp_addr, self.dev_t = self.base_buf.va_addr+value_off, self.base_buf.va_addr+timestamp_off, dev_t
def __init__(self, base_buf:HCQBuffer, value:int=0, owner:HCQDeviceType|None=None, is_timeline:bool=False, timestamp_divider=1):
self.base_buf, self.value_addr, self.timestamp_addr, self.owner = base_buf, base_buf.va_addr+0, base_buf.va_addr+8, owner
self.is_timeline = is_timeline
self.timestamp_divider:decimal.Decimal = decimal.Decimal(timestamp_divider)
self.timeline_for_device:HCQDeviceType|None = timeline_for_device
if isinstance(self.base_buf.va_addr, int):
self.value_mv, self.timestamp_mv = self.base_buf.cpu_view().view(value_off, 8, 'Q'), self.base_buf.cpu_view().view(timestamp_off, 8, 'Q')
self.value_mv, self.timestamp_mv = self.base_buf.cpu_view().view(0, 8, 'Q'), self.base_buf.cpu_view().view(8, 8, 'Q')
self.value_mv[0] = value
def __del__(self):
if isinstance(self.base_buf.va_addr, int) and self.dev_t is not None: self.dev_t.signal_pool.append(self.base_buf)
if isinstance(self.base_buf.va_addr, int) and self.owner is not None: HCQCompiled.signal_pool[self.owner.peer_group].append(self.base_buf)
@property
def value(self) -> int: return self.value_mv[0]
@@ -270,7 +268,7 @@ class HCQSignal(Generic[HCQDeviceType]):
@contextlib.contextmanager
def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Callable[[], HWQueue]|None=None, queue:HWQueue|None=None):
st, en = (dev.signal_t(), dev.signal_t()) if enabled else (None, None)
st, en = (dev.new_signal(), dev.new_signal()) if enabled else (None, None)
if enabled and queue is not None: queue.timestamp(st)
elif enabled:
@@ -355,9 +353,9 @@ class HCQCompiled(Compiled, Generic[SignalType]):
"""
A base class for devices compatible with the HCQ (Hardware Command Queue) API.
"""
devices: ClassVar[list[HCQCompiled]] = []
signal_pages: ClassVar[list[HCQBuffer]] = []
signal_pool: ClassVar[list[HCQBuffer]] = []
peer_groups: dict[str, list[HCQCompiled]] = collections.defaultdict(list)
signal_pages: dict[str, list[HCQBuffer]] = collections.defaultdict(list) # per peer group
signal_pool: dict[str, list[HCQBuffer]] = collections.defaultdict(list) # per peer group
def __init__(self, device:str, allocator:HCQAllocatorBase, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType],
comp_queue_t:Callable[[], HWQueue], copy_queue_t:Callable[[], HWQueue]|None, kernargs_size=(16 << 20), sigalloc_size=0x1000):
@@ -366,15 +364,17 @@ class HCQCompiled(Compiled, Generic[SignalType]):
from tinygrad.runtime.graph.hcq import HCQGraph
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
# TODO: peer logic is determined based on device name.
self.peer_group = device.split(":")[0]
HCQCompiled.peer_groups[self.peer_group].append(self)
# Map signals if any
for sig_page in self.signal_pages: cast(HCQAllocator, self.allocator).map(sig_page)
self.devices.append(self)
for sig_page in HCQCompiled.signal_pages[self.peer_group]: cast(HCQAllocator, self.allocator).map(sig_page)
self.sigalloc_size = sigalloc_size
self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t
self.timeline_value:int = 1
self.timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
self._shadow_timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
self.timeline_signal, self._shadow_timeline_signal = self.new_signal(value=0, is_timeline=True), self.new_signal(value=0, is_timeline=True)
self.sig_prof_records:list[tuple[HCQSignal, HCQSignal, str, bool]] = []
self.kernargs_buf:HCQBuffer = self.allocator.alloc(kernargs_size, BufferSpec(cpu_access=True))
@@ -395,13 +395,12 @@ class HCQCompiled(Compiled, Generic[SignalType]):
self.timeline_value += 1
return self.timeline_value - 1
@classmethod
def _alloc_signal(cls) -> HCQBuffer:
if not cls.signal_pool:
cls.signal_pages.append(alc:=cls.devices[0].allocator.alloc(cls.devices[0].sigalloc_size, BufferSpec(host=True,uncached=True,cpu_access=True)))
cls.signal_pool += [alc.offset(offset=off, size=16) for off in range(0, alc.size, 16)]
for dev in cls.devices: cast(HCQAllocator, dev.allocator).map(alc)
return cls.signal_pool.pop()
def new_signal(self, **kwargs) -> SignalType:
if not HCQCompiled.signal_pool[pg:=self.peer_group]:
HCQCompiled.signal_pages[pg].append(alc:=self.allocator.alloc(self.sigalloc_size, BufferSpec(host=True, uncached=True, cpu_access=True)))
HCQCompiled.signal_pool[pg] += [alc.offset(offset=off, size=16) for off in range(0, alc.size, 16)]
for dev in HCQCompiled.peer_groups[pg]: cast(HCQAllocator, dev.allocator).map(alc)
return self.signal_t(base_buf=HCQCompiled.signal_pool[pg].pop(), owner=self, **kwargs)
def _at_profile_finalize(self):
def _sync(d:HCQCompiled, q_t:Callable[[], HWQueue]):