|
|
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
|
|
import multiprocessing
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from typing import List, Optional, Dict, Tuple, Any, cast, Protocol
|
|
|
|
|
from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type
|
|
|
|
|
import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib, array
|
|
|
|
|
from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE
|
|
|
|
|
from tinygrad.dtype import DType, ImageDType
|
|
|
|
|
@@ -225,7 +225,7 @@ class HWCommandQueue:
|
|
|
|
|
def _patch(self, cmd_idx, offset, data): self.q[(st:=self.cmds_offset[cmd_idx]+offset):st+len(data)] = array.array('I', data)
|
|
|
|
|
|
|
|
|
|
@hcq_command
|
|
|
|
|
def signal(self, signal, value):
|
|
|
|
|
def signal(self, signal:Any, value:int):
|
|
|
|
|
"""
|
|
|
|
|
Enqueues a signal command which sets the signal to the given value, ensuring all previous operations are completed.
|
|
|
|
|
|
|
|
|
|
@@ -234,10 +234,10 @@ class HWCommandQueue:
|
|
|
|
|
value: The value to set the signal to
|
|
|
|
|
"""
|
|
|
|
|
self._signal(signal, value)
|
|
|
|
|
def _signal(self, signal, value): raise NotImplementedError("backend should overload this function")
|
|
|
|
|
def _signal(self, signal:Any, value:int): raise NotImplementedError("backend should overload this function")
|
|
|
|
|
|
|
|
|
|
@hcq_command
|
|
|
|
|
def wait(self, signal, value):
|
|
|
|
|
def wait(self, signal:Any, value:int):
|
|
|
|
|
"""
|
|
|
|
|
Enqueues a wait command which halts execution until the signal is greater than or equal to a specific value.
|
|
|
|
|
|
|
|
|
|
@@ -249,7 +249,7 @@ class HWCommandQueue:
|
|
|
|
|
def _wait(self, signal, value): raise NotImplementedError("backend should overload this function")
|
|
|
|
|
|
|
|
|
|
@hcq_command
|
|
|
|
|
def timestamp(self, signal):
|
|
|
|
|
def timestamp(self, signal:Any):
|
|
|
|
|
"""
|
|
|
|
|
Enqueues a timestamp command which records the current time in a signal after all previously enqueued commands are completed.
|
|
|
|
|
|
|
|
|
|
@@ -259,7 +259,7 @@ class HWCommandQueue:
|
|
|
|
|
self._timestamp(signal)
|
|
|
|
|
def _timestamp(self, signal): raise NotImplementedError("backend should overload this function")
|
|
|
|
|
|
|
|
|
|
def update_signal(self, cmd_idx, signal=None, value=None):
|
|
|
|
|
def update_signal(self, cmd_idx:int, signal:Optional[Any]=None, value:Optional[int]=None):
|
|
|
|
|
"""
|
|
|
|
|
Updates a previously queued signal command.
|
|
|
|
|
|
|
|
|
|
@@ -271,9 +271,9 @@ class HWCommandQueue:
|
|
|
|
|
if self.cmds_meta[cmd_idx] != "signal": raise RuntimeError("called update_signal not on a signal command")
|
|
|
|
|
self._update_signal(cmd_idx, signal, value)
|
|
|
|
|
return self
|
|
|
|
|
def _update_signal(self, cmd_idx, signal, value): raise NotImplementedError("backend should overload this function")
|
|
|
|
|
def _update_signal(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
|
|
|
|
|
|
|
|
|
|
def update_wait(self, cmd_idx, signal=None, value=None):
|
|
|
|
|
def update_wait(self, cmd_idx:int, signal:Optional[Any]=None, value:Optional[int]=None):
|
|
|
|
|
"""
|
|
|
|
|
Updates a previously queued wait command.
|
|
|
|
|
|
|
|
|
|
@@ -285,7 +285,7 @@ class HWCommandQueue:
|
|
|
|
|
if self.cmds_meta[cmd_idx] != "wait": raise RuntimeError("called update_wait not on a wait command")
|
|
|
|
|
self._update_wait(cmd_idx, signal, value)
|
|
|
|
|
return self
|
|
|
|
|
def _update_wait(self, cmd_idx, signal, value): raise NotImplementedError("backend should overload this function")
|
|
|
|
|
def _update_wait(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
|
|
|
|
|
|
|
|
|
|
def submit(self, device:HCQCompatCompiled):
|
|
|
|
|
"""
|
|
|
|
|
@@ -308,7 +308,7 @@ class HWComputeQueue(HWCommandQueue):
|
|
|
|
|
def _memory_barrier(self): pass
|
|
|
|
|
|
|
|
|
|
@hcq_command
|
|
|
|
|
def exec(self, prg, kernargs, global_size, local_size):
|
|
|
|
|
def exec(self, prg:HCQCompatProgram, kernargs:int, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int]):
|
|
|
|
|
"""
|
|
|
|
|
Enqueues an execution command for a kernel program.
|
|
|
|
|
|
|
|
|
|
@@ -321,7 +321,7 @@ class HWComputeQueue(HWCommandQueue):
|
|
|
|
|
self._exec(prg, kernargs, global_size, local_size)
|
|
|
|
|
def _exec(self, prg, kernargs, global_size, local_size): raise NotImplementedError("backend should overload this function")
|
|
|
|
|
|
|
|
|
|
def update_exec(self, cmd_idx, global_size, local_size):
|
|
|
|
|
def update_exec(self, cmd_idx:int, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int]):
|
|
|
|
|
"""
|
|
|
|
|
Updates a previously queued execution command.
|
|
|
|
|
|
|
|
|
|
@@ -337,7 +337,7 @@ class HWComputeQueue(HWCommandQueue):
|
|
|
|
|
|
|
|
|
|
class HWCopyQueue(HWCommandQueue):
|
|
|
|
|
@hcq_command
|
|
|
|
|
def copy(self, dest, src, copy_size):
|
|
|
|
|
def copy(self, dest:HCQCompatAllocRes, src:HCQCompatAllocRes, copy_size:int):
|
|
|
|
|
"""
|
|
|
|
|
Enqueues a copy command to transfer data.
|
|
|
|
|
|
|
|
|
|
@@ -347,9 +347,9 @@ class HWCopyQueue(HWCommandQueue):
|
|
|
|
|
copy_size: The size of data to copy
|
|
|
|
|
"""
|
|
|
|
|
self._copy(dest, src, copy_size)
|
|
|
|
|
def _copy(self, dest, src, copy_size): raise NotImplementedError("backend should overload this function")
|
|
|
|
|
def _copy(self, dest:HCQCompatAllocRes, src:HCQCompatAllocRes, copy_size:int): raise NotImplementedError("backend should overload this function")
|
|
|
|
|
|
|
|
|
|
def update_copy(self, cmd_idx, dest=None, src=None):
|
|
|
|
|
def update_copy(self, cmd_idx:int, dest:Optional[HCQCompatAllocRes]=None, src:Optional[HCQCompatAllocRes]=None):
|
|
|
|
|
"""
|
|
|
|
|
Updates a previously queued copy command.
|
|
|
|
|
|
|
|
|
|
@@ -378,7 +378,7 @@ def hcq_profile(dev, enabled, desc, queue_type=None, queue=None):
|
|
|
|
|
if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t))
|
|
|
|
|
|
|
|
|
|
class HCQCompatProgram:
|
|
|
|
|
def __init__(self, kernargs_alloc_size, kernargs_args_offset=0):
|
|
|
|
|
def __init__(self, kernargs_alloc_size:int, kernargs_args_offset:int=0):
|
|
|
|
|
self.kernargs_alloc_size, self.kernargs_args_offset = kernargs_alloc_size, kernargs_args_offset
|
|
|
|
|
def fill_kernargs(self, kernargs_ptr:int, bufs:Tuple[Any, ...], vals:Tuple[int, ...]=()): raise NotImplementedError("need fill_kernargs")
|
|
|
|
|
|
|
|
|
|
@@ -387,7 +387,8 @@ class HCQCompatCompiled(Compiled):
|
|
|
|
|
A base class for devices compatible with the HCQ (Hardware Command Queue) API.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, device:str, allocator:Allocator, renderer:Renderer, compiler:Compiler, runtime, comp_queue_t, copy_queue_t, timeline_signals):
|
|
|
|
|
def __init__(self, device:str, allocator:Allocator, renderer:Renderer, compiler:Compiler, runtime,
|
|
|
|
|
comp_queue_t:Type[HWComputeQueue], copy_queue_t:Type[HWCopyQueue], timeline_signals:Tuple[Any, Any]):
|
|
|
|
|
self.hw_compute_queue_t, self.hw_copy_queue_t = comp_queue_t, copy_queue_t
|
|
|
|
|
self.timeline_value:int = 1
|
|
|
|
|
self.timeline_signal, self._shadow_timeline_signal = timeline_signals
|
|
|
|
|
@@ -399,7 +400,7 @@ class HCQCompatCompiled(Compiled):
|
|
|
|
|
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _read_signal(cls, signal:Any) -> Any:
|
|
|
|
|
def _read_signal(cls, signal:Any) -> int:
|
|
|
|
|
"""
|
|
|
|
|
Read a value for a signal.
|
|
|
|
|
"""
|
|
|
|
|
@@ -413,34 +414,34 @@ class HCQCompatCompiled(Compiled):
|
|
|
|
|
raise NotImplementedError("_read_timestamp needs to be implemented")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _set_signal(cls, signal:Any, value:Any) -> None:
|
|
|
|
|
def _set_signal(cls, signal:Any, value:int):
|
|
|
|
|
"""
|
|
|
|
|
Set a value for a signal.
|
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError("_set_signal needs to be implemented")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _alloc_signal(cls, value:Any = 0, **kwargs) -> Any:
|
|
|
|
|
def _alloc_signal(cls, value:int = 0, **kwargs) -> Any:
|
|
|
|
|
"""
|
|
|
|
|
Allocate a new signal.
|
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError("_alloc_signal needs to be implemented")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _free_signal(cls, signal:Any) -> None:
|
|
|
|
|
def _free_signal(cls, signal:Any):
|
|
|
|
|
"""
|
|
|
|
|
Free a signal.
|
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError("_free_signal needs to be implemented")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _wait_signal(cls, signal:Any, value:Any = 0, timeout:int = 10000) -> None:
|
|
|
|
|
def _wait_signal(cls, signal:Any, value:int = 0, timeout:int = 10000):
|
|
|
|
|
"""
|
|
|
|
|
Wait for a signal to reach a specific value. Signals
|
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError("_wait_signal needs to be implemented")
|
|
|
|
|
|
|
|
|
|
def _gpu2cpu_time(self, gpu_time:float, is_copy:bool) -> float:
|
|
|
|
|
def _gpu2cpu_time(self, gpu_time:int, is_copy:bool) -> float:
|
|
|
|
|
"""
|
|
|
|
|
Convert GPU time to CPU time. `is_copy` flag indicating if this is a copy queue.
|
|
|
|
|
"""
|
|
|
|
|
@@ -475,7 +476,7 @@ class HCQCompatCompiled(Compiled):
|
|
|
|
|
cast(HCQCompatAllocator, self.allocator).b_timeline = [0] * len(cast(HCQCompatAllocator, self.allocator).b)
|
|
|
|
|
|
|
|
|
|
# Protocol for hcq compatible allocators for allocated buffers to contain VA address and it's size.
|
|
|
|
|
class HCQCompatAllocRes(Protocol): va_addr: int; size: int # noqa: E702
|
|
|
|
|
class HCQCompatAllocRes(Protocol): va_addr:int; size:int # noqa: E702
|
|
|
|
|
|
|
|
|
|
class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
|
|
|
|
|
"""
|
|
|
|
|
@@ -484,15 +485,15 @@ class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
|
|
|
|
|
This class implements basic copy operations following the HCQ API, utilizing both `HWComputeQueue` and `HWCopyQueue`.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, device, batch_size=(2 << 20), batch_cnt=32):
|
|
|
|
|
self.device = device
|
|
|
|
|
def __init__(self, device:HCQCompatCompiled, batch_size:int=(2 << 20), batch_cnt:int=32):
|
|
|
|
|
self.device:Any = device
|
|
|
|
|
self.b = [self._alloc(batch_size, BufferOptions(host=True)) for _ in range(batch_cnt)]
|
|
|
|
|
self.b_timeline, self.b_next = [0] * len(self.b), 0
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
def _alloc(self, size:int, options:BufferOptions) -> HCQCompatAllocRes: raise NotImplementedError("need hcq compat alloc")
|
|
|
|
|
|
|
|
|
|
def copyin(self, dest: HCQCompatAllocRes, src: memoryview):
|
|
|
|
|
def copyin(self, dest:HCQCompatAllocRes, src:memoryview):
|
|
|
|
|
with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"CPU -> {self.device.dname}", enabled=PROFILE):
|
|
|
|
|
for i in range(0, src.nbytes, self.b[0].size):
|
|
|
|
|
self.b_next = (self.b_next + 1) % len(self.b)
|
|
|
|
|
@@ -504,7 +505,7 @@ class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
|
|
|
|
|
self.b_timeline[self.b_next] = self.device.timeline_value
|
|
|
|
|
self.device.timeline_value += 1
|
|
|
|
|
|
|
|
|
|
def copy_from_disk(self, dest: HCQCompatAllocRes, src, size):
|
|
|
|
|
def copy_from_disk(self, dest:HCQCompatAllocRes, src, size):
|
|
|
|
|
def _get_temp_buf():
|
|
|
|
|
# Check if the next buffer is safe to be used (its signal has passed) and reserve it.
|
|
|
|
|
if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.device._read_signal(self.device.timeline_signal):
|
|
|
|
|
@@ -520,7 +521,7 @@ class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
|
|
|
|
|
self.b_timeline[batch_info[1]] = self.device.timeline_value
|
|
|
|
|
self.device.timeline_value += 1
|
|
|
|
|
|
|
|
|
|
def copyout(self, dest:memoryview, src: HCQCompatAllocRes):
|
|
|
|
|
def copyout(self, dest:memoryview, src:HCQCompatAllocRes):
|
|
|
|
|
self.device.synchronize()
|
|
|
|
|
|
|
|
|
|
with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"{self.device.dname} -> CPU", enabled=PROFILE):
|
|
|
|
|
@@ -533,7 +534,7 @@ class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
|
|
|
|
|
|
|
|
|
|
ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
|
|
|
|
|
|
|
|
|
|
def transfer(self, dest: HCQCompatAllocRes, src: HCQCompatAllocRes, sz: int, src_dev, dest_dev):
|
|
|
|
|
def transfer(self, dest:HCQCompatAllocRes, src:HCQCompatAllocRes, sz:int, src_dev, dest_dev):
|
|
|
|
|
src_dev._gpu_map(dest)
|
|
|
|
|
|
|
|
|
|
with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"{src_dev.dname} -> {dest_dev.dname}", enabled=PROFILE):
|
|
|
|
|
|