mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-16 09:37:11 -05:00
534 lines
25 KiB
Python
534 lines
25 KiB
Python
from __future__ import annotations
|
|
from typing import List, Optional, Dict, Tuple, cast, Protocol, Type, Union, TypeVar, Generic, Any
|
|
import contextlib, decimal, statistics, random, json, atexit, time, ctypes, array
|
|
from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv, getenv, to_mv
|
|
from tinygrad.renderer import Renderer
|
|
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator
|
|
from tinygrad.ops import sym_infer, sint, Variable
|
|
|
|
# **************** for HCQ Compatible Devices ****************
|
|
|
|
SignalType = TypeVar('SignalType', bound='HCQSignal')
|
|
DeviceType = TypeVar('DeviceType', bound='HCQCompiled')
|
|
ProgramType = TypeVar('ProgramType', bound='HCQProgram')
|
|
ArgsStateType = TypeVar('ArgsStateType', bound='HCQArgsState')
|
|
QueueType = TypeVar('QueueType', bound='HWQueue')
|
|
|
|
class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
|
|
"""
|
|
A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._q:Any = []
|
|
self.binded_device:Optional[DeviceType] = None
|
|
self.q_sints:List[Tuple[int, int]] = []
|
|
self.mv_sints:List[Tuple[memoryview, int, int, Optional[int]]] = []
|
|
self.syms:List[sint] = []
|
|
self._prev_resolved_syms:List[Optional[int]] = []
|
|
|
|
def _new_sym(self, sym:sint) -> int:
|
|
if sym not in self.syms:
|
|
self.syms.append(sym)
|
|
self._prev_resolved_syms.append(None)
|
|
return self.syms.index(sym)
|
|
|
|
def q(self, *values):
|
|
"""
|
|
Enqueues values in the queue.
|
|
|
|
Args:
|
|
values: The values to enqueue in the queue.
|
|
"""
|
|
|
|
for v in values:
|
|
if isinstance(v, int): self._q.append(v)
|
|
else:
|
|
self.q_sints.append((len(self._q), self._new_sym(v)))
|
|
self._q.append(0xbadc0ded)
|
|
|
|
# *** common commands ***
|
|
|
|
def timestamp(self, signal:SignalType):
|
|
"""
|
|
Enqueues a timestamp command which records the current time in a signal after all previously enqueued commands are completed.
|
|
|
|
Args:
|
|
signal: The signal to store the timestamp
|
|
"""
|
|
|
|
def signal(self, signal:SignalType, value:sint):
|
|
"""
|
|
Enqueues a signal command which sets the signal to the given value, ensuring all previous operations are completed.
|
|
|
|
Args:
|
|
signal: The signal to set
|
|
value: The value to set the signal to
|
|
"""
|
|
|
|
def wait(self, signal:SignalType, value:sint):
|
|
"""
|
|
Enqueues a wait command which halts execution until the signal is greater than or equal to a specific value.
|
|
|
|
Args:
|
|
signal: The signal to wait on
|
|
value: The value to wait for
|
|
"""
|
|
|
|
# *** commands for compute queues ***
|
|
|
|
def memory_barrier(self):
|
|
"""
|
|
Enqueues a memory barrier command to ensure memory coherence between agents. Only on compute queues.
|
|
"""
|
|
|
|
def exec(self, prg:ProgramType, args_state:ArgsStateType, global_size:Tuple[sint, ...], local_size:Tuple[sint, ...]):
|
|
"""
|
|
Enqueues an execution command for a kernel program. Only on compute queues.
|
|
|
|
Args:
|
|
prg: The program to execute
|
|
args_state: The args state to execute program with
|
|
global_size: The global work size
|
|
local_size: The local work size
|
|
"""
|
|
|
|
# *** commands for copy queues ***
|
|
|
|
def copy(self, dest:sint, src:sint, copy_size:int):
|
|
"""
|
|
Enqueues a copy command to transfer data. Only on copy queues.
|
|
|
|
Args:
|
|
dest: The destination of the copy
|
|
src: The source of the copy
|
|
copy_size: The size of data to copy
|
|
"""
|
|
|
|
# *** submit and bind commands ***
|
|
|
|
def bind(self, dev:DeviceType):
|
|
"""
|
|
Associates the queue with a specific device for optimized execution.
|
|
|
|
This optional method allows backend implementations to tailor the queue for efficient use on the given device. When implemented, it can eliminate
|
|
the need to copy queues into the device, thereby enhancing performance.
|
|
|
|
Args:
|
|
dev: The target device for queue optimization.
|
|
|
|
Note:
|
|
Implementing this method is optional but recommended for performance gains.
|
|
"""
|
|
|
|
def bind_sints(self, *vals:sint, struct:ctypes.Structure, start_field:str, fmt, mask:Optional[int]=None):
|
|
self.bind_sints_to_ptr(*vals, ptr=ctypes.addressof(struct) + getattr(type(struct), start_field).offset, fmt=fmt, mask=mask)
|
|
|
|
def bind_sints_to_ptr(self, *vals:sint, ptr:int, fmt, mask:Optional[int]=None):
|
|
mv = to_mv(ptr, 8*len(vals)).cast(fmt)
|
|
for i, val in enumerate(vals):
|
|
if isinstance(val, int): mv[i] = val if mask is None else ((mv[i] & ~mask) | val)
|
|
else: self.mv_sints.append((mv, i, self._new_sym(val), mask))
|
|
|
|
def _apply_var_vals(self, var_vals:Dict[Variable, int]):
|
|
resolved_syms = [sym_infer(sym, var_vals) for sym in self.syms]
|
|
|
|
for off, sym_idx in self.q_sints:
|
|
if self._prev_resolved_syms[sym_idx] == resolved_syms[sym_idx]: continue
|
|
self._q[off] = resolved_syms[sym_idx]
|
|
|
|
for mv, off, sym_idx, mask in self.mv_sints:
|
|
if self._prev_resolved_syms[sym_idx] == resolved_syms[sym_idx]: continue
|
|
mv[off] = resolved_syms[sym_idx] if mask is None else ((mv[off] & ~mask) | resolved_syms[sym_idx])
|
|
|
|
self._prev_resolved_syms = cast(List[Optional[int]], resolved_syms)
|
|
|
|
def submit(self, dev:DeviceType, var_vals:Optional[Dict[Variable, int]]=None):
|
|
"""
|
|
Submits the command queue to a specific device for execution.
|
|
|
|
Args:
|
|
dev: The device to submit the queue to
|
|
"""
|
|
|
|
if var_vals is not None: self._apply_var_vals(var_vals)
|
|
self._submit(dev)
|
|
return self
|
|
def _submit(self, dev:DeviceType): raise NotImplementedError("need _submit")
|
|
|
|
class HCQSignal(Generic[DeviceType]):
|
|
def __init__(self, base_addr:sint=0, value:int=0, timeline_for_device:Optional[DeviceType]=None, timestamp_divider=1, value_off=0, timestamp_off=8):
|
|
self.base_addr, self.value_addr, self.timestamp_addr = base_addr, base_addr+value_off, base_addr+timestamp_off
|
|
self.timestamp_divider:decimal.Decimal = decimal.Decimal(timestamp_divider)
|
|
self.timeline_for_device:Optional[DeviceType] = timeline_for_device
|
|
|
|
if isinstance(base_addr, int):
|
|
self.value_mv, self.timestamp_mv = to_mv(self.value_addr, 8).cast('Q'), to_mv(self.timestamp_addr, 8).cast('Q')
|
|
self.value_mv[0] = value
|
|
|
|
@property
|
|
def value(self) -> int: return self.value_mv[0]
|
|
|
|
@value.setter
|
|
def value(self, new_value:int): self.value_mv[0] = new_value
|
|
|
|
@property
|
|
def timestamp(self) -> decimal.Decimal:
|
|
"""
|
|
Get the timestamp field of the signal.
|
|
|
|
This property provides read-only access to the signal's timestamp.
|
|
|
|
Returns:
|
|
The timestamp in microseconds.
|
|
"""
|
|
return self.timestamp_mv[0] / self.timestamp_divider
|
|
|
|
def _sleep(self, time_spent_waiting_ms:int):
|
|
"""
|
|
Optional function which can implement sleep functionality for the signal.
|
|
"""
|
|
|
|
def wait(self, value:int, timeout:int=getenv("HCQDEV_WAIT_TIMEOUT_MS", 30000)):
|
|
"""
|
|
Waits the signal is greater than or equal to a specific value.
|
|
|
|
Args:
|
|
value: The value to wait for.
|
|
timeout: Maximum time to wait in milliseconds. Defaults to 10s.
|
|
"""
|
|
start_time = int(time.time() * 1000)
|
|
while (time_spent:=int(time.time() * 1000) - start_time) < timeout:
|
|
if self.value >= value: return
|
|
self._sleep(time_spent)
|
|
raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
|
|
|
|
@contextlib.contextmanager
|
|
def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Optional[Type[HWQueue]]=None, queue:Optional[HWQueue]=None):
|
|
st, en = (dev.signal_t(), dev.signal_t()) if enabled else (None, None)
|
|
|
|
if enabled and queue is not None: queue.timestamp(st)
|
|
elif enabled:
|
|
assert queue_type is not None
|
|
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(st).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
|
|
dev.timeline_value += 1
|
|
|
|
try: yield (st, en)
|
|
finally:
|
|
if enabled and queue is not None: queue.timestamp(en)
|
|
elif enabled:
|
|
assert queue_type is not None
|
|
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
|
|
dev.timeline_value += 1
|
|
|
|
if enabled and PROFILE: dev.sig_prof_records.append((cast(HCQSignal, st), cast(HCQSignal, en), desc, queue_type is dev.hw_copy_queue_t))
|
|
|
|
class HCQArgsState(Generic[ProgramType]):
|
|
def __init__(self, ptr:int, prg:ProgramType, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()): self.ptr, self.prg = ptr, prg
|
|
def update_buffer(self, index:int, buf:HCQBuffer): raise NotImplementedError("need update_buffer")
|
|
def update_var(self, index:int, val:int): raise NotImplementedError("need update_var")
|
|
|
|
class CLikeArgsState(HCQArgsState[ProgramType]):
|
|
def __init__(self, ptr:int, prg:ProgramType, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=(), prefix:Optional[List[int]]=None):
|
|
super().__init__(ptr, prg, bufs, vals=vals)
|
|
|
|
if prefix is not None: to_mv(self.ptr, len(prefix) * 4).cast('I')[:] = array.array('I', prefix)
|
|
|
|
self.bufs = to_mv(self.ptr + len(prefix or []) * 4, len(bufs) * 8).cast('Q')
|
|
self.vals = to_mv(self.ptr + len(prefix or []) * 4 + len(bufs) * 8, len(vals) * 4).cast('I')
|
|
|
|
self.bufs[:] = array.array('Q', [b.va_addr for b in bufs])
|
|
self.vals[:] = array.array('I', vals)
|
|
|
|
def update_buffer(self, index:int, buf:HCQBuffer): self.bufs[index] = buf.va_addr
|
|
def update_var(self, index:int, val:int): self.vals[index] = val
|
|
|
|
class HCQProgram(Generic[DeviceType]):
|
|
def __init__(self, args_state_t:Type[HCQArgsState], dev:DeviceType, name:str, kernargs_alloc_size:int):
|
|
self.args_state_t, self.dev, self.name, self.kernargs_alloc_size = args_state_t, dev, name, kernargs_alloc_size
|
|
|
|
def fill_kernargs(self, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=(), kernargs_ptr:Optional[int]=None) -> HCQArgsState:
|
|
"""
|
|
Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided.
|
|
Args:
|
|
bufs: Buffers to be written to kernel arguments.
|
|
vals: Values to be written to kernel arguments.
|
|
kernargs_ptr: Optional pointer to pre-allocated kernel arguments memory.
|
|
Returns:
|
|
Arguments state with the given buffers and values set for the program.
|
|
"""
|
|
return self.args_state_t(kernargs_ptr or self.dev._alloc_kernargs(self.kernargs_alloc_size), self, bufs, vals=vals)
|
|
|
|
def __call__(self, *bufs:HCQBuffer, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1),
|
|
vals:Tuple[int, ...]=(), wait:bool=False) -> Optional[float]:
|
|
"""
|
|
Enqueues the program for execution with the given arguments and dimensions.
|
|
|
|
Args:
|
|
bufs: Buffer arguments to execute the kernel with.
|
|
global_size: Specifies the global work size for kernel execution (equivalent to CUDA's grid size).
|
|
local_size: Specifies the local work size for kernel execution (equivalent to CUDA's block size).
|
|
vals: Value arguments to execute the kernel with.
|
|
wait: If True, waits for the kernel to complete execution.
|
|
|
|
Returns:
|
|
Execution time of the kernel if 'wait' is True, otherwise None.
|
|
"""
|
|
|
|
kernargs = self.fill_kernargs(bufs, vals)
|
|
q = self.dev.hw_compute_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1).memory_barrier()
|
|
|
|
with hcq_profile(self.dev, queue=q, desc=self.name, enabled=wait or PROFILE) as (sig_st, sig_en):
|
|
q.exec(self, kernargs, global_size, local_size)
|
|
|
|
q.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
|
|
self.dev.timeline_value += 1
|
|
|
|
if wait: self.dev.synchronize()
|
|
return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
|
|
|
|
class ProfileLogger:
|
|
writers: int = 0
|
|
mjson: List[Dict] = []
|
|
actors: Dict[Union[str, Tuple[str, str]], int] = {}
|
|
|
|
def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1
|
|
|
|
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)]
|
|
|
|
def _ensure_actor(self, actor_name, subactor_name):
|
|
if actor_name not in self.actors:
|
|
self.actors[actor_name] = (pid:=len(self.actors))
|
|
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
|
|
|
|
if (subactor_key:=(actor_name,subactor_name)) not in self.actors:
|
|
self.actors[subactor_key] = (tid:=len(self.actors))
|
|
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
|
|
|
|
return self.actors[actor_name], self.actors.get(subactor_key, -1)
|
|
|
|
def __del__(self):
|
|
# perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
|
|
for name, st, et, actor_name, subactor_name, args in self.events:
|
|
pid, tid = self._ensure_actor(actor_name,subactor_name)
|
|
args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None
|
|
self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args})
|
|
|
|
for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
|
|
dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
|
|
pid, tid = self._ensure_actor(actor_name,subactor_name)
|
|
self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts": en, "bp": "e"})
|
|
self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts": st, "bp": "e"})
|
|
|
|
ProfileLogger.writers -= 1
|
|
if ProfileLogger.writers == 0 and len(self.mjson) > 0:
|
|
with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
|
|
print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.")
|
|
|
|
class HCQCompiled(Compiled, Generic[SignalType]):
|
|
"""
|
|
A base class for devices compatible with the HCQ (Hardware Command Queue) API.
|
|
"""
|
|
devices: List[HCQCompiled] = []
|
|
gpu2cpu_copy_time_diff: decimal.Decimal = decimal.Decimal('nan')
|
|
gpu2cpu_compute_time_diff: decimal.Decimal = decimal.Decimal('nan')
|
|
|
|
def __init__(self, device:str, allocator:HCQAllocator, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType],
|
|
comp_queue_t:Type[HWQueue], copy_queue_t:Optional[Type[HWQueue]]):
|
|
self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
|
|
self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t
|
|
self.timeline_value:int = 1
|
|
self.timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
|
|
self._shadow_timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
|
|
self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = []
|
|
self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool, Optional[Dict]]] = []
|
|
self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = []
|
|
if PROFILE: self._prof_setup()
|
|
|
|
from tinygrad.runtime.graph.hcq import HCQGraph
|
|
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
|
|
|
|
self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferSpec(cpu_access=True))
|
|
self.kernargs_ptr:int = self.kernargs_page.va_addr
|
|
self.devices.append(self)
|
|
|
|
def synchronize(self):
|
|
try: self.timeline_signal.wait(self.timeline_value - 1)
|
|
except RuntimeError as e:
|
|
if hasattr(self, 'on_device_hang'): self.on_device_hang()
|
|
else: raise e
|
|
|
|
if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
|
|
if PROFILE:
|
|
self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp, None) for st, en, name, is_cp in self.sig_prof_records]
|
|
self.sig_prof_records = []
|
|
|
|
def _alloc_kernargs(self, alloc_size:int) -> int:
|
|
"""
|
|
Allocates space for arguments passed to the kernel.
|
|
"""
|
|
if self.kernargs_ptr >= (self.kernargs_page.va_addr + self.kernargs_page.size - alloc_size): self.kernargs_ptr = self.kernargs_page.va_addr
|
|
self.kernargs_ptr = (res:=self.kernargs_ptr) + alloc_size
|
|
return res
|
|
|
|
def _ensure_shared_time_base(self):
|
|
if not self.gpu2cpu_compute_time_diff.is_nan(): return
|
|
|
|
def _sync_cpu_queue(d:HCQCompiled, q_t:Type[HWQueue]):
|
|
q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.timeline_value).submit(d)
|
|
d.timeline_value += 1
|
|
st = time.perf_counter_ns()
|
|
d.timeline_signal.wait(d.timeline_value - 1) # average of the two
|
|
et = time.perf_counter_ns()
|
|
return (decimal.Decimal(et+st) / 2000) - d.timeline_signal.timestamp
|
|
|
|
# randomly sample the timing from GPU to CPU
|
|
choices: List = [(d, d.hw_compute_queue_t, []) for d in self.devices]
|
|
choices += [(d, d.hw_copy_queue_t, []) for d in self.devices if d.hw_copy_queue_t is not None]
|
|
for _ in range(100*len(self.devices)):
|
|
d,q,l = random.choice(choices)
|
|
l.append(_sync_cpu_queue(d,q))
|
|
for d,q,l in choices:
|
|
if q == d.hw_compute_queue_t: d.gpu2cpu_compute_time_diff = statistics.median(l)
|
|
if q == d.hw_copy_queue_t: d.gpu2cpu_copy_time_diff = statistics.median(l)
|
|
|
|
def _sync_gpu_to_gpu_queue(d1:HCQCompiled, d2:HCQCompiled, q1_t:Type[HWQueue], q2_t:Type[HWQueue]):
|
|
q1_t().signal(d1.timeline_signal, d1.timeline_value).wait(d2.timeline_signal, d2.timeline_value) \
|
|
.timestamp(d1.timeline_signal).signal(d1.timeline_signal, d1.timeline_value+1).submit(d1)
|
|
q2_t().signal(d2.timeline_signal, d2.timeline_value).wait(d1.timeline_signal, d1.timeline_value) \
|
|
.timestamp(d2.timeline_signal).signal(d2.timeline_signal, d2.timeline_value+1).submit(d2)
|
|
d1.timeline_value += 2
|
|
d2.timeline_value += 2
|
|
d1.timeline_signal.wait(d1.timeline_value - 1)
|
|
d2.timeline_signal.wait(d2.timeline_value - 1)
|
|
return d2.timeline_signal.timestamp - d1.timeline_signal.timestamp
|
|
|
|
# then test it by timing the GPU to GPU times
|
|
jitter_matrix = [[float('nan')]*len(self.devices) for _ in range(len(self.devices))]
|
|
for i1, d1 in enumerate(self.devices):
|
|
for i2, d2 in enumerate(self.devices):
|
|
if d1 == d2: continue
|
|
d1_to_d2 = statistics.median(_sync_gpu_to_gpu_queue(d1, d2, d1.hw_compute_queue_t, d2.hw_compute_queue_t) - \
|
|
_sync_gpu_to_gpu_queue(d2, d1, d2.hw_compute_queue_t, d1.hw_compute_queue_t) for _ in range(20)) / 2
|
|
jitter_matrix[i1][i2] = d1_to_d2 - (d1.gpu2cpu_compute_time_diff - d2.gpu2cpu_compute_time_diff)
|
|
print("pairwise clock jitter matrix (us):\n" + '\n'.join([''.join([f'{float(item):8.3f}' for item in row]) for row in jitter_matrix]))
|
|
|
|
def _gpu2cpu_time(self, gpu_time:decimal.Decimal, is_copy:bool) -> float:
|
|
"""
|
|
Translates local gpu time (timestamp) into global cpu time.
|
|
"""
|
|
self._ensure_shared_time_base()
|
|
return float(gpu_time + (self.gpu2cpu_copy_time_diff if is_copy else self.gpu2cpu_compute_time_diff))
|
|
|
|
def _prof_setup(self):
|
|
if hasattr(self, 'profile_logger'): return
|
|
atexit.register(self._prof_finalize)
|
|
self.profile_logger = ProfileLogger()
|
|
|
|
def _prof_finalize(self):
|
|
qname = ["COMPUTE", "DMA"]
|
|
|
|
# Sync to be sure all events on the device are recorded.
|
|
self.synchronize()
|
|
|
|
for st, en, name, is_cp, args in self.raw_prof_records:
|
|
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.device, qname[is_cp], args)]
|
|
for a_st, a_en, a_dev, a_is_copy, b_st, b_en, b_dev, b_is_copy in self.dep_prof_records:
|
|
# Perfetto connects nodes based on timing data, ensuring every choice is valid by averaging times to a midpoint.
|
|
a_tm, b_tm = a_dev._gpu2cpu_time((a_st+a_en)/decimal.Decimal(2), a_is_copy), b_dev._gpu2cpu_time((b_st+b_en)/decimal.Decimal(2), b_is_copy)
|
|
self.profile_logger.deps += [(a_tm, b_tm, a_dev.device, qname[a_is_copy], b_dev.device, qname[b_is_copy])]
|
|
self.raw_prof_records, self.dep_prof_records = [], []
|
|
|
|
# Remove the logger, this flushes all data written by the device.
|
|
del self.profile_logger
|
|
|
|
def _wrap_timeline_signal(self):
|
|
self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
|
|
self.timeline_signal.value = 0
|
|
cast(HCQAllocator, self.allocator).b_timeline = [0] * len(cast(HCQAllocator, self.allocator).b)
|
|
|
|
# Protocol for hcq compatible allocators for allocated buffers to contain VA address and it's size.
|
|
class HCQBuffer(Protocol): va_addr:int; size:int # noqa: E702
|
|
|
|
class HCQAllocator(LRUAllocator, Generic[DeviceType]):
|
|
"""
|
|
A base allocator class compatible with the HCQ (Hardware Command Queue) API.
|
|
|
|
This class implements basic copy operations following the HCQ API, utilizing both types of `HWQueue`.
|
|
"""
|
|
|
|
def __init__(self, dev:DeviceType, batch_size:int=(2 << 20), batch_cnt:int=32):
|
|
self.dev:DeviceType = dev
|
|
self.b = [self._alloc(batch_size, BufferSpec(host=True)) for _ in range(batch_cnt)]
|
|
self.b_timeline, self.b_next = [0] * len(self.b), 0
|
|
super().__init__()
|
|
|
|
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer: raise NotImplementedError("need hcq compat alloc")
|
|
|
|
def _copyin(self, dest:HCQBuffer, src:memoryview):
|
|
assert self.dev.hw_copy_queue_t is not None
|
|
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"CPU -> {self.dev.device}", enabled=PROFILE):
|
|
for i in range(0, src.nbytes, self.b[0].size):
|
|
self.b_next = (self.b_next + 1) % len(self.b)
|
|
self.dev.timeline_signal.wait(self.b_timeline[self.b_next])
|
|
ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
|
|
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
|
|
.copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
|
|
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
|
|
self.b_timeline[self.b_next] = self.dev.timeline_value
|
|
self.dev.timeline_value += 1
|
|
|
|
def copy_from_disk(self, dest:HCQBuffer, src, size):
|
|
def _get_temp_buf():
|
|
# Check if the next buffer is safe to be used (its signal has passed) and reserve it.
|
|
if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.dev.timeline_signal.value:
|
|
self.b_timeline[(self.b_next + 1) % len(self.b)], self.b_next = (1 << 64), (self.b_next + 1) % len(self.b)
|
|
return (self.b[self.b_next].va_addr, self.b_next)
|
|
return None
|
|
|
|
assert self.dev.hw_copy_queue_t is not None
|
|
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"DISK -> {self.dev.device}", enabled=PROFILE):
|
|
for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
|
|
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
|
|
.copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
|
|
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
|
|
self.b_timeline[batch_info[1]] = self.dev.timeline_value
|
|
self.dev.timeline_value += 1
|
|
|
|
def _copyout(self, dest:memoryview, src:HCQBuffer):
|
|
self.dev.synchronize()
|
|
|
|
assert self.dev.hw_copy_queue_t is not None
|
|
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"{self.dev.device} -> CPU", enabled=PROFILE):
|
|
for i in range(0, dest.nbytes, self.b[0].size):
|
|
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
|
|
.copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
|
|
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
|
|
self.dev.timeline_signal.wait(self.dev.timeline_value)
|
|
self.dev.timeline_value += 1
|
|
|
|
ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
|
|
|
|
def _transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev:DeviceType, dest_dev:DeviceType):
|
|
cast(HCQAllocator, src_dev.allocator).map(dest)
|
|
|
|
assert src_dev.hw_copy_queue_t is not None
|
|
with hcq_profile(src_dev, queue_type=src_dev.hw_copy_queue_t, desc=f"{src_dev.device} -> {dest_dev.device}", enabled=PROFILE):
|
|
src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
|
|
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
|
|
.copy(dest.va_addr, src.va_addr, sz) \
|
|
.signal(src_dev.timeline_signal, src_dev.timeline_value).submit(src_dev)
|
|
src_dev.timeline_value += 1
|
|
|
|
if src_dev != dest_dev:
|
|
dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
|
|
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
|
|
.signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
|
|
dest_dev.timeline_value += 1
|
|
|
|
def map(self, buf:HCQBuffer): pass
|
|
|
|
def _offset(self, buf, size:int, offset:int) -> HCQBuffer:
|
|
return type(buf)(va_addr=buf.va_addr + offset, size=size, **{k:v for k,v in buf.__dict__.items() if k not in ['va_addr', 'size']},
|
|
**{x[0]:getattr(buf, x[0]) for x in getattr(buf, '_fields_', []) if x[0] not in ['va_addr', 'size']}, _base=buf)
|