beam: add dev_timeout for am (#15063)

* beam: add dev_timeout for am

* all covered

* fk

* x

* fuzz

* reset

* f
This commit is contained in:
nimlgen
2026-03-01 16:57:29 +03:00
committed by GitHub
parent 8cb4368967
commit ccbbca05ef
18 changed files with 79 additions and 41 deletions

View File

@@ -0,0 +1,31 @@
#!/usr/bin/env python3
"""
Stress test for beam timeout + device recovery on AM devices.
Usage:
AMD=1 python test/external/external_test_beam_timeout_recovery.py
"""
from tinygrad import Tensor, Device
from tinygrad.helpers import Context
from tinygrad.runtime.ops_amd import AMDDevice
if __name__ == "__main__":
dev = Device["AMD"]
assert isinstance(dev, AMDDevice) and dev.is_am(), "not am"
N = 10000
for i in range(N):
with Context(DEBUG=0, BEAM=0):
a = Tensor.rand(4096, 4096, device="AMD").contiguous().realize()
b = Tensor.rand(4096, 4096, device="AMD").contiguous().realize()
c = a.matmul(b)
c.realize()
try: dev.synchronize(timeout=1)
except RuntimeError as e: print(e)
with Context(DEBUG=0, BEAM=0):
a = Tensor.ones(512, 512, device="AMD").contiguous().realize()
b = Tensor.ones(512, 512, device="AMD").contiguous().realize()
result = a.matmul(b).realize()[0, 0].item()
assert result == 512.0, f"iter {i}: got {result}"
print(f" iter {i+1}/{N}: ok")
print(f"=== All {N} iterations passed ===")

View File

@@ -10,7 +10,7 @@ from tinygrad.helpers import Profiling
class FakeProgram:
def __init__(self, name:str, prg:bytes, **kwargs): pass
def __call__(self, *bufs, global_size, local_size, vals=(), wait=False): pass
def __call__(self, *bufs, global_size, local_size, vals=(), wait=False, **kw): pass
class FakeAllocator(Allocator[Compiled]):
def _alloc(self, sz, options): return None

View File

@@ -36,7 +36,8 @@ def get_test_global_size(global_size, max_global_size, var_vals):
return test_global_size, input_size / prod(test_global_size)
def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:list[Buffer], early_stop:float|None=None,
allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test", dev_timeout=False) -> list[float]:
timeout = int(early_stop * 1e3) if dev_timeout and early_stop is not None and early_stop < math.inf else None
factor = 1
if allow_test_size and max_global_size is not None:
global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals)
@@ -50,7 +51,7 @@ def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:lis
if hasattr(dev:=Device[p.device], 'invalidate_caches'): dev.invalidate_caches()
else:
with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
tms.append(unwrap(car(input_bufs, var_vals, wait=True))*factor)
tms.append(unwrap(car(input_bufs, var_vals, wait=True, timeout=timeout))*factor)
if early_stop is not None and early_stop < min(tms): break
return tms
@@ -161,7 +162,8 @@ def beam_search(s:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=True
continue
seen_libs.add(lib)
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0,
allow_test_size=allow_test_size, clear_l2=hasattr(dev, 'invalidate_caches'))
allow_test_size=allow_test_size, clear_l2=hasattr(dev, 'invalidate_caches'),
dev_timeout=getenv("BEAM_DEV_TIMEOUT", 1))
except Exception as e:
if BEAM_DEBUG: print(f"BEAM failed for opts: {candidates[i].applied_opts}\n{e}")
if isinstance(e, RuntimeError): continue

View File

