hcq: _sleep report status (#13992)

* hcq: _sleep report status

* msg

* print all
This commit is contained in:
nimlgen
2026-01-03 14:28:28 +03:00
committed by GitHub
parent 3b354bc11f
commit a49924a0e9
5 changed files with 27 additions and 18 deletions

View File

@@ -42,9 +42,10 @@ class ProfilePMCEvent(ProfileEvent): device:str; kern:str; sched:list[PMCSample]
class AMDSignal(HCQSignal):
def __init__(self, *args, **kwargs): super().__init__(*args, **{**kwargs, 'timestamp_divider': 100})
def _sleep(self, time_spent_waiting_ms:int):
def _sleep(self, time_spent_waiting_ms:int) -> bool:
# Resonable to sleep for long workloads (which take more than 2s) and only timeline signals.
if time_spent_waiting_ms > 2000 and self.is_timeline and self.owner is not None: self.owner.iface.sleep(200)
if time_spent_waiting_ms > 2000 and self.is_timeline and self.owner is not None: return self.owner.iface.sleep(200)
return False
class AMDComputeQueue(HWQueue):
def __init__(self, dev:AMDDevice):
@@ -773,7 +774,9 @@ class KFDIface:
write_ptrs=[MMIOInterface(queue.write_pointer_address, 8, fmt='Q')],
doorbells=[MMIOInterface(self.doorbells + queue.doorbell_offset - self.doorbells_base, 8, fmt='Q')])
def sleep(self, tm:int): kfd.AMDKFD_IOC_WAIT_EVENTS(KFDIface.kfd, events_ptr=self.queue_event_arr_ptr, num_events=1, wait_for_all=1, timeout=tm)
def sleep(self, tm:int) -> bool:
kfd.AMDKFD_IOC_WAIT_EVENTS(KFDIface.kfd, events_ptr=self.queue_event_arr_ptr, num_events=1, wait_for_all=1, timeout=tm)
return False
def on_device_hang(self):
def _collect_str(st): return ' '.join(f'{k[0]}={getattr(st, k[0])}' for k in st._fields_)
@@ -843,15 +846,16 @@ class PCIIface(PCIIfaceBase):
return AMDQueueDesc(ring=ring.cpu_view().view(fmt='I'), doorbells=[self.dev_impl.doorbell64.view(doorbell_index * 8, 8, fmt='Q')],
read_ptrs=[gart.cpu_view().view(offset=rptr, size=8, fmt='Q')], write_ptrs=[gart.cpu_view().view(offset=wptr, size=8, fmt='Q')], put_value=pv)
def sleep(self, timeout):
def sleep(self, timeout) -> bool:
if hasattr(self.pci_dev, 'irq_poller') and self.pci_dev.irq_poller is not None and (events_cnt:=len(self.pci_dev.irq_poller.poll(timeout))):
self.pci_dev.irq_fd.read(8 * events_cnt)
self.dev_impl.ih.interrupt_handler()
return self.dev_impl.gmc.check_fault() is not None
def on_device_hang(self):
devs:list[AMDDevice] = [d for pg in HCQCompiled.peer_groups.values() for d in pg if isinstance(d, AMDDevice) and d.is_am()]
for d in devs: d.iface.dev_impl.gmc.on_interrupt()
raise RuntimeError("Device hang detected")
faults = [f for d in devs if (f:=d.iface.dev_impl.gmc.check_fault())]
raise RuntimeError(f"Device hang detected: {'; '.join(faults)}" if faults else "Device hang detected")
def device_fini(self): self.dev_impl.fini()
@@ -883,7 +887,7 @@ class USBIface(PCIIface):
if queue_type == kfd.KFD_IOC_QUEUE_TYPE_COMPUTE: self.pci_dev.usb._pci_cacheable += [(ring.cpu_view().addr, ring.size)]
return super().create_queue(queue_type, ring, gart, rptr, wptr, eop_buffer, cwsr_buffer, ctl_stack_size, ctx_save_restore_size, xcc_id, idx)
def sleep(self, timeout): pass
def sleep(self, timeout) -> bool: return False
class AMDDevice(HCQCompiled):
def is_am(self) -> bool: return isinstance(self.iface, (PCIIface, USBIface))

View File

