From ccbbca05eff780be55ad8bc125f4de7b4ec65c9b Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Sun, 1 Mar 2026 16:57:29 +0300 Subject: [PATCH] beam: add dev_timeout for am (#15063) * beam: add dev_timeout for am * all covered * fk * x * fuzz * reset * f --- .../external_fuzz_beam_timeout_recovery.py | 31 +++++++++++++++++++ test/external/external_test_speed_llama.py | 2 +- tinygrad/codegen/opt/search.py | 8 +++-- tinygrad/engine/realize.py | 4 +-- tinygrad/runtime/ops_amd.py | 10 +++--- tinygrad/runtime/ops_cl.py | 2 +- tinygrad/runtime/ops_cuda.py | 2 +- tinygrad/runtime/ops_dsp.py | 4 +-- tinygrad/runtime/ops_hip.py | 2 +- tinygrad/runtime/ops_metal.py | 2 +- tinygrad/runtime/ops_null.py | 2 +- tinygrad/runtime/ops_nv.py | 5 +-- tinygrad/runtime/ops_python.py | 2 +- tinygrad/runtime/ops_qcom.py | 3 +- tinygrad/runtime/ops_webgpu.py | 2 +- tinygrad/runtime/support/am/amdev.py | 8 ++--- tinygrad/runtime/support/am/ip.py | 14 ++++----- tinygrad/runtime/support/hcq.py | 17 +++++----- 18 files changed, 79 insertions(+), 41 deletions(-) create mode 100644 test/external/external_fuzz_beam_timeout_recovery.py diff --git a/test/external/external_fuzz_beam_timeout_recovery.py b/test/external/external_fuzz_beam_timeout_recovery.py new file mode 100644 index 0000000000..73b67c8551 --- /dev/null +++ b/test/external/external_fuzz_beam_timeout_recovery.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +""" +Stress test for beam timeout + device recovery on AM devices. + +Usage: + AMD=1 python test/external/external_test_beam_timeout_recovery.py +""" +from tinygrad import Tensor, Device +from tinygrad.helpers import Context +from tinygrad.runtime.ops_amd import AMDDevice + +if __name__ == "__main__": + dev = Device["AMD"] + assert isinstance(dev, AMDDevice) and dev.is_am(), "not am" + + N = 10000 + for i in range(N): + with Context(DEBUG=0, BEAM=0): + a = Tensor.rand(4096, 4096, device="AMD").contiguous().realize() + b = Tensor.rand(4096, 4096, device="AMD").contiguous().realize() + c = a.matmul(b) + c.realize() + try: dev.synchronize(timeout=1) + except RuntimeError as e: print(e) + with Context(DEBUG=0, BEAM=0): + a = Tensor.ones(512, 512, device="AMD").contiguous().realize() + b = Tensor.ones(512, 512, device="AMD").contiguous().realize() + result = a.matmul(b).realize()[0, 0].item() + assert result == 512.0, f"iter {i}: got {result}" + print(f" iter {i+1}/{N}: ok") + print(f"=== All {N} iterations passed ===") diff --git a/test/external/external_test_speed_llama.py b/test/external/external_test_speed_llama.py index 7d50fd51af..1234113e77 100644 --- a/test/external/external_test_speed_llama.py +++ b/test/external/external_test_speed_llama.py @@ -10,7 +10,7 @@ from tinygrad.helpers import Profiling class FakeProgram: def __init__(self, name:str, prg:bytes, **kwargs): pass - def __call__(self, *bufs, global_size, local_size, vals=(), wait=False): pass + def __call__(self, *bufs, global_size, local_size, vals=(), wait=False, **kw): pass class FakeAllocator(Allocator[Compiled]): def _alloc(self, sz, options): return None diff --git a/tinygrad/codegen/opt/search.py b/tinygrad/codegen/opt/search.py index 13e86e8924..2e3df02330 100644 --- a/tinygrad/codegen/opt/search.py +++ b/tinygrad/codegen/opt/search.py @@ -36,7 +36,8 @@ def get_test_global_size(global_size, max_global_size, var_vals): return test_global_size, input_size / prod(test_global_size) def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:list[Buffer], early_stop:float|None=None, - allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test") -> list[float]: + allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test", dev_timeout=False) -> list[float]: + timeout = int(early_stop * 1e3) if dev_timeout and early_stop is not None and early_stop < math.inf else None factor = 1 if allow_test_size and max_global_size is not None: global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals) @@ -50,7 +51,7 @@ def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:lis if hasattr(dev:=Device[p.device], 'invalidate_caches'): dev.invalidate_caches() else: with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False) - tms.append(unwrap(car(input_bufs, var_vals, wait=True))*factor) + tms.append(unwrap(car(input_bufs, var_vals, wait=True, timeout=timeout))*factor) if early_stop is not None and early_stop < min(tms): break return tms @@ -161,7 +162,8 @@ def beam_search(s:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=True continue seen_libs.add(lib) try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, - allow_test_size=allow_test_size, clear_l2=hasattr(dev, 'invalidate_caches')) + allow_test_size=allow_test_size, clear_l2=hasattr(dev, 'invalidate_caches'), + dev_timeout=getenv("BEAM_DEV_TIMEOUT", 1)) except Exception as e: if BEAM_DEBUG: print(f"BEAM failed for opts: {candidates[i].applied_opts}\n{e}") if isinstance(e, RuntimeError): continue diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index b18546970b..f57e909171 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -50,7 +50,7 @@ class CompiledRunner(Runner): def __reduce__(self): return self.__class__, (self.p,) - def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False) -> float|None: + def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False, timeout:int|None=None) -> float|None: if var_vals is None: var_vals = {} global_size, local_size = self.p.launch_dims(var_vals) if Device[self.p.device].renderer.has_local and local_size is None and all_int(self.p.global_size): @@ -58,7 +58,7 @@ class CompiledRunner(Runner): global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)] self.p = replace(self.p, global_size=global_size, local_size=local_size) return self._prg(*[x._buf for x in rawbufs], global_size=tuple(global_size), local_size=tuple(local_size) if local_size else None, - vals=tuple(var_vals[k.expr] if k.expr not in self.p.runtimevars else None for k in self.p.vars), wait=wait) + vals=tuple(var_vals[k.expr] if k.expr not in self.p.runtimevars else None for k in self.p.vars), wait=wait, timeout=timeout) class ViewOp(Runner): def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device) diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 0fbdb7c0d8..58fa71ccf9 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -598,9 +598,10 @@ class AMDProgram(HCQProgram): base=self.lib_gpu.va_addr) weakref.finalize(self, self._fini, self.dev, self.lib_gpu, buf_spec) - def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(), wait=False): + def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(), + wait=False, timeout:int|None=None): if self.dev.sqtt_enabled: cast(AMDComputeQueue, self.dev.hw_compute_queue_t()).sqtt_start(self.dev.sqtt_buffers).submit(self.dev) - res = super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait) + res = super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait, timeout=timeout) if self.dev.pmc_enabled: cast(AMDComputeQueue, self.dev.hw_compute_queue_t()).pmc_read(self.dev.pmc_buffer, self.dev.pmc_sched) \ .signal(self.dev.timeline_signal, self.dev.next_timeline()).submit(self.dev) @@ -869,7 +870,7 @@ class PCIIface(PCIIfaceBase): devs:list[AMDDevice] = [d for pg in HCQCompiled.peer_groups.values() for d in pg if isinstance(d, AMDDevice) and d.is_am()] for d in devs: d.iface.dev_impl.ih.interrupt_handler() - if reset and d.iface.dev_impl.recover(): + if reset and d.iface.dev_impl.recover(force=d.error_state is not None): d.compute_queue.put_value, _ = d.iface.dev_impl.gfx.setup_ring(*d.compute_queue.params) d.compute_queue.read_ptr[0] = d.compute_queue.write_ptr[0] = d.compute_queue.put_value d.timeline_signal.value = d.timeline_value - 1 @@ -977,7 +978,8 @@ class AMDDevice(HCQCompiled): super().__init__(device, AMDAllocator(self), compilers, functools.partial(AMDProgram, self), AMDSignal, functools.partial(AMDComputeAQLQueue if self.is_aql else AMDComputeQueue, self), functools.partial(AMDCopyQueue, self, max_copy_size=self.max_copy_size) if self.has_sdma_queue else None, - kernargs_size=(8 << 10) if self.is_usb() else (16 << 20), sigalloc_size=0x100 if self.is_usb() else 0x1000) + kernargs_size=(8 << 10) if self.is_usb() else (16 << 20), sigalloc_size=0x100 if self.is_usb() else 0x1000, + can_recover=self.is_am()) # Scratch setup self.max_private_segment_size = 0 diff --git a/tinygrad/runtime/ops_cl.py b/tinygrad/runtime/ops_cl.py index 75f430df9c..cc2ff4ed94 100644 --- a/tinygrad/runtime/ops_cl.py +++ b/tinygrad/runtime/ops_cl.py @@ -54,7 +54,7 @@ class CLProgram: except (TypeError, AttributeError): pass def __call__(self, *bufs:tuple[cl.cl_mem, BufferSpec], global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]|None=None, - vals:tuple[int, ...]=(), wait=False) -> float|None: + vals:tuple[int, ...]=(), wait=False, **kw) -> float|None: i = 0 for i,(b,_) in enumerate(bufs): for real_i, dt in self.arg_dtypes[i]: diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 75b7dbb5df..311f918eaf 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -51,7 +51,7 @@ class CUDAProgram: @suppress_finalizing def __del__(self): check(cuda.cuModuleUnload(self.module)) - def __call__(self, *args, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False): + def __call__(self, *args, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw): check(cuda.cuCtxSetCurrent(self.dev.context)) if not hasattr(self, "vargs"): self.c_args, self.vargs = encode_args(args, vals) diff --git a/tinygrad/runtime/ops_dsp.py b/tinygrad/runtime/ops_dsp.py index c0694c879d..44fd550303 100644 --- a/tinygrad/runtime/ops_dsp.py +++ b/tinygrad/runtime/ops_dsp.py @@ -84,7 +84,7 @@ class DSPProgram: def __init__(self, dev:DSPDevice, name:str, lib:bytes, **kwargs): self.dev, self.lib = dev, lib - def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False): + def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw): if len(bufs) >= 16: raise RuntimeError(f"Too many buffers to execute: {len(bufs)}") pra, fds, attrs, _ = rpc_prep_args(ins=[var_vals_mv:=memoryview(bytearray((len(bufs)+len(vals))*4)), off_mv:=memoryview(bytearray(len(bufs)*4))], @@ -293,7 +293,7 @@ class MockDSPRenderer(DSPRenderer): class MockDSPProgram: def __init__(self, name:str, lib:bytes, **kwargs): self.lib = lib - def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False): + def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw): with tempfile.NamedTemporaryFile(suffix=".out") as dsp_lib: dsp_lib.write(self.lib) dsp_lib.flush() diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index ce5a4214b5..0586062088 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -32,7 +32,7 @@ class HIPProgram: def __del__(self): if hasattr(self, 'module'): check(hip.hipModuleUnload(self.module)) - def __call__(self, *args, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False): + def __call__(self, *args, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw): check(hip.hipSetDevice(self.dev.device_id)) if not hasattr(self, "vargs"): fields = [(f'f{i}', hip.hipDeviceptr_t, i*8) for i in range(len(args))] + [(f'v{i}', ctypes.c_int, len(args)*8+i*4) for i in range(len(vals))] diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index b004fdddbd..158234c441 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -123,7 +123,7 @@ class MetalProgram: # cache these msg calls self.max_total_threads: int = self.pipeline_state.maxTotalThreadsPerThreadgroup() - def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False): + def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw): if prod(local_size) > self.max_total_threads: exec_width = self.pipeline_state.threadExecutionWidth() memory_length = self.pipeline_state.staticThreadgroupMemoryLength() diff --git a/tinygrad/runtime/ops_null.py b/tinygrad/runtime/ops_null.py index 37519bbb12..529204562c 100644 --- a/tinygrad/runtime/ops_null.py +++ b/tinygrad/runtime/ops_null.py @@ -15,7 +15,7 @@ class NullRenderer(CStyleLanguage): class NullProgram: def __init__(self, device:str, name:str, lib:bytes, *args, **kwargs): self.device, self.name = device, name - def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False): + def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw): with cpu_profile(self.name, self.device): return 1e-3 class NullAllocator(Allocator['NullDevice']): diff --git a/tinygrad/runtime/ops_nv.py b/tinygrad/runtime/ops_nv.py index eb645bd0d2..2cec43faf6 100644 --- a/tinygrad/runtime/ops_nv.py +++ b/tinygrad/runtime/ops_nv.py @@ -312,12 +312,13 @@ class NVProgram(HCQProgram): yield typ, param, sh.content[start_off+4:start_off+sz+4] if typ == 0x4 else sz start_off += (sz if typ == 0x4 else 0) + 4 - def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(), wait=False): + def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(), + wait=False, timeout:int|None=None): if prod(local_size) > 1024 or self.max_threads < prod(local_size) or self.lcmem_usage > cast(NVDevice, self.dev).slm_per_thread: raise RuntimeError(f"Too many resources requested for launch, {prod(local_size)=}, {self.max_threads=}") if any(cur > mx for cur,mx in zip(global_size, [2147483647, 65535, 65535])) or any(cur > mx for cur,mx in zip(local_size, [1024, 1024, 64])): raise RuntimeError(f"Invalid global/local dims {global_size=}, {local_size=}") - res = super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait) + res = super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait, timeout=timeout) if self.dev.pma_enabled: self.dev.synchronize() if pma_blob:=self.dev._prof_readback(): diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 22092c7fed..484b862bb9 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -41,7 +41,7 @@ def generic_wmma_helper(inp, warp_size, WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_ class PythonProgram: def __init__(self, name:str, lib:bytes, **kwargs): self.uops: list[tuple[Ops, DType, list[int], Any]] = pickle.loads(lib) - def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False): + def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw): st = time.perf_counter() warp = list(itertools.product(*[range(x) for x in local_size[::-1]])) warp_size = len(warp) diff --git a/tinygrad/runtime/ops_qcom.py b/tinygrad/runtime/ops_qcom.py index fcfdb4abfe..632d1d79fd 100644 --- a/tinygrad/runtime/ops_qcom.py +++ b/tinygrad/runtime/ops_qcom.py @@ -266,7 +266,8 @@ class QCOMProgram(HCQProgram): super().__init__(QCOMArgsState, self.dev, self.name, kernargs_alloc_size=kernargs_alloc_size) weakref.finalize(self, self._fini, self.dev, self.lib_gpu, buf_spec) - def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(), wait=False): + def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), + vals:tuple[int|None, ...]=(), wait=False, **kw): if self.max_threads < prod(local_size): raise RuntimeError("Too many resources requested for launch") if any(g*l>mx for g,l,mx in zip(global_size, local_size, [65536, 65536, 65536])) and any(l>mx for l,mx in zip(local_size, [1024, 1024, 1024])): raise RuntimeError(f"Invalid global/local dims {global_size=}, {local_size=}") diff --git a/tinygrad/runtime/ops_webgpu.py b/tinygrad/runtime/ops_webgpu.py index 0d9f0c4b33..c41014b198 100644 --- a/tinygrad/runtime/ops_webgpu.py +++ b/tinygrad/runtime/ops_webgpu.py @@ -90,7 +90,7 @@ class WebGPUProgram: self.name, self.lib, self.prg = name, lib, shader_module def __call__(self, *bufs:WGPUBufPtr, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), - vals:tuple[int, ...]=(), wait=False) -> float|None: + vals:tuple[int, ...]=(), wait=False, **kw) -> float|None: wait = wait and self.timestamp_supported tmp_bufs = [*bufs] buf_patch = False diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index 514b5f2281..34416cd9af 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -225,13 +225,13 @@ class AMDev(PCIDevImplBase): self.ih.interrupt_handler() self.reg("regSCRATCH_REG6").write(self.is_err_state) # set finalized state. - def recover(self) -> bool: - if not self.is_err_state: return False - if DEBUG >= 2: print(f"am {self.devfmt}: Start recovery") + def recover(self, force=False) -> bool: + if not force and not self.is_err_state: return False + if DEBUG >= 3: print(f"am {self.devfmt}: Start recovery") self.ih.interrupt_handler() self.gfx.reset_mec() self.is_err_state = False - if DEBUG >= 2: print(f"am {self.devfmt}: Recovery complete") + if DEBUG >= 3: print(f"am {self.devfmt}: Recovery complete") return True def is_hive(self) -> bool: return self.gmc.xgmi_seg_sz > 0 diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index 21f06e601c..534a0376d0 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -290,13 +290,11 @@ class AM_GFX(AM_IP): def fini_hw(self): self._dequeue_hqds() def reset_mec(self): - self._dequeue_hqds(reset=True) + self._dequeue_hqds() - # issue a soft reset to reset aql sync counter on multixcc systems. - if self.xccs > 1: - for xcc in range(self.xccs): self.adev.regGRBM_SOFT_RESET.write(soft_reset_cp=1, soft_reset_gfx=1, inst=xcc) - time.sleep(0.05) - for xcc in range(self.xccs): self.adev.regGRBM_SOFT_RESET.write(0x0, inst=xcc) + for xcc in range(self.xccs): self.adev.regGRBM_SOFT_RESET.write(soft_reset_cp=1, soft_reset_cpc=1, inst=xcc) + time.sleep(0.05) + for xcc in range(self.xccs): self.adev.regGRBM_SOFT_RESET.write(0x0, inst=xcc) self._config_mec() self._enable_mec() @@ -384,13 +382,13 @@ class AM_GFX(AM_IP): if self.adev.ip_ver[am.GC_HWIP] >= (10,0,0): _config_helper(eng_name="MEC", cntl_reg="MEC_RS64", eng_reg="MEC_RS64", pipe_cnt=1, me=1, xcc=xcc) - def _dequeue_hqds(self, reset=False): + def _dequeue_hqds(self): for q in range(2): for xcc in range(self.xccs): self._grbm_select(me=1, pipe=0, queue=q, inst=xcc) if self.adev.regCP_HQD_ACTIVE.read(inst=xcc) & 1: self.adev.regCP_HQD_DEQUEUE_REQUEST.write(0x2, inst=xcc) # 1 - DRAIN_PIPE; 2 - RESET_WAVES - if not reset: wait_cond(lambda: self.adev.regCP_HQD_ACTIVE.read(inst=xcc) & 1, value=0, msg="HQD dequeue timeout") + if not self.adev.is_err_state: wait_cond(lambda: self.adev.regCP_HQD_ACTIVE.read(inst=xcc) & 1, value=0, msg="HQD dequeue timeout") self._grbm_select() class AM_IH(AM_IP): diff --git a/tinygrad/runtime/support/hcq.py b/tinygrad/runtime/support/hcq.py index a2d08a92dc..5cb6e7c4e9 100644 --- a/tinygrad/runtime/support/hcq.py +++ b/tinygrad/runtime/support/hcq.py @@ -253,7 +253,7 @@ class HCQSignal(Generic[HCQDeviceType]): Raises RuntimeError if a fault is detected. """ - def wait(self, value:int, timeout:int=getenv("HCQDEV_WAIT_TIMEOUT_MS", 30000)): + def wait(self, value:int, timeout:int|None=None): """ Waits the signal is greater than or equal to a specific value. @@ -261,6 +261,7 @@ class HCQSignal(Generic[HCQDeviceType]): value: The value to wait for. timeout: Maximum time to wait in milliseconds. Defaults to 30s. """ + timeout = timeout or getenv("HCQDEV_WAIT_TIMEOUT_MS", 30000) start_time = int(time.perf_counter() * 1000) while (not_passed:=(prev_value:=self.value) < value) and (cur_time:=int(time.perf_counter() * 1000)) - start_time < timeout: self._sleep(cur_time - start_time) @@ -325,7 +326,7 @@ class HCQProgram(Generic[HCQDeviceType]): return self.args_state_t(argsbuf, self, bufs, vals=vals) def __call__(self, *bufs:HCQBuffer, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), - vals:tuple[int|None, ...]=(), wait:bool=False) -> float|None: + vals:tuple[int|None, ...]=(), wait:bool=False, timeout:int|None=None) -> float|None: """ Enqueues the program for execution with the given arguments and dimensions. @@ -349,7 +350,7 @@ class HCQProgram(Generic[HCQDeviceType]): q.signal(self.dev.timeline_signal, self.dev.next_timeline()).submit(self.dev) - if wait: self.dev.synchronize() + if wait: self.dev.synchronize(timeout=timeout) return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None class HCQCompiled(Compiled, Generic[SignalType]): @@ -362,7 +363,8 @@ class HCQCompiled(Compiled, Generic[SignalType]): cpu_devices: list[HCQCompiled] = [] def __init__(self, device:str, allocator:HCQAllocatorBase, compilers:CompilerSet, runtime, signal_t:Type[SignalType], - comp_queue_t:Callable[..., HWQueue], copy_queue_t:Callable[..., HWQueue]|None=None, kernargs_size=(16 << 20), sigalloc_size=0x1000): + comp_queue_t:Callable[..., HWQueue], copy_queue_t:Callable[..., HWQueue]|None=None, kernargs_size=(16 << 20), sigalloc_size=0x1000, + can_recover:bool=False): self.device_id:int = int(device.split(":")[1]) if ":" in device else 0 from tinygrad.runtime.graph.hcq import HCQGraph @@ -386,22 +388,23 @@ class HCQCompiled(Compiled, Generic[SignalType]): self.kernargs_buf:HCQBuffer = self.allocator.alloc(kernargs_size, BufferSpec(cpu_access=True)) self.kernargs_offset_allocator:BumpAllocator = BumpAllocator(self.kernargs_buf.size, wrap=True) + self.can_recover = can_recover # Whether the device can recover from faults or timeouts self.error_state:Exception|None = None # Exception if error is unrecoverable and sync will always fail if self._is_cpu(): HCQCompiled.cpu_devices.append(self) - def synchronize(self): + def synchronize(self, timeout:int|None=None): if self.error_state is not None: raise self.error_state # If we have any work on CPU devices, need to synchronize them. This is just an optimization to release GIL allowing to finish faster. if not self._is_cpu(): for dev in HCQCompiled.cpu_devices: dev.synchronize() - try: self.timeline_signal.wait(self.timeline_value - 1) + try: self.timeline_signal.wait(self.timeline_value - 1, timeout=timeout if timeout is not None and self.can_recover else None) except RuntimeError as e: self.error_state = e if hasattr(self, 'on_device_hang'): self.on_device_hang() - else: raise e + raise e if self.timeline_value > (1 << 31): self._wrap_timeline_signal() if PROFILE: