hcq: update signal logic (#14531)

This commit is contained in:
nimlgen
2026-02-04 19:32:56 +03:00
committed by GitHub
parent 62786d488a
commit ec2b6bbda8
5 changed files with 23 additions and 29 deletions

View File

@@ -43,10 +43,9 @@ class ProfilePMCEvent(ProfileEvent): device:str; kern:int; sched:list[PMCSample]
class AMDSignal(HCQSignal):
def __init__(self, *args, **kwargs): super().__init__(*args, **{**kwargs, 'timestamp_divider': 100})
def _sleep(self, time_spent_waiting_ms:int) -> bool:
# Resonable to sleep for long workloads (which take more than 2s) and only timeline signals.
if time_spent_waiting_ms > 2000 and self.is_timeline and self.owner is not None: return self.owner.iface.sleep(200)
return False
def _sleep(self, time_spent_since_last_sleep_ms:int):
# Reasonable to sleep for long workloads (which take more than 200ms) and only timeline signals.
if time_spent_since_last_sleep_ms > 200 and self.is_timeline and self.owner is not None: self.owner.iface.sleep(200)
class AMDComputeQueue(HWQueue):
def __init__(self, dev:AMDDevice):
@@ -778,9 +777,8 @@ class KFDIface:
write_ptr=MMIOInterface(queue.write_pointer_address, 8, fmt='Q'),
doorbell=MMIOInterface(self.doorbells + queue.doorbell_offset - self.doorbells_base, 8, fmt='Q'))
def sleep(self, tm:int) -> bool:
def sleep(self, tm:int):
kfd.AMDKFD_IOC_WAIT_EVENTS(KFDIface.kfd, events_ptr=self.queue_event_arr_ptr, num_events=1, wait_for_all=1, timeout=tm)
return False
def on_device_hang(self):
def _collect_str(st): return ' '.join(f'{k[0]}={getattr(st, k[0])}' for k in st._real_fields_)
@@ -857,11 +855,11 @@ class PCIIface(PCIIfaceBase):
return AMDQueueDesc(ring=ring.cpu_view().view(fmt='I'), doorbell=self.dev_impl.doorbell64.view(doorbell_index * 8, 8, fmt='Q'), put_value=pv,
read_ptr=gart.cpu_view().view(offset=rptr, size=8, fmt='Q'), write_ptr=gart.cpu_view().view(offset=wptr, size=8, fmt='Q'), params=rcvr_params)
def sleep(self, timeout) -> bool:
def sleep(self, timeout):
if hasattr(self.pci_dev, 'irq_poller') and self.pci_dev.irq_poller is not None and (events_cnt:=len(self.pci_dev.irq_poller.poll(timeout))):
self.pci_dev.irq_fd.read(8 * events_cnt)
self.dev_impl.ih.interrupt_handler()
return self.dev_impl.is_err_state
if self.dev_impl.is_err_state: raise RuntimeError("Device fault detected")
def on_device_hang(self):
devs:list[AMDDevice] = [d for pg in HCQCompiled.peer_groups.values() for d in pg if isinstance(d, AMDDevice) and d.is_am()]
@@ -905,7 +903,7 @@ class USBIface(PCIIface):
if queue_type == kfd.KFD_IOC_QUEUE_TYPE_COMPUTE: self.pci_dev.usb._pci_cacheable += [(ring.cpu_view().addr, ring.size)]
return super().create_queue(queue_type, ring, gart, rptr, wptr, eop_buffer, cwsr_buffer, ctl_stack_size, ctx_save_restore_size, xcc_id, idx)
def sleep(self, timeout) -> bool: return False
def sleep(self, timeout): pass
class AMDDevice(HCQCompiled):
def is_am(self) -> bool: return isinstance(self.iface, (PCIIface, USBIface))

View File

@@ -13,9 +13,8 @@ from tinygrad.runtime.support.elf import jit_loader
from tinygrad.uop.ops import sint
class CPUSignal(HCQSignal):
def _sleep(self, time_spent_waiting_ms:int) -> bool:
def _sleep(self, time_spent_since_last_sleep_ms:int):
if self.is_timeline and self.owner is not None: self.owner.tasks.join()
return False
class CPUWorker(threading.Thread):
def __init__(self, dev, tasks, thread_id):

View File