@@ -13,8 +13,9 @@ from tinygrad.runtime.support.elf import jit_loader
from tinygrad.uop.ops import sint
class CPUSignal(HCQSignal):
def _sleep(self, time_spent_waiting_ms:int):
def _sleep(self, time_spent_waiting_ms:int) -> bool:
if self.is_timeline and self.owner is not None: self.owner.tasks.join()
return False
class CPUWorker(threading.Thread):
def __init__(self, dev, tasks, thread_id):

View File

@@ -45,10 +45,11 @@ class QCOMCompiler(CLCompiler):
class QCOMSignal(HCQSignal):
def __init__(self, *args, **kwargs): super().__init__(*args, **{**kwargs, 'timestamp_divider': 19.2})
def _sleep(self, time_spent_waiting_ms:int):
def _sleep(self, time_spent_waiting_ms:int) -> bool:
# Sleep only for timeline signals. Do it immediately to free cpu.
if self.is_timeline and self.owner is not None:
kgsl.IOCTL_KGSL_DEVICE_WAITTIMESTAMP_CTXTID(self.owner.fd, context_id=self.owner.ctx, timestamp=self.owner.last_cmd, timeout=0xffffffff)
return False
class QCOMComputeQueue(HWQueue):
def __init__(self, dev:QCOMDevice):

View File

@@ -158,11 +158,11 @@ class AM_GMC(AM_IP):
if self.adev.ip_ver[am.GC_HWIP] < (10,0,0): return (pte & am.AMDGPU_PDE_PTE) if pte_lv != am.AMDGPU_VM_PDB0 else not (pte & am.AMDGPU_PTE_TF)
return pte & (am.AMDGPU_PDE_PTE_GFX12 if self.adev.ip_ver[am.GC_HWIP] >= (12,0,0) else am.AMDGPU_PDE_PTE)
def on_interrupt(self):
for ip in ["MM", "GC"]:
va = (self.adev.reg(f'reg{ip}VM_L2_PROTECTION_FAULT_ADDR_HI32').read()<<32) | self.adev.reg(f'reg{ip}VM_L2_PROTECTION_FAULT_ADDR_LO32').read()
if self.adev.reg(self.pf_status_reg(ip)).read():
raise RuntimeError(f"{ip}VM_L2_PROTECTION_FAULT_STATUS: {self.adev.reg(self.pf_status_reg(ip)).read_bitfields()} {va<<12:#x}")
def check_fault(self) -> str|None:
va = (self.adev.reg('regGCVM_L2_PROTECTION_FAULT_ADDR_HI32').read()<<32) | self.adev.reg('regGCVM_L2_PROTECTION_FAULT_ADDR_LO32').read()
if self.adev.reg(self.pf_status_reg("GC")).read():
return f"GCVM_L2_PROTECTION_FAULT_STATUS: {self.adev.reg(self.pf_status_reg('GC')).read_bitfields()} {va<<12:#x}"
return None
class AM_SMU(AM_IP):
def init_sw(self):

View File

@@ -243,10 +243,12 @@ class HCQSignal(Generic[HCQDeviceType]):
"""
return self.timestamp_mv[0] / self.timestamp_divider
def _sleep(self, time_spent_waiting_ms:int):
def _sleep(self, time_spent_waiting_ms:int) -> bool:
"""
Optional function which can implement sleep functionality for the signal.
Returns True if a fault was detected, False otherwise.
"""
return False
def wait(self, value:int, timeout:int=getenv("HCQDEV_WAIT_TIMEOUT_MS", 30000)):
"""
@@ -256,11 +258,12 @@ class HCQSignal(Generic[HCQDeviceType]):
value: The value to wait for.
timeout: Maximum time to wait in milliseconds. Defaults to 30s.
"""
start_time = int(time.perf_counter() * 1000)
start_time, fault = int(time.perf_counter() * 1000), False
while (not_passed:=(prev_value:=self.value) < value) and (time_spent:=int(time.perf_counter() * 1000) - start_time) < timeout:
self._sleep(time_spent)
if fault:=self._sleep(time_spent): break
if self.value != prev_value: start_time = int(time.perf_counter() * 1000) # progress was made, reset timer
if not_passed and self.value < value: raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
if not_passed and self.value < value:
raise RuntimeError("Device fault detected" if fault else f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
@contextlib.contextmanager
def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Callable[[], HWQueue]|None=None, queue:HWQueue|None=None):