diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index fc1f803fd0..279324fbbb 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -42,9 +42,10 @@ class ProfilePMCEvent(ProfileEvent): device:str; kern:str; sched:list[PMCSample] class AMDSignal(HCQSignal): def __init__(self, *args, **kwargs): super().__init__(*args, **{**kwargs, 'timestamp_divider': 100}) - def _sleep(self, time_spent_waiting_ms:int): + def _sleep(self, time_spent_waiting_ms:int) -> bool: # Resonable to sleep for long workloads (which take more than 2s) and only timeline signals. - if time_spent_waiting_ms > 2000 and self.is_timeline and self.owner is not None: self.owner.iface.sleep(200) + if time_spent_waiting_ms > 2000 and self.is_timeline and self.owner is not None: return self.owner.iface.sleep(200) + return False class AMDComputeQueue(HWQueue): def __init__(self, dev:AMDDevice): @@ -773,7 +774,9 @@ class KFDIface: write_ptrs=[MMIOInterface(queue.write_pointer_address, 8, fmt='Q')], doorbells=[MMIOInterface(self.doorbells + queue.doorbell_offset - self.doorbells_base, 8, fmt='Q')]) - def sleep(self, tm:int): kfd.AMDKFD_IOC_WAIT_EVENTS(KFDIface.kfd, events_ptr=self.queue_event_arr_ptr, num_events=1, wait_for_all=1, timeout=tm) + def sleep(self, tm:int) -> bool: + kfd.AMDKFD_IOC_WAIT_EVENTS(KFDIface.kfd, events_ptr=self.queue_event_arr_ptr, num_events=1, wait_for_all=1, timeout=tm) + return False def on_device_hang(self): def _collect_str(st): return ' '.join(f'{k[0]}={getattr(st, k[0])}' for k in st._fields_) @@ -843,15 +846,16 @@ class PCIIface(PCIIfaceBase): return AMDQueueDesc(ring=ring.cpu_view().view(fmt='I'), doorbells=[self.dev_impl.doorbell64.view(doorbell_index * 8, 8, fmt='Q')], read_ptrs=[gart.cpu_view().view(offset=rptr, size=8, fmt='Q')], write_ptrs=[gart.cpu_view().view(offset=wptr, size=8, fmt='Q')], put_value=pv) - def sleep(self, timeout): + def sleep(self, timeout) -> bool: if hasattr(self.pci_dev, 'irq_poller') and self.pci_dev.irq_poller is not None and (events_cnt:=len(self.pci_dev.irq_poller.poll(timeout))): self.pci_dev.irq_fd.read(8 * events_cnt) self.dev_impl.ih.interrupt_handler() + return self.dev_impl.gmc.check_fault() is not None def on_device_hang(self): devs:list[AMDDevice] = [d for pg in HCQCompiled.peer_groups.values() for d in pg if isinstance(d, AMDDevice) and d.is_am()] - for d in devs: d.iface.dev_impl.gmc.on_interrupt() - raise RuntimeError("Device hang detected") + faults = [f for d in devs if (f:=d.iface.dev_impl.gmc.check_fault())] + raise RuntimeError(f"Device hang detected: {'; '.join(faults)}" if faults else "Device hang detected") def device_fini(self): self.dev_impl.fini() @@ -883,7 +887,7 @@ class USBIface(PCIIface): if queue_type == kfd.KFD_IOC_QUEUE_TYPE_COMPUTE: self.pci_dev.usb._pci_cacheable += [(ring.cpu_view().addr, ring.size)] return super().create_queue(queue_type, ring, gart, rptr, wptr, eop_buffer, cwsr_buffer, ctl_stack_size, ctx_save_restore_size, xcc_id, idx) - def sleep(self, timeout): pass + def sleep(self, timeout) -> bool: return False class AMDDevice(HCQCompiled): def is_am(self) -> bool: return isinstance(self.iface, (PCIIface, USBIface)) diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index 9affe333c2..6a1d035794 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -13,8 +13,9 @@ from tinygrad.runtime.support.elf import jit_loader from tinygrad.uop.ops import sint class CPUSignal(HCQSignal): - def _sleep(self, time_spent_waiting_ms:int): + def _sleep(self, time_spent_waiting_ms:int) -> bool: if self.is_timeline and self.owner is not None: self.owner.tasks.join() + return False class CPUWorker(threading.Thread): def __init__(self, dev, tasks, thread_id): diff --git a/tinygrad/runtime/ops_qcom.py b/tinygrad/runtime/ops_qcom.py index fee8b7eb58..7a90ff2d9b 100644 --- a/tinygrad/runtime/ops_qcom.py +++ b/tinygrad/runtime/ops_qcom.py @@ -45,10 +45,11 @@ class QCOMCompiler(CLCompiler): class QCOMSignal(HCQSignal): def __init__(self, *args, **kwargs): super().__init__(*args, **{**kwargs, 'timestamp_divider': 19.2}) - def _sleep(self, time_spent_waiting_ms:int): + def _sleep(self, time_spent_waiting_ms:int) -> bool: # Sleep only for timeline signals. Do it immediately to free cpu. if self.is_timeline and self.owner is not None: kgsl.IOCTL_KGSL_DEVICE_WAITTIMESTAMP_CTXTID(self.owner.fd, context_id=self.owner.ctx, timestamp=self.owner.last_cmd, timeout=0xffffffff) + return False class QCOMComputeQueue(HWQueue): def __init__(self, dev:QCOMDevice): diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index 97769638a3..2eb8765049 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -158,11 +158,11 @@ class AM_GMC(AM_IP): if self.adev.ip_ver[am.GC_HWIP] < (10,0,0): return (pte & am.AMDGPU_PDE_PTE) if pte_lv != am.AMDGPU_VM_PDB0 else not (pte & am.AMDGPU_PTE_TF) return pte & (am.AMDGPU_PDE_PTE_GFX12 if self.adev.ip_ver[am.GC_HWIP] >= (12,0,0) else am.AMDGPU_PDE_PTE) - def on_interrupt(self): - for ip in ["MM", "GC"]: - va = (self.adev.reg(f'reg{ip}VM_L2_PROTECTION_FAULT_ADDR_HI32').read()<<32) | self.adev.reg(f'reg{ip}VM_L2_PROTECTION_FAULT_ADDR_LO32').read() - if self.adev.reg(self.pf_status_reg(ip)).read(): - raise RuntimeError(f"{ip}VM_L2_PROTECTION_FAULT_STATUS: {self.adev.reg(self.pf_status_reg(ip)).read_bitfields()} {va<<12:#x}") + def check_fault(self) -> str|None: + va = (self.adev.reg('regGCVM_L2_PROTECTION_FAULT_ADDR_HI32').read()<<32) | self.adev.reg('regGCVM_L2_PROTECTION_FAULT_ADDR_LO32').read() + if self.adev.reg(self.pf_status_reg("GC")).read(): + return f"GCVM_L2_PROTECTION_FAULT_STATUS: {self.adev.reg(self.pf_status_reg('GC')).read_bitfields()} {va<<12:#x}" + return None class AM_SMU(AM_IP): def init_sw(self): diff --git a/tinygrad/runtime/support/hcq.py b/tinygrad/runtime/support/hcq.py index a3bfbe1315..803ebd5fdc 100644 --- a/tinygrad/runtime/support/hcq.py +++ b/tinygrad/runtime/support/hcq.py @@ -243,10 +243,12 @@ class HCQSignal(Generic[HCQDeviceType]): """ return self.timestamp_mv[0] / self.timestamp_divider - def _sleep(self, time_spent_waiting_ms:int): + def _sleep(self, time_spent_waiting_ms:int) -> bool: """ Optional function which can implement sleep functionality for the signal. + Returns True if a fault was detected, False otherwise. """ + return False def wait(self, value:int, timeout:int=getenv("HCQDEV_WAIT_TIMEOUT_MS", 30000)): """ @@ -256,11 +258,12 @@ class HCQSignal(Generic[HCQDeviceType]): value: The value to wait for. timeout: Maximum time to wait in milliseconds. Defaults to 30s. """ - start_time = int(time.perf_counter() * 1000) + start_time, fault = int(time.perf_counter() * 1000), False while (not_passed:=(prev_value:=self.value) < value) and (time_spent:=int(time.perf_counter() * 1000) - start_time) < timeout: - self._sleep(time_spent) + if fault:=self._sleep(time_spent): break if self.value != prev_value: start_time = int(time.perf_counter() * 1000) # progress was made, reset timer - if not_passed and self.value < value: raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})") + if not_passed and self.value < value: + raise RuntimeError("Device fault detected" if fault else f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})") @contextlib.contextmanager def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Callable[[], HWQueue]|None=None, queue:HWQueue|None=None):