cpu: wait on dep signals (#14862)

* cpu: task_done() in case of failures

* print

* fix

* x

* f

* x

* um

* ?

* u

* f

* x

* gh

* f

* f

* virt

* x

* simpler
This commit is contained in:
nimlgen
2026-02-23 21:09:41 +03:00
committed by GitHub
parent 127136421d
commit 77db8e1c07
4 changed files with 33 additions and 24 deletions

View File

@@ -45,7 +45,7 @@ class AMDSignal(HCQSignal):
def _sleep(self, time_spent_since_last_sleep_ms:int):
# Reasonable to sleep for long workloads (which take more than 200ms) and only timeline signals.
if time_spent_since_last_sleep_ms > 200 and self.is_timeline and self.owner is not None: self.owner.iface.sleep(200)
if time_spent_since_last_sleep_ms > 200 and self.owner is not None: self.owner.iface.sleep(200)
class AMDComputeQueue(HWQueue):
def __init__(self, dev:AMDDevice):

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
import platform, sys, ctypes, functools, time, mmap, threading, queue
from tinygrad.helpers import to_mv, OSX, WIN, mv_address, wait_cond, suppress_finalizing, unwrap, data64_le
from tinygrad.helpers import to_mv, OSX, WIN, mv_address, suppress_finalizing, unwrap, data64_le
from tinygrad.helpers import CPU_CC, CPU_LVP, CPU_LLVM
from tinygrad.device import BufferSpec, DMACPURef, CompilerSet
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface
@@ -13,7 +13,9 @@ from tinygrad.uop.ops import sint
class CPUSignal(HCQSignal):
def _sleep(self, time_spent_since_last_sleep_ms:int):
if self.is_timeline and self.owner is not None: self.owner.tasks.join()
if self.is_timeline and self.owner is not None:
self.owner.tasks.join()
if self.owner.error_state is not None: raise self.owner.error_state
class CPUWorker(threading.Thread):
def __init__(self, dev, tasks, thread_id):
@@ -29,13 +31,15 @@ class CPUWorker(threading.Thread):
def run(self):
while True:
cmd_iter = iter(self.tasks.get())
for cmd in cmd_iter:
threads, args_cnt = next(cmd_iter), next(cmd_iter)
args = [next(cmd_iter) for _ in range(args_cnt)]
for th in range(threads - 1): self.push_task(th, cmd, args)
cmd(self.thread_id, *args)
for th in range(threads - 1): self.pool[th].join()
self.tasks.task_done()
try:
for cmd in cmd_iter:
threads, args_cnt = next(cmd_iter), next(cmd_iter)
args = [next(cmd_iter) for _ in range(args_cnt)]
for th in range(threads - 1): self.push_task(th, cmd, args)
cmd(self.thread_id, *args)
for th in range(threads - 1): self.pool[th].join()
except Exception as e: self.dev.error_state = e
finally: self.tasks.task_done()
class CPUComputeQueue(HWQueue):
def _exec(self, tid, prg, bufs, *args):
@@ -43,7 +47,9 @@ class CPUComputeQueue(HWQueue):
if 'core_id' in prg.runtimevars: vals[prg.runtimevars['core_id']] = tid
prg.fxn(*map(ctypes.c_uint64, args[:bufs]), *map(ctypes.c_int64 if platform.machine() == "arm64" else ctypes.c_int32, vals))
def _signal(self, tid, signal_addr, value): to_mv(signal_addr, 4).cast('I')[0] = value
def _wait(self, tid, signal_addr, value): wait_cond(lambda: to_mv(signal_addr, 4).cast('I')[0] >= value, timeout_ms=60000)
def _wait(self, tid, tmpl_sig, signal_addr, value):
tmpl_sig.base_buf = HCQBuffer(signal_addr, 16, view=MMIOInterface(signal_addr, 16))
tmpl_sig.wait(value)
def _timestamp(self, tid, timestamp_addr): to_mv(timestamp_addr, 8).cast('Q')[0] = time.perf_counter_ns()
def cmd(self, cmd, *args, threads=1):
self.q(cmd, threads, len(args), *args)
@@ -55,7 +61,7 @@ class CPUComputeQueue(HWQueue):
self.bind_args_state(args_state)
return self.cmd(self._exec, prg, 1, args_state.buf.va_addr)
return self.cmd(self._exec, prg, len(args_state.bufs), *[x.va_addr for x in args_state.bufs], *args_state.vals, threads=(global_size or (1,))[0])
def wait(self, signal, value=0): return self.cmd(self._wait, signal.value_addr, value)
def wait(self, signal, value=0): return self.cmd(self._wait, type(signal)(signal.base_buf, owner=signal.owner, virt=True), signal.value_addr, value)
def timestamp(self, signal): return self.cmd(self._timestamp, signal.timestamp_addr)
def signal(self, signal, value:sint=0): return self.cmd(self._signal, signal.value_addr, value)
def _submit(self, dev): dev.tasks.put(self._q[:])

View File

@@ -28,7 +28,7 @@ class ProfilePMAEvent(ProfileEvent): device:str; kern:str; blob:bytes; exec_tag:
class NVSignal(HCQSignal):
def _sleep(self, time_spent_since_last_sleep_ms:int):
# Reasonable to sleep for long workloads (which take more than 200ms) and only timeline signals.
if time_spent_since_last_sleep_ms > 200 and self.is_timeline and self.owner is not None: self.owner.iface.sleep(200)
if time_spent_since_last_sleep_ms > 200 and self.owner is not None: self.owner.iface.sleep(200)
def get_error_str(status): return f"{status}: {nv_gpu.nv_status_codes.get(status, 'Unknown error')}"

View File

@@ -214,23 +214,26 @@ class HWQueue(Generic[SignalType, HCQDeviceType, ProgramType, ArgsStateType]):
def _submit(self, dev:HCQDeviceType): raise NotImplementedError("need _submit")
class HCQSignal(Generic[HCQDeviceType]):
def __init__(self, base_buf:HCQBuffer, value:int=0, owner:HCQDeviceType|None=None, is_timeline:bool=False, timestamp_divider=1000):
self.base_buf, self.value_addr, self.timestamp_addr, self.owner = base_buf, base_buf.va_addr+0, base_buf.va_addr+8, owner
self.is_timeline = is_timeline
def __init__(self, base_buf:HCQBuffer, value:int=0, owner:HCQDeviceType|None=None, is_timeline:bool=False, timestamp_divider=1000, virt=False):
self.base_buf, self.owner, self.is_timeline = base_buf, owner, is_timeline
self.should_return = isinstance(self.base_buf.va_addr, int) and self.owner is not None and not virt
self.timestamp_divider:decimal.Decimal = decimal.Decimal(timestamp_divider)
if isinstance(self.base_buf.va_addr, int):
self.value_mv, self.timestamp_mv = self.base_buf.cpu_view().view(0, 8, 'Q'), self.base_buf.cpu_view().view(8, 8, 'Q')
self.value_mv[0] = value
if isinstance(self.base_buf.va_addr, int) and not virt: self.value = value
def __del__(self):
if isinstance(self.base_buf.va_addr, int) and self.owner is not None: HCQCompiled.signal_pool[self.owner.peer_group].append(self.base_buf)
if self.should_return: HCQCompiled.signal_pool[unwrap(self.owner).peer_group].append(self.base_buf)
@property
def value(self) -> int: return self.value_mv[0]
def value_addr(self) -> sint: return self.base_buf.va_addr
@property
def timestamp_addr(self) -> sint: return self.base_buf.va_addr + 8
@property
def value(self) -> int: return self.base_buf.cpu_view().view(0, 8, 'Q')[0]
@value.setter
def value(self, new_value:int): self.value_mv[0] = new_value
def value(self, new_value:int): self.base_buf.cpu_view().view(0, 8, 'Q')[0] = new_value
@property
def timestamp(self) -> decimal.Decimal:
@@ -242,7 +245,7 @@ class HCQSignal(Generic[HCQDeviceType]):
Returns:
The timestamp in microseconds.
"""
return self.timestamp_mv[0] / self.timestamp_divider
return self.base_buf.cpu_view().view(8, 8, 'Q')[0] / self.timestamp_divider
def _sleep(self, time_spent_since_last_sleep_ms:int):
"""