@@ -50,7 +50,7 @@ class CompiledRunner(Runner):
def __reduce__(self): return self.__class__, (self.p,)
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False) -> float|None:
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False, timeout:int|None=None) -> float|None:
if var_vals is None: var_vals = {}
global_size, local_size = self.p.launch_dims(var_vals)
if Device[self.p.device].renderer.has_local and local_size is None and all_int(self.p.global_size):
@@ -58,7 +58,7 @@ class CompiledRunner(Runner):
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
self.p = replace(self.p, global_size=global_size, local_size=local_size)
return self._prg(*[x._buf for x in rawbufs], global_size=tuple(global_size), local_size=tuple(local_size) if local_size else None,
vals=tuple(var_vals[k.expr] if k.expr not in self.p.runtimevars else None for k in self.p.vars), wait=wait)
vals=tuple(var_vals[k.expr] if k.expr not in self.p.runtimevars else None for k in self.p.vars), wait=wait, timeout=timeout)
class ViewOp(Runner):
def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)

View File

@@ -598,9 +598,10 @@ class AMDProgram(HCQProgram):
base=self.lib_gpu.va_addr)
weakref.finalize(self, self._fini, self.dev, self.lib_gpu, buf_spec)
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(), wait=False):
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(),
wait=False, timeout:int|None=None):
if self.dev.sqtt_enabled: cast(AMDComputeQueue, self.dev.hw_compute_queue_t()).sqtt_start(self.dev.sqtt_buffers).submit(self.dev)
res = super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait)
res = super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait, timeout=timeout)
if self.dev.pmc_enabled:
cast(AMDComputeQueue, self.dev.hw_compute_queue_t()).pmc_read(self.dev.pmc_buffer, self.dev.pmc_sched) \
.signal(self.dev.timeline_signal, self.dev.next_timeline()).submit(self.dev)
@@ -869,7 +870,7 @@ class PCIIface(PCIIfaceBase):
devs:list[AMDDevice] = [d for pg in HCQCompiled.peer_groups.values() for d in pg if isinstance(d, AMDDevice) and d.is_am()]
for d in devs:
d.iface.dev_impl.ih.interrupt_handler()
if reset and d.iface.dev_impl.recover():
if reset and d.iface.dev_impl.recover(force=d.error_state is not None):
d.compute_queue.put_value, _ = d.iface.dev_impl.gfx.setup_ring(*d.compute_queue.params)
d.compute_queue.read_ptr[0] = d.compute_queue.write_ptr[0] = d.compute_queue.put_value
d.timeline_signal.value = d.timeline_value - 1
@@ -977,7 +978,8 @@ class AMDDevice(HCQCompiled):
super().__init__(device, AMDAllocator(self), compilers, functools.partial(AMDProgram, self), AMDSignal,
functools.partial(AMDComputeAQLQueue if self.is_aql else AMDComputeQueue, self),
functools.partial(AMDCopyQueue, self, max_copy_size=self.max_copy_size) if self.has_sdma_queue else None,
kernargs_size=(8 << 10) if self.is_usb() else (16 << 20), sigalloc_size=0x100 if self.is_usb() else 0x1000)
kernargs_size=(8 << 10) if self.is_usb() else (16 << 20), sigalloc_size=0x100 if self.is_usb() else 0x1000,
can_recover=self.is_am())
# Scratch setup
self.max_private_segment_size = 0

View File

@@ -54,7 +54,7 @@ class CLProgram:
except (TypeError, AttributeError): pass
def __call__(self, *bufs:tuple[cl.cl_mem, BufferSpec], global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]|None=None,
vals:tuple[int, ...]=(), wait=False) -> float|None:
vals:tuple[int, ...]=(), wait=False, **kw) -> float|None:
i = 0
for i,(b,_) in enumerate(bufs):
for real_i, dt in self.arg_dtypes[i]:

View File