@@ -27,10 +27,9 @@ PMA = ContextVar("PMA", abs(VIZ.value)>=2)
class ProfilePMAEvent(ProfileEvent): device:str; kern:str; blob:bytes # noqa: E702
class NVSignal(HCQSignal):
def _sleep(self, time_spent_waiting_ms:int) -> bool:
# Resonable to sleep for long workloads (which take more than 2s) and only timeline signals.
if time_spent_waiting_ms > 2000 and self.is_timeline and self.owner is not None: return self.owner.iface.sleep(200)
return False
def _sleep(self, time_spent_since_last_sleep_ms:int):
# Reasonable to sleep for long workloads (which take more than 200ms) and only timeline signals.
if time_spent_since_last_sleep_ms > 200 and self.is_timeline and self.owner is not None: self.owner.iface.sleep(200)
def get_error_str(status): return f"{status}: {nv_gpu.nv_status_codes.get(status, 'Unknown error')}"
@@ -525,7 +524,7 @@ class NVKIface:
def _alloc_gpu_vaddr(self, size, alignment=(4 << 10), force_low=False):
return NVKIface.low_uvm_vaddr_allocator.alloc(size, alignment) if force_low else NVKIface.uvm_vaddr_allocator.alloc(size, alignment)
def sleep(self, tm:int) -> bool: return False
def sleep(self, tm:int): pass
class PCIIface(PCIIfaceBase):
gpus:ClassVar[list[str]] = []
@@ -561,9 +560,9 @@ class PCIIface(PCIIfaceBase):
def device_fini(self): self.dev_impl.fini()
def sleep(self, timeout) -> bool:
def sleep(self, timeout):
for _ in self.dev_impl.gsp.stat_q.read_resp(): pass
return self.dev_impl.is_err_state
if self.dev_impl.is_err_state: raise RuntimeError("Device fault detected")
class NVDevice(HCQCompiled[NVSignal]):
def is_nvd(self) -> bool: return isinstance(self.iface, PCIIface)

View File

@@ -57,11 +57,10 @@ class QCOMCompiler(CLCompiler):
class QCOMSignal(HCQSignal):
def __init__(self, *args, **kwargs): super().__init__(*args, **{**kwargs, 'timestamp_divider': 19.2})
def _sleep(self, time_spent_waiting_ms:int) -> bool:
def _sleep(self, time_spent_since_last_sleep_ms:int):
# Sleep only for timeline signals. Do it immediately to free cpu.
if self.is_timeline and self.owner is not None:
kgsl.IOCTL_KGSL_DEVICE_WAITTIMESTAMP_CTXTID(self.owner.fd, context_id=self.owner.ctx, timestamp=self.owner.last_cmd, timeout=0xffffffff)
return False
class QCOMComputeQueue(HWQueue):
def __init__(self, dev:QCOMDevice):

View File

@@ -243,12 +243,11 @@ class HCQSignal(Generic[HCQDeviceType]):
"""
return self.timestamp_mv[0] / self.timestamp_divider
def _sleep(self, time_spent_waiting_ms:int) -> bool:
def _sleep(self, time_spent_since_last_sleep_ms:int):
"""
Optional function which can implement sleep functionality for the signal.
Returns True if a fault was detected, False otherwise.
Raises RuntimeError if a fault is detected.
"""
return False
def wait(self, value:int, timeout:int=getenv("HCQDEV_WAIT_TIMEOUT_MS", 30000)):
"""
@@ -258,12 +257,12 @@ class HCQSignal(Generic[HCQDeviceType]):
value: The value to wait for.
timeout: Maximum time to wait in milliseconds. Defaults to 30s.
"""
start_time, fault = int(time.perf_counter() * 1000), False
while (not_passed:=(prev_value:=self.value) < value) and (time_spent:=int(time.perf_counter() * 1000) - start_time) < timeout:
if fault:=self._sleep(time_spent): break
if self.value != prev_value: start_time = int(time.perf_counter() * 1000) # progress was made, reset timer
if not_passed and self.value < value:
raise RuntimeError("Device fault detected" if fault else f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
start_time = last_sleep_time = int(time.perf_counter() * 1000)
while (not_passed:=(prev_value:=self.value) < value) and (cur_time:=int(time.perf_counter() * 1000)) - start_time < timeout:
self._sleep(cur_time - last_sleep_time)
last_sleep_time = int(time.perf_counter() * 1000)
if self.value != prev_value: start_time = last_sleep_time # progress was made, reset timer
if not_passed and self.value < value: raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
@contextlib.contextmanager
def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Callable[[], HWQueue]|None=None, queue:HWQueue|None=None):