mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
beam: add dev_timeout for am (#15063)
* beam: add dev_timeout for am * all covered * fk * x * fuzz * reset * f
This commit is contained in:
31
test/external/external_fuzz_beam_timeout_recovery.py
vendored
Normal file
31
test/external/external_fuzz_beam_timeout_recovery.py
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Stress test for beam timeout + device recovery on AM devices.
|
||||
|
||||
Usage:
|
||||
AMD=1 python test/external/external_test_beam_timeout_recovery.py
|
||||
"""
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.helpers import Context
|
||||
from tinygrad.runtime.ops_amd import AMDDevice
|
||||
|
||||
if __name__ == "__main__":
|
||||
dev = Device["AMD"]
|
||||
assert isinstance(dev, AMDDevice) and dev.is_am(), "not am"
|
||||
|
||||
N = 10000
|
||||
for i in range(N):
|
||||
with Context(DEBUG=0, BEAM=0):
|
||||
a = Tensor.rand(4096, 4096, device="AMD").contiguous().realize()
|
||||
b = Tensor.rand(4096, 4096, device="AMD").contiguous().realize()
|
||||
c = a.matmul(b)
|
||||
c.realize()
|
||||
try: dev.synchronize(timeout=1)
|
||||
except RuntimeError as e: print(e)
|
||||
with Context(DEBUG=0, BEAM=0):
|
||||
a = Tensor.ones(512, 512, device="AMD").contiguous().realize()
|
||||
b = Tensor.ones(512, 512, device="AMD").contiguous().realize()
|
||||
result = a.matmul(b).realize()[0, 0].item()
|
||||
assert result == 512.0, f"iter {i}: got {result}"
|
||||
print(f" iter {i+1}/{N}: ok")
|
||||
print(f"=== All {N} iterations passed ===")
|
||||
2
test/external/external_test_speed_llama.py
vendored
2
test/external/external_test_speed_llama.py
vendored
@@ -10,7 +10,7 @@ from tinygrad.helpers import Profiling
|
||||
|
||||
class FakeProgram:
|
||||
def __init__(self, name:str, prg:bytes, **kwargs): pass
|
||||
def __call__(self, *bufs, global_size, local_size, vals=(), wait=False): pass
|
||||
def __call__(self, *bufs, global_size, local_size, vals=(), wait=False, **kw): pass
|
||||
|
||||
class FakeAllocator(Allocator[Compiled]):
|
||||
def _alloc(self, sz, options): return None
|
||||
|
||||
@@ -36,7 +36,8 @@ def get_test_global_size(global_size, max_global_size, var_vals):
|
||||
return test_global_size, input_size / prod(test_global_size)
|
||||
|
||||
def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:list[Buffer], early_stop:float|None=None,
|
||||
allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
|
||||
allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test", dev_timeout=False) -> list[float]:
|
||||
timeout = int(early_stop * 1e3) if dev_timeout and early_stop is not None and early_stop < math.inf else None
|
||||
factor = 1
|
||||
if allow_test_size and max_global_size is not None:
|
||||
global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals)
|
||||
@@ -50,7 +51,7 @@ def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[str, int], rawbufs:lis
|
||||
if hasattr(dev:=Device[p.device], 'invalidate_caches'): dev.invalidate_caches()
|
||||
else:
|
||||
with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
|
||||
tms.append(unwrap(car(input_bufs, var_vals, wait=True))*factor)
|
||||
tms.append(unwrap(car(input_bufs, var_vals, wait=True, timeout=timeout))*factor)
|
||||
if early_stop is not None and early_stop < min(tms): break
|
||||
return tms
|
||||
|
||||
@@ -161,7 +162,8 @@ def beam_search(s:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=True
|
||||
continue
|
||||
seen_libs.add(lib)
|
||||
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0,
|
||||
allow_test_size=allow_test_size, clear_l2=hasattr(dev, 'invalidate_caches'))
|
||||
allow_test_size=allow_test_size, clear_l2=hasattr(dev, 'invalidate_caches'),
|
||||
dev_timeout=getenv("BEAM_DEV_TIMEOUT", 1))
|
||||
except Exception as e:
|
||||
if BEAM_DEBUG: print(f"BEAM failed for opts: {candidates[i].applied_opts}\n{e}")
|
||||
if isinstance(e, RuntimeError): continue
|
||||
|
||||
@@ -50,7 +50,7 @@ class CompiledRunner(Runner):
|
||||
|
||||
def __reduce__(self): return self.__class__, (self.p,)
|
||||
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False) -> float|None:
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int]|None=None, wait=False, timeout:int|None=None) -> float|None:
|
||||
if var_vals is None: var_vals = {}
|
||||
global_size, local_size = self.p.launch_dims(var_vals)
|
||||
if Device[self.p.device].renderer.has_local and local_size is None and all_int(self.p.global_size):
|
||||
@@ -58,7 +58,7 @@ class CompiledRunner(Runner):
|
||||
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
|
||||
self.p = replace(self.p, global_size=global_size, local_size=local_size)
|
||||
return self._prg(*[x._buf for x in rawbufs], global_size=tuple(global_size), local_size=tuple(local_size) if local_size else None,
|
||||
vals=tuple(var_vals[k.expr] if k.expr not in self.p.runtimevars else None for k in self.p.vars), wait=wait)
|
||||
vals=tuple(var_vals[k.expr] if k.expr not in self.p.runtimevars else None for k in self.p.vars), wait=wait, timeout=timeout)
|
||||
|
||||
class ViewOp(Runner):
|
||||
def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)
|
||||
|
||||
@@ -598,9 +598,10 @@ class AMDProgram(HCQProgram):
|
||||
base=self.lib_gpu.va_addr)
|
||||
weakref.finalize(self, self._fini, self.dev, self.lib_gpu, buf_spec)
|
||||
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(), wait=False):
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(),
|
||||
wait=False, timeout:int|None=None):
|
||||
if self.dev.sqtt_enabled: cast(AMDComputeQueue, self.dev.hw_compute_queue_t()).sqtt_start(self.dev.sqtt_buffers).submit(self.dev)
|
||||
res = super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait)
|
||||
res = super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait, timeout=timeout)
|
||||
if self.dev.pmc_enabled:
|
||||
cast(AMDComputeQueue, self.dev.hw_compute_queue_t()).pmc_read(self.dev.pmc_buffer, self.dev.pmc_sched) \
|
||||
.signal(self.dev.timeline_signal, self.dev.next_timeline()).submit(self.dev)
|
||||
@@ -869,7 +870,7 @@ class PCIIface(PCIIfaceBase):
|
||||
devs:list[AMDDevice] = [d for pg in HCQCompiled.peer_groups.values() for d in pg if isinstance(d, AMDDevice) and d.is_am()]
|
||||
for d in devs:
|
||||
d.iface.dev_impl.ih.interrupt_handler()
|
||||
if reset and d.iface.dev_impl.recover():
|
||||
if reset and d.iface.dev_impl.recover(force=d.error_state is not None):
|
||||
d.compute_queue.put_value, _ = d.iface.dev_impl.gfx.setup_ring(*d.compute_queue.params)
|
||||
d.compute_queue.read_ptr[0] = d.compute_queue.write_ptr[0] = d.compute_queue.put_value
|
||||
d.timeline_signal.value = d.timeline_value - 1
|
||||
@@ -977,7 +978,8 @@ class AMDDevice(HCQCompiled):
|
||||
super().__init__(device, AMDAllocator(self), compilers, functools.partial(AMDProgram, self), AMDSignal,
|
||||
functools.partial(AMDComputeAQLQueue if self.is_aql else AMDComputeQueue, self),
|
||||
functools.partial(AMDCopyQueue, self, max_copy_size=self.max_copy_size) if self.has_sdma_queue else None,
|
||||
kernargs_size=(8 << 10) if self.is_usb() else (16 << 20), sigalloc_size=0x100 if self.is_usb() else 0x1000)
|
||||
kernargs_size=(8 << 10) if self.is_usb() else (16 << 20), sigalloc_size=0x100 if self.is_usb() else 0x1000,
|
||||
can_recover=self.is_am())
|
||||
|
||||
# Scratch setup
|
||||
self.max_private_segment_size = 0
|
||||
|
||||
@@ -54,7 +54,7 @@ class CLProgram:
|
||||
except (TypeError, AttributeError): pass
|
||||
|
||||
def __call__(self, *bufs:tuple[cl.cl_mem, BufferSpec], global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]|None=None,
|
||||
vals:tuple[int, ...]=(), wait=False) -> float|None:
|
||||
vals:tuple[int, ...]=(), wait=False, **kw) -> float|None:
|
||||
i = 0
|
||||
for i,(b,_) in enumerate(bufs):
|
||||
for real_i, dt in self.arg_dtypes[i]:
|
||||
|
||||
@@ -51,7 +51,7 @@ class CUDAProgram:
|
||||
@suppress_finalizing
|
||||
def __del__(self): check(cuda.cuModuleUnload(self.module))
|
||||
|
||||
def __call__(self, *args, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
|
||||
def __call__(self, *args, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw):
|
||||
check(cuda.cuCtxSetCurrent(self.dev.context))
|
||||
if not hasattr(self, "vargs"):
|
||||
self.c_args, self.vargs = encode_args(args, vals)
|
||||
|
||||
@@ -84,7 +84,7 @@ class DSPProgram:
|
||||
def __init__(self, dev:DSPDevice, name:str, lib:bytes, **kwargs):
|
||||
self.dev, self.lib = dev, lib
|
||||
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw):
|
||||
if len(bufs) >= 16: raise RuntimeError(f"Too many buffers to execute: {len(bufs)}")
|
||||
|
||||
pra, fds, attrs, _ = rpc_prep_args(ins=[var_vals_mv:=memoryview(bytearray((len(bufs)+len(vals))*4)), off_mv:=memoryview(bytearray(len(bufs)*4))],
|
||||
@@ -293,7 +293,7 @@ class MockDSPRenderer(DSPRenderer):
|
||||
|
||||
class MockDSPProgram:
|
||||
def __init__(self, name:str, lib:bytes, **kwargs): self.lib = lib
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw):
|
||||
with tempfile.NamedTemporaryFile(suffix=".out") as dsp_lib:
|
||||
dsp_lib.write(self.lib)
|
||||
dsp_lib.flush()
|
||||
|
||||
@@ -32,7 +32,7 @@ class HIPProgram:
|
||||
def __del__(self):
|
||||
if hasattr(self, 'module'): check(hip.hipModuleUnload(self.module))
|
||||
|
||||
def __call__(self, *args, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
|
||||
def __call__(self, *args, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw):
|
||||
check(hip.hipSetDevice(self.dev.device_id))
|
||||
if not hasattr(self, "vargs"):
|
||||
fields = [(f'f{i}', hip.hipDeviceptr_t, i*8) for i in range(len(args))] + [(f'v{i}', ctypes.c_int, len(args)*8+i*4) for i in range(len(vals))]
|
||||
|
||||
@@ -123,7 +123,7 @@ class MetalProgram:
|
||||
# cache these msg calls
|
||||
self.max_total_threads: int = self.pipeline_state.maxTotalThreadsPerThreadgroup()
|
||||
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw):
|
||||
if prod(local_size) > self.max_total_threads:
|
||||
exec_width = self.pipeline_state.threadExecutionWidth()
|
||||
memory_length = self.pipeline_state.staticThreadgroupMemoryLength()
|
||||
|
||||
@@ -15,7 +15,7 @@ class NullRenderer(CStyleLanguage):
|
||||
|
||||
class NullProgram:
|
||||
def __init__(self, device:str, name:str, lib:bytes, *args, **kwargs): self.device, self.name = device, name
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw):
|
||||
with cpu_profile(self.name, self.device): return 1e-3
|
||||
|
||||
class NullAllocator(Allocator['NullDevice']):
|
||||
|
||||
@@ -312,12 +312,13 @@ class NVProgram(HCQProgram):
|
||||
yield typ, param, sh.content[start_off+4:start_off+sz+4] if typ == 0x4 else sz
|
||||
start_off += (sz if typ == 0x4 else 0) + 4
|
||||
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(), wait=False):
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(),
|
||||
wait=False, timeout:int|None=None):
|
||||
if prod(local_size) > 1024 or self.max_threads < prod(local_size) or self.lcmem_usage > cast(NVDevice, self.dev).slm_per_thread:
|
||||
raise RuntimeError(f"Too many resources requested for launch, {prod(local_size)=}, {self.max_threads=}")
|
||||
if any(cur > mx for cur,mx in zip(global_size, [2147483647, 65535, 65535])) or any(cur > mx for cur,mx in zip(local_size, [1024, 1024, 64])):
|
||||
raise RuntimeError(f"Invalid global/local dims {global_size=}, {local_size=}")
|
||||
res = super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait)
|
||||
res = super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait, timeout=timeout)
|
||||
if self.dev.pma_enabled:
|
||||
self.dev.synchronize()
|
||||
if pma_blob:=self.dev._prof_readback():
|
||||
|
||||
@@ -41,7 +41,7 @@ def generic_wmma_helper(inp, warp_size, WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_
|
||||
class PythonProgram:
|
||||
def __init__(self, name:str, lib:bytes, **kwargs):
|
||||
self.uops: list[tuple[Ops, DType, list[int], Any]] = pickle.loads(lib)
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw):
|
||||
st = time.perf_counter()
|
||||
warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
|
||||
warp_size = len(warp)
|
||||
|
||||
@@ -266,7 +266,8 @@ class QCOMProgram(HCQProgram):
|
||||
super().__init__(QCOMArgsState, self.dev, self.name, kernargs_alloc_size=kernargs_alloc_size)
|
||||
weakref.finalize(self, self._fini, self.dev, self.lib_gpu, buf_spec)
|
||||
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(), wait=False):
|
||||
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
|
||||
vals:tuple[int|None, ...]=(), wait=False, **kw):
|
||||
if self.max_threads < prod(local_size): raise RuntimeError("Too many resources requested for launch")
|
||||
if any(g*l>mx for g,l,mx in zip(global_size, local_size, [65536, 65536, 65536])) and any(l>mx for l,mx in zip(local_size, [1024, 1024, 1024])):
|
||||
raise RuntimeError(f"Invalid global/local dims {global_size=}, {local_size=}")
|
||||
|
||||
@@ -90,7 +90,7 @@ class WebGPUProgram:
|
||||
|
||||
self.name, self.lib, self.prg = name, lib, shader_module
|
||||
def __call__(self, *bufs:WGPUBufPtr, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
|
||||
vals:tuple[int, ...]=(), wait=False) -> float|None:
|
||||
vals:tuple[int, ...]=(), wait=False, **kw) -> float|None:
|
||||
wait = wait and self.timestamp_supported
|
||||
tmp_bufs = [*bufs]
|
||||
buf_patch = False
|
||||
|
||||
@@ -225,13 +225,13 @@ class AMDev(PCIDevImplBase):
|
||||
self.ih.interrupt_handler()
|
||||
self.reg("regSCRATCH_REG6").write(self.is_err_state) # set finalized state.
|
||||
|
||||
def recover(self) -> bool:
|
||||
if not self.is_err_state: return False
|
||||
if DEBUG >= 2: print(f"am {self.devfmt}: Start recovery")
|
||||
def recover(self, force=False) -> bool:
|
||||
if not force and not self.is_err_state: return False
|
||||
if DEBUG >= 3: print(f"am {self.devfmt}: Start recovery")
|
||||
self.ih.interrupt_handler()
|
||||
self.gfx.reset_mec()
|
||||
self.is_err_state = False
|
||||
if DEBUG >= 2: print(f"am {self.devfmt}: Recovery complete")
|
||||
if DEBUG >= 3: print(f"am {self.devfmt}: Recovery complete")
|
||||
return True
|
||||
|
||||
def is_hive(self) -> bool: return self.gmc.xgmi_seg_sz > 0
|
||||
|
||||
@@ -290,13 +290,11 @@ class AM_GFX(AM_IP):
|
||||
def fini_hw(self): self._dequeue_hqds()
|
||||
|
||||
def reset_mec(self):
|
||||
self._dequeue_hqds(reset=True)
|
||||
self._dequeue_hqds()
|
||||
|
||||
# issue a soft reset to reset aql sync counter on multixcc systems.
|
||||
if self.xccs > 1:
|
||||
for xcc in range(self.xccs): self.adev.regGRBM_SOFT_RESET.write(soft_reset_cp=1, soft_reset_gfx=1, inst=xcc)
|
||||
time.sleep(0.05)
|
||||
for xcc in range(self.xccs): self.adev.regGRBM_SOFT_RESET.write(0x0, inst=xcc)
|
||||
for xcc in range(self.xccs): self.adev.regGRBM_SOFT_RESET.write(soft_reset_cp=1, soft_reset_cpc=1, inst=xcc)
|
||||
time.sleep(0.05)
|
||||
for xcc in range(self.xccs): self.adev.regGRBM_SOFT_RESET.write(0x0, inst=xcc)
|
||||
|
||||
self._config_mec()
|
||||
self._enable_mec()
|
||||
@@ -384,13 +382,13 @@ class AM_GFX(AM_IP):
|
||||
if self.adev.ip_ver[am.GC_HWIP] >= (10,0,0):
|
||||
_config_helper(eng_name="MEC", cntl_reg="MEC_RS64", eng_reg="MEC_RS64", pipe_cnt=1, me=1, xcc=xcc)
|
||||
|
||||
def _dequeue_hqds(self, reset=False):
|
||||
def _dequeue_hqds(self):
|
||||
for q in range(2):
|
||||
for xcc in range(self.xccs):
|
||||
self._grbm_select(me=1, pipe=0, queue=q, inst=xcc)
|
||||
if self.adev.regCP_HQD_ACTIVE.read(inst=xcc) & 1:
|
||||
self.adev.regCP_HQD_DEQUEUE_REQUEST.write(0x2, inst=xcc) # 1 - DRAIN_PIPE; 2 - RESET_WAVES
|
||||
if not reset: wait_cond(lambda: self.adev.regCP_HQD_ACTIVE.read(inst=xcc) & 1, value=0, msg="HQD dequeue timeout")
|
||||
if not self.adev.is_err_state: wait_cond(lambda: self.adev.regCP_HQD_ACTIVE.read(inst=xcc) & 1, value=0, msg="HQD dequeue timeout")
|
||||
self._grbm_select()
|
||||
|
||||
class AM_IH(AM_IP):
|
||||
|
||||
@@ -253,7 +253,7 @@ class HCQSignal(Generic[HCQDeviceType]):
|
||||
Raises RuntimeError if a fault is detected.
|
||||
"""
|
||||
|
||||
def wait(self, value:int, timeout:int=getenv("HCQDEV_WAIT_TIMEOUT_MS", 30000)):
|
||||
def wait(self, value:int, timeout:int|None=None):
|
||||
"""
|
||||
Waits the signal is greater than or equal to a specific value.
|
||||
|
||||
@@ -261,6 +261,7 @@ class HCQSignal(Generic[HCQDeviceType]):
|
||||
value: The value to wait for.
|
||||
timeout: Maximum time to wait in milliseconds. Defaults to 30s.
|
||||
"""
|
||||
timeout = timeout or getenv("HCQDEV_WAIT_TIMEOUT_MS", 30000)
|
||||
start_time = int(time.perf_counter() * 1000)
|
||||
while (not_passed:=(prev_value:=self.value) < value) and (cur_time:=int(time.perf_counter() * 1000)) - start_time < timeout:
|
||||
self._sleep(cur_time - start_time)
|
||||
@@ -325,7 +326,7 @@ class HCQProgram(Generic[HCQDeviceType]):
|
||||
return self.args_state_t(argsbuf, self, bufs, vals=vals)
|
||||
|
||||
def __call__(self, *bufs:HCQBuffer, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
|
||||
vals:tuple[int|None, ...]=(), wait:bool=False) -> float|None:
|
||||
vals:tuple[int|None, ...]=(), wait:bool=False, timeout:int|None=None) -> float|None:
|
||||
"""
|
||||
Enqueues the program for execution with the given arguments and dimensions.
|
||||
|
||||
@@ -349,7 +350,7 @@ class HCQProgram(Generic[HCQDeviceType]):
|
||||
|
||||
q.signal(self.dev.timeline_signal, self.dev.next_timeline()).submit(self.dev)
|
||||
|
||||
if wait: self.dev.synchronize()
|
||||
if wait: self.dev.synchronize(timeout=timeout)
|
||||
return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
|
||||
|
||||
class HCQCompiled(Compiled, Generic[SignalType]):
|
||||
@@ -362,7 +363,8 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
||||
cpu_devices: list[HCQCompiled] = []
|
||||
|
||||
def __init__(self, device:str, allocator:HCQAllocatorBase, compilers:CompilerSet, runtime, signal_t:Type[SignalType],
|
||||
comp_queue_t:Callable[..., HWQueue], copy_queue_t:Callable[..., HWQueue]|None=None, kernargs_size=(16 << 20), sigalloc_size=0x1000):
|
||||
comp_queue_t:Callable[..., HWQueue], copy_queue_t:Callable[..., HWQueue]|None=None, kernargs_size=(16 << 20), sigalloc_size=0x1000,
|
||||
can_recover:bool=False):
|
||||
self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
|
||||
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
@@ -386,22 +388,23 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
||||
self.kernargs_buf:HCQBuffer = self.allocator.alloc(kernargs_size, BufferSpec(cpu_access=True))
|
||||
self.kernargs_offset_allocator:BumpAllocator = BumpAllocator(self.kernargs_buf.size, wrap=True)
|
||||
|
||||
self.can_recover = can_recover # Whether the device can recover from faults or timeouts
|
||||
self.error_state:Exception|None = None # Exception if error is unrecoverable and sync will always fail
|
||||
|
||||
if self._is_cpu(): HCQCompiled.cpu_devices.append(self)
|
||||
|
||||
def synchronize(self):
|
||||
def synchronize(self, timeout:int|None=None):
|
||||
if self.error_state is not None: raise self.error_state
|
||||
|
||||
# If we have any work on CPU devices, need to synchronize them. This is just an optimization to release GIL allowing to finish faster.
|
||||
if not self._is_cpu():
|
||||
for dev in HCQCompiled.cpu_devices: dev.synchronize()
|
||||
|
||||
try: self.timeline_signal.wait(self.timeline_value - 1)
|
||||
try: self.timeline_signal.wait(self.timeline_value - 1, timeout=timeout if timeout is not None and self.can_recover else None)
|
||||
except RuntimeError as e:
|
||||
self.error_state = e
|
||||
if hasattr(self, 'on_device_hang'): self.on_device_hang()
|
||||
else: raise e
|
||||
raise e
|
||||
|
||||
if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
|
||||
if PROFILE:
|
||||
|
||||
Reference in New Issue
Block a user