@@ -51,7 +51,7 @@ class CUDAProgram:
@suppress_finalizing
def __del__(self): check(cuda.cuModuleUnload(self.module))
def __call__(self, *args, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
def __call__(self, *args, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw):
check(cuda.cuCtxSetCurrent(self.dev.context))
if not hasattr(self, "vargs"):
self.c_args, self.vargs = encode_args(args, vals)

View File

@@ -84,7 +84,7 @@ class DSPProgram:
def __init__(self, dev:DSPDevice, name:str, lib:bytes, **kwargs):
self.dev, self.lib = dev, lib
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw):
if len(bufs) >= 16: raise RuntimeError(f"Too many buffers to execute: {len(bufs)}")
pra, fds, attrs, _ = rpc_prep_args(ins=[var_vals_mv:=memoryview(bytearray((len(bufs)+len(vals))*4)), off_mv:=memoryview(bytearray(len(bufs)*4))],
@@ -293,7 +293,7 @@ class MockDSPRenderer(DSPRenderer):
class MockDSPProgram:
def __init__(self, name:str, lib:bytes, **kwargs): self.lib = lib
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw):
with tempfile.NamedTemporaryFile(suffix=".out") as dsp_lib:
dsp_lib.write(self.lib)
dsp_lib.flush()

View File

@@ -32,7 +32,7 @@ class HIPProgram:
def __del__(self):
if hasattr(self, 'module'): check(hip.hipModuleUnload(self.module))
def __call__(self, *args, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
def __call__(self, *args, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw):
check(hip.hipSetDevice(self.dev.device_id))
if not hasattr(self, "vargs"):
fields = [(f'f{i}', hip.hipDeviceptr_t, i*8) for i in range(len(args))] + [(f'v{i}', ctypes.c_int, len(args)*8+i*4) for i in range(len(vals))]

View File

@@ -123,7 +123,7 @@ class MetalProgram:
# cache these msg calls
self.max_total_threads: int = self.pipeline_state.maxTotalThreadsPerThreadgroup()
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw):
if prod(local_size) > self.max_total_threads:
exec_width = self.pipeline_state.threadExecutionWidth()
memory_length = self.pipeline_state.staticThreadgroupMemoryLength()

View File

@@ -15,7 +15,7 @@ class NullRenderer(CStyleLanguage):
class NullProgram:
def __init__(self, device:str, name:str, lib:bytes, *args, **kwargs): self.device, self.name = device, name
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw):
with cpu_profile(self.name, self.device): return 1e-3
class NullAllocator(Allocator['NullDevice']):

View File

@@ -312,12 +312,13 @@ class NVProgram(HCQProgram):
yield typ, param, sh.content[start_off+4:start_off+sz+4] if typ == 0x4 else sz
start_off += (sz if typ == 0x4 else 0) + 4
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(), wait=False):
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(),
wait=False, timeout:int|None=None):
if prod(local_size) > 1024 or self.max_threads < prod(local_size) or self.lcmem_usage > cast(NVDevice, self.dev).slm_per_thread:
raise RuntimeError(f"Too many resources requested for launch, {prod(local_size)=}, {self.max_threads=}")
if any(cur > mx for cur,mx in zip(global_size, [2147483647, 65535, 65535])) or any(cur > mx for cur,mx in zip(local_size, [1024, 1024, 64])):
raise RuntimeError(f"Invalid global/local dims {global_size=}, {local_size=}")
res = super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait)
res = super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait, timeout=timeout)
if self.dev.pma_enabled:
self.dev.synchronize()
if pma_blob:=self.dev._prof_readback():

View File

@@ -41,7 +41,7 @@ def generic_wmma_helper(inp, warp_size, WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_
class PythonProgram:
def __init__(self, name:str, lib:bytes, **kwargs):
self.uops: list[tuple[Ops, DType, list[int], Any]] = pickle.loads(lib)
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw):
st = time.perf_counter()
warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
warp_size = len(warp)

View File

