mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
hcq: _sleep report status (#13992)
* hcq: _sleep report status * msg * print all
This commit is contained in:
@@ -42,9 +42,10 @@ class ProfilePMCEvent(ProfileEvent): device:str; kern:str; sched:list[PMCSample]
|
||||
class AMDSignal(HCQSignal):
|
||||
def __init__(self, *args, **kwargs): super().__init__(*args, **{**kwargs, 'timestamp_divider': 100})
|
||||
|
||||
def _sleep(self, time_spent_waiting_ms:int):
|
||||
def _sleep(self, time_spent_waiting_ms:int) -> bool:
|
||||
# Resonable to sleep for long workloads (which take more than 2s) and only timeline signals.
|
||||
if time_spent_waiting_ms > 2000 and self.is_timeline and self.owner is not None: self.owner.iface.sleep(200)
|
||||
if time_spent_waiting_ms > 2000 and self.is_timeline and self.owner is not None: return self.owner.iface.sleep(200)
|
||||
return False
|
||||
|
||||
class AMDComputeQueue(HWQueue):
|
||||
def __init__(self, dev:AMDDevice):
|
||||
@@ -773,7 +774,9 @@ class KFDIface:
|
||||
write_ptrs=[MMIOInterface(queue.write_pointer_address, 8, fmt='Q')],
|
||||
doorbells=[MMIOInterface(self.doorbells + queue.doorbell_offset - self.doorbells_base, 8, fmt='Q')])
|
||||
|
||||
def sleep(self, tm:int): kfd.AMDKFD_IOC_WAIT_EVENTS(KFDIface.kfd, events_ptr=self.queue_event_arr_ptr, num_events=1, wait_for_all=1, timeout=tm)
|
||||
def sleep(self, tm:int) -> bool:
|
||||
kfd.AMDKFD_IOC_WAIT_EVENTS(KFDIface.kfd, events_ptr=self.queue_event_arr_ptr, num_events=1, wait_for_all=1, timeout=tm)
|
||||
return False
|
||||
|
||||
def on_device_hang(self):
|
||||
def _collect_str(st): return ' '.join(f'{k[0]}={getattr(st, k[0])}' for k in st._fields_)
|
||||
@@ -843,15 +846,16 @@ class PCIIface(PCIIfaceBase):
|
||||
return AMDQueueDesc(ring=ring.cpu_view().view(fmt='I'), doorbells=[self.dev_impl.doorbell64.view(doorbell_index * 8, 8, fmt='Q')],
|
||||
read_ptrs=[gart.cpu_view().view(offset=rptr, size=8, fmt='Q')], write_ptrs=[gart.cpu_view().view(offset=wptr, size=8, fmt='Q')], put_value=pv)
|
||||
|
||||
def sleep(self, timeout):
|
||||
def sleep(self, timeout) -> bool:
|
||||
if hasattr(self.pci_dev, 'irq_poller') and self.pci_dev.irq_poller is not None and (events_cnt:=len(self.pci_dev.irq_poller.poll(timeout))):
|
||||
self.pci_dev.irq_fd.read(8 * events_cnt)
|
||||
self.dev_impl.ih.interrupt_handler()
|
||||
return self.dev_impl.gmc.check_fault() is not None
|
||||
|
||||
def on_device_hang(self):
|
||||
devs:list[AMDDevice] = [d for pg in HCQCompiled.peer_groups.values() for d in pg if isinstance(d, AMDDevice) and d.is_am()]
|
||||
for d in devs: d.iface.dev_impl.gmc.on_interrupt()
|
||||
raise RuntimeError("Device hang detected")
|
||||
faults = [f for d in devs if (f:=d.iface.dev_impl.gmc.check_fault())]
|
||||
raise RuntimeError(f"Device hang detected: {'; '.join(faults)}" if faults else "Device hang detected")
|
||||
|
||||
def device_fini(self): self.dev_impl.fini()
|
||||
|
||||
@@ -883,7 +887,7 @@ class USBIface(PCIIface):
|
||||
if queue_type == kfd.KFD_IOC_QUEUE_TYPE_COMPUTE: self.pci_dev.usb._pci_cacheable += [(ring.cpu_view().addr, ring.size)]
|
||||
return super().create_queue(queue_type, ring, gart, rptr, wptr, eop_buffer, cwsr_buffer, ctl_stack_size, ctx_save_restore_size, xcc_id, idx)
|
||||
|
||||
def sleep(self, timeout): pass
|
||||
def sleep(self, timeout) -> bool: return False
|
||||
|
||||
class AMDDevice(HCQCompiled):
|
||||
def is_am(self) -> bool: return isinstance(self.iface, (PCIIface, USBIface))
|
||||
|
||||
@@ -13,8 +13,9 @@ from tinygrad.runtime.support.elf import jit_loader
|
||||
from tinygrad.uop.ops import sint
|
||||
|
||||
class CPUSignal(HCQSignal):
|
||||
def _sleep(self, time_spent_waiting_ms:int):
|
||||
def _sleep(self, time_spent_waiting_ms:int) -> bool:
|
||||
if self.is_timeline and self.owner is not None: self.owner.tasks.join()
|
||||
return False
|
||||
|
||||
class CPUWorker(threading.Thread):
|
||||
def __init__(self, dev, tasks, thread_id):
|
||||
|
||||
@@ -45,10 +45,11 @@ class QCOMCompiler(CLCompiler):
|
||||
class QCOMSignal(HCQSignal):
|
||||
def __init__(self, *args, **kwargs): super().__init__(*args, **{**kwargs, 'timestamp_divider': 19.2})
|
||||
|
||||
def _sleep(self, time_spent_waiting_ms:int):
|
||||
def _sleep(self, time_spent_waiting_ms:int) -> bool:
|
||||
# Sleep only for timeline signals. Do it immediately to free cpu.
|
||||
if self.is_timeline and self.owner is not None:
|
||||
kgsl.IOCTL_KGSL_DEVICE_WAITTIMESTAMP_CTXTID(self.owner.fd, context_id=self.owner.ctx, timestamp=self.owner.last_cmd, timeout=0xffffffff)
|
||||
return False
|
||||
|
||||
class QCOMComputeQueue(HWQueue):
|
||||
def __init__(self, dev:QCOMDevice):
|
||||
|
||||
@@ -158,11 +158,11 @@ class AM_GMC(AM_IP):
|
||||
if self.adev.ip_ver[am.GC_HWIP] < (10,0,0): return (pte & am.AMDGPU_PDE_PTE) if pte_lv != am.AMDGPU_VM_PDB0 else not (pte & am.AMDGPU_PTE_TF)
|
||||
return pte & (am.AMDGPU_PDE_PTE_GFX12 if self.adev.ip_ver[am.GC_HWIP] >= (12,0,0) else am.AMDGPU_PDE_PTE)
|
||||
|
||||
def on_interrupt(self):
|
||||
for ip in ["MM", "GC"]:
|
||||
va = (self.adev.reg(f'reg{ip}VM_L2_PROTECTION_FAULT_ADDR_HI32').read()<<32) | self.adev.reg(f'reg{ip}VM_L2_PROTECTION_FAULT_ADDR_LO32').read()
|
||||
if self.adev.reg(self.pf_status_reg(ip)).read():
|
||||
raise RuntimeError(f"{ip}VM_L2_PROTECTION_FAULT_STATUS: {self.adev.reg(self.pf_status_reg(ip)).read_bitfields()} {va<<12:#x}")
|
||||
def check_fault(self) -> str|None:
|
||||
va = (self.adev.reg('regGCVM_L2_PROTECTION_FAULT_ADDR_HI32').read()<<32) | self.adev.reg('regGCVM_L2_PROTECTION_FAULT_ADDR_LO32').read()
|
||||
if self.adev.reg(self.pf_status_reg("GC")).read():
|
||||
return f"GCVM_L2_PROTECTION_FAULT_STATUS: {self.adev.reg(self.pf_status_reg('GC')).read_bitfields()} {va<<12:#x}"
|
||||
return None
|
||||
|
||||
class AM_SMU(AM_IP):
|
||||
def init_sw(self):
|
||||
|
||||
@@ -243,10 +243,12 @@ class HCQSignal(Generic[HCQDeviceType]):
|
||||
"""
|
||||
return self.timestamp_mv[0] / self.timestamp_divider
|
||||
|
||||
def _sleep(self, time_spent_waiting_ms:int):
|
||||
def _sleep(self, time_spent_waiting_ms:int) -> bool:
|
||||
"""
|
||||
Optional function which can implement sleep functionality for the signal.
|
||||
Returns True if a fault was detected, False otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
def wait(self, value:int, timeout:int=getenv("HCQDEV_WAIT_TIMEOUT_MS", 30000)):
|
||||
"""
|
||||
@@ -256,11 +258,12 @@ class HCQSignal(Generic[HCQDeviceType]):
|
||||
value: The value to wait for.
|
||||
timeout: Maximum time to wait in milliseconds. Defaults to 30s.
|
||||
"""
|
||||
start_time = int(time.perf_counter() * 1000)
|
||||
start_time, fault = int(time.perf_counter() * 1000), False
|
||||
while (not_passed:=(prev_value:=self.value) < value) and (time_spent:=int(time.perf_counter() * 1000) - start_time) < timeout:
|
||||
self._sleep(time_spent)
|
||||
if fault:=self._sleep(time_spent): break
|
||||
if self.value != prev_value: start_time = int(time.perf_counter() * 1000) # progress was made, reset timer
|
||||
if not_passed and self.value < value: raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
|
||||
if not_passed and self.value < value:
|
||||
raise RuntimeError("Device fault detected" if fault else f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
|
||||
|
||||
@contextlib.contextmanager
|
||||
def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Callable[[], HWQueue]|None=None, queue:HWQueue|None=None):
|
||||
|
||||
Reference in New Issue
Block a user