diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 61f75cdc05..373cf6c380 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -45,7 +45,7 @@ class AMDSignal(HCQSignal): def _sleep(self, time_spent_since_last_sleep_ms:int): # Reasonable to sleep for long workloads (which take more than 200ms) and only timeline signals. - if time_spent_since_last_sleep_ms > 200 and self.is_timeline and self.owner is not None: self.owner.iface.sleep(200) + if time_spent_since_last_sleep_ms > 200 and self.owner is not None: self.owner.iface.sleep(200) class AMDComputeQueue(HWQueue): def __init__(self, dev:AMDDevice): diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index 682266e858..738d0aa262 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -1,6 +1,6 @@ from __future__ import annotations import platform, sys, ctypes, functools, time, mmap, threading, queue -from tinygrad.helpers import to_mv, OSX, WIN, mv_address, wait_cond, suppress_finalizing, unwrap, data64_le +from tinygrad.helpers import to_mv, OSX, WIN, mv_address, suppress_finalizing, unwrap, data64_le from tinygrad.helpers import CPU_CC, CPU_LVP, CPU_LLVM from tinygrad.device import BufferSpec, DMACPURef, CompilerSet from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface @@ -13,7 +13,9 @@ from tinygrad.uop.ops import sint class CPUSignal(HCQSignal): def _sleep(self, time_spent_since_last_sleep_ms:int): - if self.is_timeline and self.owner is not None: self.owner.tasks.join() + if self.is_timeline and self.owner is not None: + self.owner.tasks.join() + if self.owner.error_state is not None: raise self.owner.error_state class CPUWorker(threading.Thread): def __init__(self, dev, tasks, thread_id): @@ -29,13 +31,15 @@ class CPUWorker(threading.Thread): def run(self): while True: cmd_iter = iter(self.tasks.get()) - for cmd in cmd_iter: - threads, args_cnt = next(cmd_iter), next(cmd_iter) - args = [next(cmd_iter) for _ in range(args_cnt)] - for th in range(threads - 1): self.push_task(th, cmd, args) - cmd(self.thread_id, *args) - for th in range(threads - 1): self.pool[th].join() - self.tasks.task_done() + try: + for cmd in cmd_iter: + threads, args_cnt = next(cmd_iter), next(cmd_iter) + args = [next(cmd_iter) for _ in range(args_cnt)] + for th in range(threads - 1): self.push_task(th, cmd, args) + cmd(self.thread_id, *args) + for th in range(threads - 1): self.pool[th].join() + except Exception as e: self.dev.error_state = e + finally: self.tasks.task_done() class CPUComputeQueue(HWQueue): def _exec(self, tid, prg, bufs, *args): @@ -43,7 +47,9 @@ class CPUComputeQueue(HWQueue): if 'core_id' in prg.runtimevars: vals[prg.runtimevars['core_id']] = tid prg.fxn(*map(ctypes.c_uint64, args[:bufs]), *map(ctypes.c_int64 if platform.machine() == "arm64" else ctypes.c_int32, vals)) def _signal(self, tid, signal_addr, value): to_mv(signal_addr, 4).cast('I')[0] = value - def _wait(self, tid, signal_addr, value): wait_cond(lambda: to_mv(signal_addr, 4).cast('I')[0] >= value, timeout_ms=60000) + def _wait(self, tid, tmpl_sig, signal_addr, value): + tmpl_sig.base_buf = HCQBuffer(signal_addr, 16, view=MMIOInterface(signal_addr, 16)) + tmpl_sig.wait(value) def _timestamp(self, tid, timestamp_addr): to_mv(timestamp_addr, 8).cast('Q')[0] = time.perf_counter_ns() def cmd(self, cmd, *args, threads=1): self.q(cmd, threads, len(args), *args) @@ -55,7 +61,7 @@ class CPUComputeQueue(HWQueue): self.bind_args_state(args_state) return self.cmd(self._exec, prg, 1, args_state.buf.va_addr) return self.cmd(self._exec, prg, len(args_state.bufs), *[x.va_addr for x in args_state.bufs], *args_state.vals, threads=(global_size or (1,))[0]) - def wait(self, signal, value=0): return self.cmd(self._wait, signal.value_addr, value) + def wait(self, signal, value=0): return self.cmd(self._wait, type(signal)(signal.base_buf, owner=signal.owner, virt=True), signal.value_addr, value) def timestamp(self, signal): return self.cmd(self._timestamp, signal.timestamp_addr) def signal(self, signal, value:sint=0): return self.cmd(self._signal, signal.value_addr, value) def _submit(self, dev): dev.tasks.put(self._q[:]) diff --git a/tinygrad/runtime/ops_nv.py b/tinygrad/runtime/ops_nv.py index 2ca17163f5..eb645bd0d2 100644 --- a/tinygrad/runtime/ops_nv.py +++ b/tinygrad/runtime/ops_nv.py @@ -28,7 +28,7 @@ class ProfilePMAEvent(ProfileEvent): device:str; kern:str; blob:bytes; exec_tag: class NVSignal(HCQSignal): def _sleep(self, time_spent_since_last_sleep_ms:int): # Reasonable to sleep for long workloads (which take more than 200ms) and only timeline signals. - if time_spent_since_last_sleep_ms > 200 and self.is_timeline and self.owner is not None: self.owner.iface.sleep(200) + if time_spent_since_last_sleep_ms > 200 and self.owner is not None: self.owner.iface.sleep(200) def get_error_str(status): return f"{status}: {nv_gpu.nv_status_codes.get(status, 'Unknown error')}" diff --git a/tinygrad/runtime/support/hcq.py b/tinygrad/runtime/support/hcq.py index c37b336fa8..a2d08a92dc 100644 --- a/tinygrad/runtime/support/hcq.py +++ b/tinygrad/runtime/support/hcq.py @@ -214,23 +214,26 @@ class HWQueue(Generic[SignalType, HCQDeviceType, ProgramType, ArgsStateType]): def _submit(self, dev:HCQDeviceType): raise NotImplementedError("need _submit") class HCQSignal(Generic[HCQDeviceType]): - def __init__(self, base_buf:HCQBuffer, value:int=0, owner:HCQDeviceType|None=None, is_timeline:bool=False, timestamp_divider=1000): - self.base_buf, self.value_addr, self.timestamp_addr, self.owner = base_buf, base_buf.va_addr+0, base_buf.va_addr+8, owner - self.is_timeline = is_timeline + def __init__(self, base_buf:HCQBuffer, value:int=0, owner:HCQDeviceType|None=None, is_timeline:bool=False, timestamp_divider=1000, virt=False): + self.base_buf, self.owner, self.is_timeline = base_buf, owner, is_timeline + self.should_return = isinstance(self.base_buf.va_addr, int) and self.owner is not None and not virt self.timestamp_divider:decimal.Decimal = decimal.Decimal(timestamp_divider) - - if isinstance(self.base_buf.va_addr, int): - self.value_mv, self.timestamp_mv = self.base_buf.cpu_view().view(0, 8, 'Q'), self.base_buf.cpu_view().view(8, 8, 'Q') - self.value_mv[0] = value + if isinstance(self.base_buf.va_addr, int) and not virt: self.value = value def __del__(self): - if isinstance(self.base_buf.va_addr, int) and self.owner is not None: HCQCompiled.signal_pool[self.owner.peer_group].append(self.base_buf) + if self.should_return: HCQCompiled.signal_pool[unwrap(self.owner).peer_group].append(self.base_buf) @property - def value(self) -> int: return self.value_mv[0] + def value_addr(self) -> sint: return self.base_buf.va_addr + + @property + def timestamp_addr(self) -> sint: return self.base_buf.va_addr + 8 + + @property + def value(self) -> int: return self.base_buf.cpu_view().view(0, 8, 'Q')[0] @value.setter - def value(self, new_value:int): self.value_mv[0] = new_value + def value(self, new_value:int): self.base_buf.cpu_view().view(0, 8, 'Q')[0] = new_value @property def timestamp(self) -> decimal.Decimal: @@ -242,7 +245,7 @@ class HCQSignal(Generic[HCQDeviceType]): Returns: The timestamp in microseconds. """ - return self.timestamp_mv[0] / self.timestamp_divider + return self.base_buf.cpu_view().view(8, 8, 'Q')[0] / self.timestamp_divider def _sleep(self, time_spent_since_last_sleep_ms:int): """