@@ -266,7 +266,8 @@ class QCOMProgram(HCQProgram):
super().__init__(QCOMArgsState, self.dev, self.name, kernargs_alloc_size=kernargs_alloc_size)
weakref.finalize(self, self._fini, self.dev, self.lib_gpu, buf_spec)
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(), wait=False):
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
vals:tuple[int|None, ...]=(), wait=False, **kw):
if self.max_threads < prod(local_size): raise RuntimeError("Too many resources requested for launch")
if any(g*l>mx for g,l,mx in zip(global_size, local_size, [65536, 65536, 65536])) and any(l>mx for l,mx in zip(local_size, [1024, 1024, 1024])):
raise RuntimeError(f"Invalid global/local dims {global_size=}, {local_size=}")

View File

@@ -90,7 +90,7 @@ class WebGPUProgram:
self.name, self.lib, self.prg = name, lib, shader_module
def __call__(self, *bufs:WGPUBufPtr, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
vals:tuple[int, ...]=(), wait=False) -> float|None:
vals:tuple[int, ...]=(), wait=False, **kw) -> float|None:
wait = wait and self.timestamp_supported
tmp_bufs = [*bufs]
buf_patch = False

View File

@@ -225,13 +225,13 @@ class AMDev(PCIDevImplBase):
self.ih.interrupt_handler()
self.reg("regSCRATCH_REG6").write(self.is_err_state) # set finalized state.
def recover(self) -> bool:
if not self.is_err_state: return False
if DEBUG >= 2: print(f"am {self.devfmt}: Start recovery")
def recover(self, force=False) -> bool:
if not force and not self.is_err_state: return False
if DEBUG >= 3: print(f"am {self.devfmt}: Start recovery")
self.ih.interrupt_handler()
self.gfx.reset_mec()
self.is_err_state = False
if DEBUG >= 2: print(f"am {self.devfmt}: Recovery complete")
if DEBUG >= 3: print(f"am {self.devfmt}: Recovery complete")
return True
def is_hive(self) -> bool: return self.gmc.xgmi_seg_sz > 0

View File

@@ -290,13 +290,11 @@ class AM_GFX(AM_IP):
def fini_hw(self): self._dequeue_hqds()
def reset_mec(self):
self._dequeue_hqds(reset=True)
self._dequeue_hqds()
# issue a soft reset to reset aql sync counter on multixcc systems.
if self.xccs > 1:
for xcc in range(self.xccs): self.adev.regGRBM_SOFT_RESET.write(soft_reset_cp=1, soft_reset_gfx=1, inst=xcc)
time.sleep(0.05)
for xcc in range(self.xccs): self.adev.regGRBM_SOFT_RESET.write(0x0, inst=xcc)
for xcc in range(self.xccs): self.adev.regGRBM_SOFT_RESET.write(soft_reset_cp=1, soft_reset_cpc=1, inst=xcc)
time.sleep(0.05)
for xcc in range(self.xccs): self.adev.regGRBM_SOFT_RESET.write(0x0, inst=xcc)
self._config_mec()
self._enable_mec()
@@ -384,13 +382,13 @@ class AM_GFX(AM_IP):
if self.adev.ip_ver[am.GC_HWIP] >= (10,0,0):
_config_helper(eng_name="MEC", cntl_reg="MEC_RS64", eng_reg="MEC_RS64", pipe_cnt=1, me=1, xcc=xcc)
def _dequeue_hqds(self, reset=False):
def _dequeue_hqds(self):
for q in range(2):
for xcc in range(self.xccs):
self._grbm_select(me=1, pipe=0, queue=q, inst=xcc)
if self.adev.regCP_HQD_ACTIVE.read(inst=xcc) & 1:
self.adev.regCP_HQD_DEQUEUE_REQUEST.write(0x2, inst=xcc) # 1 - DRAIN_PIPE; 2 - RESET_WAVES
if not reset: wait_cond(lambda: self.adev.regCP_HQD_ACTIVE.read(inst=xcc) & 1, value=0, msg="HQD dequeue timeout")
if not self.adev.is_err_state: wait_cond(lambda: self.adev.regCP_HQD_ACTIVE.read(inst=xcc) & 1, value=0, msg="HQD dequeue timeout")
self._grbm_select()
class AM_IH(AM_IP):

