mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
cpu: wait on dep signals (#14862)
* cpu: task_done() in case of failures * print * fix * x * f * x * um * ? * u * f * x * gh * f * f * virt * x * simpler
This commit is contained in:
@@ -45,7 +45,7 @@ class AMDSignal(HCQSignal):
|
||||
|
||||
def _sleep(self, time_spent_since_last_sleep_ms:int):
|
||||
# Reasonable to sleep for long workloads (which take more than 200ms) and only timeline signals.
|
||||
if time_spent_since_last_sleep_ms > 200 and self.is_timeline and self.owner is not None: self.owner.iface.sleep(200)
|
||||
if time_spent_since_last_sleep_ms > 200 and self.owner is not None: self.owner.iface.sleep(200)
|
||||
|
||||
class AMDComputeQueue(HWQueue):
|
||||
def __init__(self, dev:AMDDevice):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import platform, sys, ctypes, functools, time, mmap, threading, queue
|
||||
from tinygrad.helpers import to_mv, OSX, WIN, mv_address, wait_cond, suppress_finalizing, unwrap, data64_le
|
||||
from tinygrad.helpers import to_mv, OSX, WIN, mv_address, suppress_finalizing, unwrap, data64_le
|
||||
from tinygrad.helpers import CPU_CC, CPU_LVP, CPU_LLVM
|
||||
from tinygrad.device import BufferSpec, DMACPURef, CompilerSet
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface
|
||||
@@ -13,7 +13,9 @@ from tinygrad.uop.ops import sint
|
||||
|
||||
class CPUSignal(HCQSignal):
|
||||
def _sleep(self, time_spent_since_last_sleep_ms:int):
|
||||
if self.is_timeline and self.owner is not None: self.owner.tasks.join()
|
||||
if self.is_timeline and self.owner is not None:
|
||||
self.owner.tasks.join()
|
||||
if self.owner.error_state is not None: raise self.owner.error_state
|
||||
|
||||
class CPUWorker(threading.Thread):
|
||||
def __init__(self, dev, tasks, thread_id):
|
||||
@@ -29,13 +31,15 @@ class CPUWorker(threading.Thread):
|
||||
def run(self):
|
||||
while True:
|
||||
cmd_iter = iter(self.tasks.get())
|
||||
for cmd in cmd_iter:
|
||||
threads, args_cnt = next(cmd_iter), next(cmd_iter)
|
||||
args = [next(cmd_iter) for _ in range(args_cnt)]
|
||||
for th in range(threads - 1): self.push_task(th, cmd, args)
|
||||
cmd(self.thread_id, *args)
|
||||
for th in range(threads - 1): self.pool[th].join()
|
||||
self.tasks.task_done()
|
||||
try:
|
||||
for cmd in cmd_iter:
|
||||
threads, args_cnt = next(cmd_iter), next(cmd_iter)
|
||||
args = [next(cmd_iter) for _ in range(args_cnt)]
|
||||
for th in range(threads - 1): self.push_task(th, cmd, args)
|
||||
cmd(self.thread_id, *args)
|
||||
for th in range(threads - 1): self.pool[th].join()
|
||||
except Exception as e: self.dev.error_state = e
|
||||
finally: self.tasks.task_done()
|
||||
|
||||
class CPUComputeQueue(HWQueue):
|
||||
def _exec(self, tid, prg, bufs, *args):
|
||||
@@ -43,7 +47,9 @@ class CPUComputeQueue(HWQueue):
|
||||
if 'core_id' in prg.runtimevars: vals[prg.runtimevars['core_id']] = tid
|
||||
prg.fxn(*map(ctypes.c_uint64, args[:bufs]), *map(ctypes.c_int64 if platform.machine() == "arm64" else ctypes.c_int32, vals))
|
||||
def _signal(self, tid, signal_addr, value): to_mv(signal_addr, 4).cast('I')[0] = value
|
||||
def _wait(self, tid, signal_addr, value): wait_cond(lambda: to_mv(signal_addr, 4).cast('I')[0] >= value, timeout_ms=60000)
|
||||
def _wait(self, tid, tmpl_sig, signal_addr, value):
|
||||
tmpl_sig.base_buf = HCQBuffer(signal_addr, 16, view=MMIOInterface(signal_addr, 16))
|
||||
tmpl_sig.wait(value)
|
||||
def _timestamp(self, tid, timestamp_addr): to_mv(timestamp_addr, 8).cast('Q')[0] = time.perf_counter_ns()
|
||||
def cmd(self, cmd, *args, threads=1):
|
||||
self.q(cmd, threads, len(args), *args)
|
||||
@@ -55,7 +61,7 @@ class CPUComputeQueue(HWQueue):
|
||||
self.bind_args_state(args_state)
|
||||
return self.cmd(self._exec, prg, 1, args_state.buf.va_addr)
|
||||
return self.cmd(self._exec, prg, len(args_state.bufs), *[x.va_addr for x in args_state.bufs], *args_state.vals, threads=(global_size or (1,))[0])
|
||||
def wait(self, signal, value=0): return self.cmd(self._wait, signal.value_addr, value)
|
||||
def wait(self, signal, value=0): return self.cmd(self._wait, type(signal)(signal.base_buf, owner=signal.owner, virt=True), signal.value_addr, value)
|
||||
def timestamp(self, signal): return self.cmd(self._timestamp, signal.timestamp_addr)
|
||||
def signal(self, signal, value:sint=0): return self.cmd(self._signal, signal.value_addr, value)
|
||||
def _submit(self, dev): dev.tasks.put(self._q[:])
|
||||
|
||||
@@ -28,7 +28,7 @@ class ProfilePMAEvent(ProfileEvent): device:str; kern:str; blob:bytes; exec_tag:
|
||||
class NVSignal(HCQSignal):
|
||||
def _sleep(self, time_spent_since_last_sleep_ms:int):
|
||||
# Reasonable to sleep for long workloads (which take more than 200ms) and only timeline signals.
|
||||
if time_spent_since_last_sleep_ms > 200 and self.is_timeline and self.owner is not None: self.owner.iface.sleep(200)
|
||||
if time_spent_since_last_sleep_ms > 200 and self.owner is not None: self.owner.iface.sleep(200)
|
||||
|
||||
def get_error_str(status): return f"{status}: {nv_gpu.nv_status_codes.get(status, 'Unknown error')}"
|
||||
|
||||
|
||||
@@ -214,23 +214,26 @@ class HWQueue(Generic[SignalType, HCQDeviceType, ProgramType, ArgsStateType]):
|
||||
def _submit(self, dev:HCQDeviceType): raise NotImplementedError("need _submit")
|
||||
|
||||
class HCQSignal(Generic[HCQDeviceType]):
|
||||
def __init__(self, base_buf:HCQBuffer, value:int=0, owner:HCQDeviceType|None=None, is_timeline:bool=False, timestamp_divider=1000):
|
||||
self.base_buf, self.value_addr, self.timestamp_addr, self.owner = base_buf, base_buf.va_addr+0, base_buf.va_addr+8, owner
|
||||
self.is_timeline = is_timeline
|
||||
def __init__(self, base_buf:HCQBuffer, value:int=0, owner:HCQDeviceType|None=None, is_timeline:bool=False, timestamp_divider=1000, virt=False):
|
||||
self.base_buf, self.owner, self.is_timeline = base_buf, owner, is_timeline
|
||||
self.should_return = isinstance(self.base_buf.va_addr, int) and self.owner is not None and not virt
|
||||
self.timestamp_divider:decimal.Decimal = decimal.Decimal(timestamp_divider)
|
||||
|
||||
if isinstance(self.base_buf.va_addr, int):
|
||||
self.value_mv, self.timestamp_mv = self.base_buf.cpu_view().view(0, 8, 'Q'), self.base_buf.cpu_view().view(8, 8, 'Q')
|
||||
self.value_mv[0] = value
|
||||
if isinstance(self.base_buf.va_addr, int) and not virt: self.value = value
|
||||
|
||||
def __del__(self):
|
||||
if isinstance(self.base_buf.va_addr, int) and self.owner is not None: HCQCompiled.signal_pool[self.owner.peer_group].append(self.base_buf)
|
||||
if self.should_return: HCQCompiled.signal_pool[unwrap(self.owner).peer_group].append(self.base_buf)
|
||||
|
||||
@property
|
||||
def value(self) -> int: return self.value_mv[0]
|
||||
def value_addr(self) -> sint: return self.base_buf.va_addr
|
||||
|
||||
@property
|
||||
def timestamp_addr(self) -> sint: return self.base_buf.va_addr + 8
|
||||
|
||||
@property
|
||||
def value(self) -> int: return self.base_buf.cpu_view().view(0, 8, 'Q')[0]
|
||||
|
||||
@value.setter
|
||||
def value(self, new_value:int): self.value_mv[0] = new_value
|
||||
def value(self, new_value:int): self.base_buf.cpu_view().view(0, 8, 'Q')[0] = new_value
|
||||
|
||||
@property
|
||||
def timestamp(self) -> decimal.Decimal:
|
||||
@@ -242,7 +245,7 @@ class HCQSignal(Generic[HCQDeviceType]):
|
||||
Returns:
|
||||
The timestamp in microseconds.
|
||||
"""
|
||||
return self.timestamp_mv[0] / self.timestamp_divider
|
||||
return self.base_buf.cpu_view().view(8, 8, 'Q')[0] / self.timestamp_divider
|
||||
|
||||
def _sleep(self, time_spent_since_last_sleep_ms:int):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user