View File

@@ -253,7 +253,7 @@ class HCQSignal(Generic[HCQDeviceType]):
Raises RuntimeError if a fault is detected.
"""
def wait(self, value:int, timeout:int=getenv("HCQDEV_WAIT_TIMEOUT_MS", 30000)):
def wait(self, value:int, timeout:int|None=None):
"""
Waits the signal is greater than or equal to a specific value.
@@ -261,6 +261,7 @@ class HCQSignal(Generic[HCQDeviceType]):
value: The value to wait for.
timeout: Maximum time to wait in milliseconds. Defaults to 30s.
"""
timeout = timeout or getenv("HCQDEV_WAIT_TIMEOUT_MS", 30000)
start_time = int(time.perf_counter() * 1000)
while (not_passed:=(prev_value:=self.value) < value) and (cur_time:=int(time.perf_counter() * 1000)) - start_time < timeout:
self._sleep(cur_time - start_time)
@@ -325,7 +326,7 @@ class HCQProgram(Generic[HCQDeviceType]):
return self.args_state_t(argsbuf, self, bufs, vals=vals)
def __call__(self, *bufs:HCQBuffer, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
vals:tuple[int|None, ...]=(), wait:bool=False) -> float|None:
vals:tuple[int|None, ...]=(), wait:bool=False, timeout:int|None=None) -> float|None:
"""
Enqueues the program for execution with the given arguments and dimensions.
@@ -349,7 +350,7 @@ class HCQProgram(Generic[HCQDeviceType]):
q.signal(self.dev.timeline_signal, self.dev.next_timeline()).submit(self.dev)
if wait: self.dev.synchronize()
if wait: self.dev.synchronize(timeout=timeout)
return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
class HCQCompiled(Compiled, Generic[SignalType]):
@@ -362,7 +363,8 @@ class HCQCompiled(Compiled, Generic[SignalType]):
cpu_devices: list[HCQCompiled] = []
def __init__(self, device:str, allocator:HCQAllocatorBase, compilers:CompilerSet, runtime, signal_t:Type[SignalType],
comp_queue_t:Callable[..., HWQueue], copy_queue_t:Callable[..., HWQueue]|None=None, kernargs_size=(16 << 20), sigalloc_size=0x1000):
comp_queue_t:Callable[..., HWQueue], copy_queue_t:Callable[..., HWQueue]|None=None, kernargs_size=(16 << 20), sigalloc_size=0x1000,
can_recover:bool=False):
self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
from tinygrad.runtime.graph.hcq import HCQGraph
@@ -386,22 +388,23 @@ class HCQCompiled(Compiled, Generic[SignalType]):
self.kernargs_buf:HCQBuffer = self.allocator.alloc(kernargs_size, BufferSpec(cpu_access=True))
self.kernargs_offset_allocator:BumpAllocator = BumpAllocator(self.kernargs_buf.size, wrap=True)
self.can_recover = can_recover # Whether the device can recover from faults or timeouts
self.error_state:Exception|None = None # Exception if error is unrecoverable and sync will always fail
if self._is_cpu(): HCQCompiled.cpu_devices.append(self)
def synchronize(self):
def synchronize(self, timeout:int|None=None):
if self.error_state is not None: raise self.error_state
# If we have any work on CPU devices, need to synchronize them. This is just an optimization to release GIL allowing to finish faster.
if not self._is_cpu():
for dev in HCQCompiled.cpu_devices: dev.synchronize()
try: self.timeline_signal.wait(self.timeline_value - 1)
try: self.timeline_signal.wait(self.timeline_value - 1, timeout=timeout if timeout is not None and self.can_recover else None)
except RuntimeError as e:
self.error_state = e
if hasattr(self, 'on_device_hang'): self.on_device_hang()
else: raise e
raise e
if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
if PROFILE: