mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
hcq: refactor into peer_groups (#11277)
* hcq: refactor into peer_groups * fix fors * fixes * ooops * mypy * tiny fixes
This commit is contained in:
@@ -52,7 +52,7 @@ Signals are device-dependent structures used for synchronization and timing in H
|
||||
The following Python code demonstrates the usage of signals:
|
||||
|
||||
```python
|
||||
signal = your_device.signal_t()
|
||||
signal = your_device.new_signal(value=0)
|
||||
|
||||
HWQueue().timestamp(signal) \
|
||||
.signal(signal, value_to_fire) \
|
||||
|
||||
@@ -68,7 +68,7 @@ class TestHCQ(unittest.TestCase):
|
||||
if queue_type is None: continue
|
||||
|
||||
with self.subTest(name=str(queue_type)):
|
||||
fake_signal = TestHCQ.d0.signal_t()
|
||||
fake_signal = TestHCQ.d0.new_signal()
|
||||
fake_signal.value = 1
|
||||
queue_type().wait(fake_signal, 1) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
@@ -81,7 +81,7 @@ class TestHCQ(unittest.TestCase):
|
||||
if queue_type is None: continue
|
||||
|
||||
with self.subTest(name=str(queue_type)):
|
||||
fake_signal = TestHCQ.d0.signal_t()
|
||||
fake_signal = TestHCQ.d0.new_signal()
|
||||
queue_type().wait(fake_signal, 1) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
|
||||
@@ -101,7 +101,7 @@ class TestHCQ(unittest.TestCase):
|
||||
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
|
||||
virt_signal = TestHCQ.d0.signal_t(base_buf=HCQBuffer(Variable("sig_addr", 0, 0xffffffffffffffff, dtypes.uint64), 16))
|
||||
|
||||
fake_signal = TestHCQ.d0.signal_t()
|
||||
fake_signal = TestHCQ.d0.new_signal()
|
||||
q = queue_type().wait(virt_signal, virt_val).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
|
||||
fake_signal.value = 0x30
|
||||
@@ -293,7 +293,7 @@ class TestHCQ(unittest.TestCase):
|
||||
virt_signal = TestHCQ.d0.signal_t(base_buf=HCQBuffer(Variable("sig_addr", 0, 0xffffffffffffffff, dtypes.uint64), 16))
|
||||
|
||||
with self.subTest(name=str(queue_type)):
|
||||
fake_signal = TestHCQ.d0.signal_t()
|
||||
fake_signal = TestHCQ.d0.new_signal()
|
||||
q = queue_type().wait(virt_signal, virt_val).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
q.bind(TestHCQ.d0)
|
||||
|
||||
@@ -310,7 +310,7 @@ class TestHCQ(unittest.TestCase):
|
||||
try: d1 = Device[f"{Device.DEFAULT}:1"]
|
||||
except Exception: self.skipTest("no multidevice, test skipped")
|
||||
|
||||
TestHCQ.d0.hw_copy_queue_t().signal(sig:=TestHCQ.d0.signal_t(value=0), value=0xfff) \
|
||||
TestHCQ.d0.hw_copy_queue_t().signal(sig:=TestHCQ.d0.new_signal(value=0), value=0xfff) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
|
||||
d1.hw_copy_queue_t().wait(sig, value=0xfff) \
|
||||
@@ -324,7 +324,7 @@ class TestHCQ(unittest.TestCase):
|
||||
|
||||
# Test profile api
|
||||
def test_speed_exec_time(self):
|
||||
sig_st, sig_en = TestHCQ.d0.signal_t(), TestHCQ.d0.signal_t()
|
||||
sig_st, sig_en = TestHCQ.d0.new_signal(), TestHCQ.d0.new_signal()
|
||||
TestHCQ.d0.hw_compute_queue_t().timestamp(sig_st) \
|
||||
.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
|
||||
.timestamp(sig_en) \
|
||||
@@ -346,7 +346,7 @@ class TestHCQ(unittest.TestCase):
|
||||
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
|
||||
b = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
|
||||
|
||||
sig_st, sig_en = TestHCQ.d0.signal_t(), TestHCQ.d0.signal_t()
|
||||
sig_st, sig_en = TestHCQ.d0.new_signal(), TestHCQ.d0.new_signal()
|
||||
TestHCQ.d0.hw_copy_queue_t().timestamp(sig_st) \
|
||||
.copy(a._buf.va_addr, b._buf.va_addr, SZ) \
|
||||
.timestamp(sig_en) \
|
||||
@@ -373,7 +373,7 @@ class TestHCQ(unittest.TestCase):
|
||||
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
|
||||
TestHCQ.d0.allocator.map(b._buf)
|
||||
|
||||
sig_st, sig_en = TestHCQ.d0.signal_t(), TestHCQ.d0.signal_t()
|
||||
sig_st, sig_en = TestHCQ.d0.new_signal(), TestHCQ.d0.new_signal()
|
||||
TestHCQ.d0.hw_copy_queue_t().timestamp(sig_st) \
|
||||
.copy(a._buf.va_addr, b._buf.va_addr, SZ) \
|
||||
.timestamp(sig_en) \
|
||||
@@ -531,8 +531,8 @@ class TestHCQ(unittest.TestCase):
|
||||
try: nv_dev = Device["NV"]
|
||||
except Exception: self.skipTest("no NV device, test skipped")
|
||||
|
||||
x = amd_dev.signal_t()
|
||||
y = nv_dev.signal_t()
|
||||
x = amd_dev.new_signal()
|
||||
y = nv_dev.new_signal()
|
||||
assert type(x) is amd_dev.signal_t
|
||||
assert type(y) is nv_dev.signal_t
|
||||
|
||||
|
||||
@@ -48,13 +48,13 @@ class HCQGraph(MultiGraphRunner):
|
||||
self.comp_queues: dict[HCQCompiled, HWQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
|
||||
self.copy_queues: dict[HCQCompiled, HWQueue] = {} # lazy allocation
|
||||
|
||||
self.signals: dict[Any, HCQSignal] = {**{dev: dev.signal_t(value=0) for dev in self.devices}, **{"KICK": self.devices[0].signal_t(value=0)}}
|
||||
self.signals: dict[Any, HCQSignal] = {**{dev: dev.new_signal(value=0) for dev in self.devices}, **{"KICK": self.devices[0].new_signal(value=0)}}
|
||||
self.kickoff_value: int = 0
|
||||
self.kickoff_var = UOp.variable("kickoff_var", 0, 0xffffffff, dtype=dtypes.uint32)
|
||||
|
||||
# When profiling allocate 2 signals for each jit item to measure speed. The jth jit item have signals at 2*j and 2*j+1.
|
||||
# TODO: This logic might allocate a few extra signals...
|
||||
self.prof_signals: list[HCQSignal] = [self.devices[0].signal_t() for i in range(len(jit_cache) * 2)] if PROFILE else []
|
||||
self.prof_signals: list[HCQSignal] = [self.devices[0].new_signal() for i in range(len(jit_cache) * 2)] if PROFILE else []
|
||||
self.prog_graph_deps: list[list[int]] = []
|
||||
self.prof_graph_entries: list[ProfileGraphEntry] = []
|
||||
|
||||
@@ -78,7 +78,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
|
||||
enqueue_queue = self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
|
||||
|
||||
out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))
|
||||
out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.new_signal(value=0))
|
||||
|
||||
# Get dependencies based on input and output buffers.
|
||||
rdeps = self._access_resources(ji.bufs, ji.prg.p.outs if is_exec_prg else [0], (enqueue_queue, j + 1)) #type:ignore
|
||||
@@ -131,7 +131,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
# Create variable timeline signals for each device.
|
||||
timeline_sigaddrs = {dev: UOp.variable(f"timeline_sig_{dev.device_id}", 0, 0xffffffffffffffff, dtype=dtypes.uint64) for dev in self.devices}
|
||||
self.virt_timeline_vals = {dev: UOp.variable(f"timeline_var_{dev.device_id}", 0, 0xffffffff, dtype=dtypes.uint32) for dev in self.devices}
|
||||
self.virt_timeline_signals = {dev: dev.signal_t(base_buf=HCQBuffer(timeline_sigaddrs[dev], 16), timeline_for_device=dev) for dev in self.devices}
|
||||
self.virt_timeline_signals = {dev: dev.signal_t(HCQBuffer(timeline_sigaddrs[dev], 16), owner=dev, is_timeline=True) for dev in self.devices}
|
||||
|
||||
for dev in self.devices:
|
||||
self.comp_queues[dev].memory_barrier().wait(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev]) \
|
||||
|
||||
@@ -26,12 +26,11 @@ WAIT_REG_MEM_FUNCTION_NEQ = 4 # !=
|
||||
WAIT_REG_MEM_FUNCTION_GEQ = 5 # >=
|
||||
|
||||
class AMDSignal(HCQSignal):
|
||||
def __init__(self, base_buf:HCQBuffer|None=None, **kwargs):
|
||||
super().__init__(base_buf, **kwargs, timestamp_divider=100, dev_t=AMDDevice)
|
||||
def __init__(self, *args, **kwargs): super().__init__(*args, **{**kwargs, 'timestamp_divider': 100})
|
||||
|
||||
def _sleep(self, time_spent_waiting_ms:int):
|
||||
# Resonable to sleep for long workloads (which take more than 2s) and only timeline signals.
|
||||
if time_spent_waiting_ms > 2000 and self.timeline_for_device is not None: self.timeline_for_device.iface.sleep(200)
|
||||
if time_spent_waiting_ms > 2000 and self.is_timeline and self.owner is not None: self.owner.iface.sleep(200)
|
||||
|
||||
class AMDComputeQueue(HWQueue):
|
||||
def __init__(self, dev:AMDDevice):
|
||||
@@ -299,7 +298,7 @@ class AMDComputeQueue(HWQueue):
|
||||
self.release_mem(signal.value_addr, value, self.pm4.data_sel__mec_release_mem__send_32_bit_low,
|
||||
self.pm4.int_sel__mec_release_mem__send_interrupt_after_write_confirm, cache_flush=True)
|
||||
|
||||
if (dev:=signal.timeline_for_device) is not None and not dev.is_am():
|
||||
if (dev:=signal.owner) is not None and signal.is_timeline and not dev.is_am():
|
||||
self.release_mem(dev.queue_event_mailbox_ptr, dev.queue_event.event_id, self.pm4.data_sel__mec_release_mem__send_32_bit_low,
|
||||
self.pm4.int_sel__mec_release_mem__send_interrupt_after_write_confirm, ctxid=dev.queue_event.event_id)
|
||||
return self
|
||||
@@ -355,7 +354,7 @@ class AMDCopyQueue(HWQueue):
|
||||
fence_flags = self.sdma.SDMA_PKT_FENCE_HEADER_MTYPE(3) if self.dev.target >= (10,0,0) else 0
|
||||
self.q(self.sdma.SDMA_OP_FENCE | fence_flags, *data64_le(signal.value_addr), value)
|
||||
|
||||
if (dev:=signal.timeline_for_device) is not None and not dev.is_am():
|
||||
if (dev:=signal.owner) is not None and signal.is_timeline and not dev.is_am():
|
||||
self.q(self.sdma.SDMA_OP_FENCE | fence_flags, *data64_le(dev.queue_event_mailbox_ptr), dev.queue_event.event_id)
|
||||
self.q(self.sdma.SDMA_OP_TRAP, self.sdma.SDMA_PKT_TRAP_INT_CONTEXT_INT_CONTEXT(dev.queue_event.event_id))
|
||||
elif dev is not None and dev.is_am(): self.q(self.sdma.SDMA_OP_TRAP, self.sdma.SDMA_PKT_TRAP_INT_CONTEXT_INT_CONTEXT(0))
|
||||
@@ -682,7 +681,8 @@ class PCIIface(PCIIfaceBase):
|
||||
self.dev_impl.ih.interrupt_handler()
|
||||
|
||||
def on_device_hang(self):
|
||||
for d in self.dev.devices: d.iface.dev_impl.gmc.on_interrupt()
|
||||
devs:list[AMDDevice] = [d for pg in HCQCompiled.peer_groups.values() for d in pg if isinstance(d, AMDDevice) and d.is_am()]
|
||||
for d in devs: d.iface.dev_impl.gmc.on_interrupt()
|
||||
raise RuntimeError("Device hang detected")
|
||||
|
||||
def device_fini(self): self.dev_impl.fini()
|
||||
@@ -730,10 +730,6 @@ class USBIface(PCIIface):
|
||||
def sleep(self, timeout): pass
|
||||
|
||||
class AMDDevice(HCQCompiled):
|
||||
devices: ClassVar[list[HCQCompiled]] = []
|
||||
signal_pages: ClassVar[list[HCQBuffer]] = []
|
||||
signal_pool: ClassVar[list[HCQBuffer]] = []
|
||||
|
||||
def is_am(self) -> bool: return isinstance(self.iface, (PCIIface, USBIface))
|
||||
def is_usb(self) -> bool: return isinstance(self.iface, USBIface)
|
||||
|
||||
|
||||
@@ -72,8 +72,7 @@ class QMD:
|
||||
else: self.write(**{f'constant_buffer_addr_upper_shifted6_{i}':hi32(addr >> 6), f'constant_buffer_addr_lower_shifted6_{i}':lo32(addr >> 6)})
|
||||
|
||||
class NVSignal(HCQSignal):
|
||||
def __init__(self, base_buf:HCQBuffer|None=None, **kwargs):
|
||||
super().__init__(base_buf, **kwargs, timestamp_divider=1000, dev_t=NVDevice)
|
||||
def __init__(self, *args, **kwargs): super().__init__(*args, **{**kwargs, 'timestamp_divider': 1000})
|
||||
|
||||
class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']):
|
||||
def __init__(self):
|
||||
@@ -365,7 +364,7 @@ class NVKIface:
|
||||
uvm.register_gpu(self.fd_uvm, rmCtrlFd=-1, gpu_uuid=self.gpu_uuid)
|
||||
uvm.register_gpu_vaspace(self.fd_uvm, gpuUuid=self.gpu_uuid, rmCtrlFd=self.fd_ctl.fd, hClient=self.root, hVaSpace=vaspace)
|
||||
|
||||
for dev in cast(list[NVDevice], self.dev.devices):
|
||||
for dev in cast(list[NVDevice], [d for pg in HCQCompiled.peer_groups.values() for d in pg if isinstance(d, NVDevice) and not d.is_nvd()]):
|
||||
try: uvm.enable_peer_access(self.fd_uvm, gpuUuidA=self.gpu_uuid, gpuUuidB=dev.iface.gpu_uuid)
|
||||
except RuntimeError as e: raise RuntimeError(f"{e}. Make sure GPUs #{self.gpu_minor} & #{dev.iface.gpu_minor} have P2P enabled.") from e
|
||||
|
||||
@@ -481,10 +480,6 @@ class PCIIface(PCIIfaceBase):
|
||||
def device_fini(self): self.dev_impl.fini()
|
||||
|
||||
class NVDevice(HCQCompiled[NVSignal]):
|
||||
devices: ClassVar[list[HCQCompiled]] = []
|
||||
signal_pages: ClassVar[list[HCQBuffer]] = []
|
||||
signal_pool: ClassVar[list[HCQBuffer]] = []
|
||||
|
||||
def is_nvd(self) -> bool: return isinstance(self.iface, PCIIface)
|
||||
|
||||
def __init__(self, device:str=""):
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import os, ctypes, functools, mmap, struct, array, math, sys, weakref
|
||||
assert sys.platform != 'win32'
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast, ClassVar
|
||||
from typing import Any, cast
|
||||
from tinygrad.device import BufferSpec
|
||||
from tinygrad.runtime.support.hcq import HCQBuffer, HWQueue, HCQProgram, HCQCompiled, HCQAllocatorBase, HCQSignal, HCQArgsState, BumpAllocator
|
||||
from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface
|
||||
@@ -37,14 +37,12 @@ class QCOMCompiler(CLCompiler):
|
||||
def disassemble(self, lib:bytes): fromimport('extra.disassemblers.adreno', 'disasm')(lib)
|
||||
|
||||
class QCOMSignal(HCQSignal):
|
||||
def __init__(self, base_buf:HCQBuffer|None=None, **kwargs):
|
||||
super().__init__(base_buf, **kwargs, timestamp_divider=19.2, dev_t=QCOMDevice)
|
||||
def __init__(self, *args, **kwargs): super().__init__(*args, **{**kwargs, 'timestamp_divider': 19.2})
|
||||
|
||||
def _sleep(self, time_spent_waiting_ms:int):
|
||||
# Sleep only for only timeline signals. Do it immediately to free cpu.
|
||||
if self.timeline_for_device is not None:
|
||||
kgsl.IOCTL_KGSL_DEVICE_WAITTIMESTAMP_CTXTID(self.timeline_for_device.fd, context_id=self.timeline_for_device.ctx,
|
||||
timestamp=self.timeline_for_device.last_cmd, timeout=0xffffffff)
|
||||
# Sleep only for timeline signals. Do it immediately to free cpu.
|
||||
if self.is_timeline and self.owner is not None:
|
||||
kgsl.IOCTL_KGSL_DEVICE_WAITTIMESTAMP_CTXTID(self.owner.fd, context_id=self.owner.ctx, timestamp=self.owner.last_cmd, timeout=0xffffffff)
|
||||
|
||||
class QCOMComputeQueue(HWQueue):
|
||||
def __del__(self):
|
||||
@@ -315,10 +313,6 @@ class QCOMAllocator(HCQAllocatorBase):
|
||||
self.dev._gpu_free(opaque)
|
||||
|
||||
class QCOMDevice(HCQCompiled):
|
||||
devices: ClassVar[list[HCQCompiled]] = []
|
||||
signal_pages: ClassVar[list[HCQBuffer]] = []
|
||||
signal_pool: ClassVar[list[HCQBuffer]] = []
|
||||
|
||||
gpu_id: int = 0
|
||||
dummy_addr: int = 0
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import cast, Callable, Type, TypeVar, Generic, Any, ClassVar
|
||||
import contextlib, decimal, statistics, time, ctypes, array, os, fcntl, struct, traceback
|
||||
from typing import cast, Callable, Type, TypeVar, Generic, Any
|
||||
import contextlib, decimal, statistics, time, ctypes, array, os, fcntl, struct, traceback, collections
|
||||
from tinygrad.helpers import PROFILE, getenv, to_mv, round_up, ProfileRangeEvent
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileDeviceEvent, ProfileProgramEvent
|
||||
@@ -217,19 +217,17 @@ class HWQueue(Generic[SignalType, HCQDeviceType, ProgramType, ArgsStateType]):
|
||||
def _submit(self, dev:HCQDeviceType): raise NotImplementedError("need _submit")
|
||||
|
||||
class HCQSignal(Generic[HCQDeviceType]):
|
||||
def __init__(self, base_buf:HCQBuffer|None=None, value:int=0, dev_t:Type[HCQDeviceType]|None=None, timeline_for_device:HCQDeviceType|None=None,
|
||||
timestamp_divider=1, value_off=0, timestamp_off=8):
|
||||
self.base_buf = cast(HCQBuffer, dev_t._alloc_signal() if dev_t is not None and base_buf is None else base_buf)
|
||||
self.value_addr, self.timestamp_addr, self.dev_t = self.base_buf.va_addr+value_off, self.base_buf.va_addr+timestamp_off, dev_t
|
||||
def __init__(self, base_buf:HCQBuffer, value:int=0, owner:HCQDeviceType|None=None, is_timeline:bool=False, timestamp_divider=1):
|
||||
self.base_buf, self.value_addr, self.timestamp_addr, self.owner = base_buf, base_buf.va_addr+0, base_buf.va_addr+8, owner
|
||||
self.is_timeline = is_timeline
|
||||
self.timestamp_divider:decimal.Decimal = decimal.Decimal(timestamp_divider)
|
||||
self.timeline_for_device:HCQDeviceType|None = timeline_for_device
|
||||
|
||||
if isinstance(self.base_buf.va_addr, int):
|
||||
self.value_mv, self.timestamp_mv = self.base_buf.cpu_view().view(value_off, 8, 'Q'), self.base_buf.cpu_view().view(timestamp_off, 8, 'Q')
|
||||
self.value_mv, self.timestamp_mv = self.base_buf.cpu_view().view(0, 8, 'Q'), self.base_buf.cpu_view().view(8, 8, 'Q')
|
||||
self.value_mv[0] = value
|
||||
|
||||
def __del__(self):
|
||||
if isinstance(self.base_buf.va_addr, int) and self.dev_t is not None: self.dev_t.signal_pool.append(self.base_buf)
|
||||
if isinstance(self.base_buf.va_addr, int) and self.owner is not None: HCQCompiled.signal_pool[self.owner.peer_group].append(self.base_buf)
|
||||
|
||||
@property
|
||||
def value(self) -> int: return self.value_mv[0]
|
||||
@@ -270,7 +268,7 @@ class HCQSignal(Generic[HCQDeviceType]):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Callable[[], HWQueue]|None=None, queue:HWQueue|None=None):
|
||||
st, en = (dev.signal_t(), dev.signal_t()) if enabled else (None, None)
|
||||
st, en = (dev.new_signal(), dev.new_signal()) if enabled else (None, None)
|
||||
|
||||
if enabled and queue is not None: queue.timestamp(st)
|
||||
elif enabled:
|
||||
@@ -355,9 +353,9 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
||||
"""
|
||||
A base class for devices compatible with the HCQ (Hardware Command Queue) API.
|
||||
"""
|
||||
devices: ClassVar[list[HCQCompiled]] = []
|
||||
signal_pages: ClassVar[list[HCQBuffer]] = []
|
||||
signal_pool: ClassVar[list[HCQBuffer]] = []
|
||||
peer_groups: dict[str, list[HCQCompiled]] = collections.defaultdict(list)
|
||||
signal_pages: dict[str, list[HCQBuffer]] = collections.defaultdict(list) # per peer group
|
||||
signal_pool: dict[str, list[HCQBuffer]] = collections.defaultdict(list) # per peer group
|
||||
|
||||
def __init__(self, device:str, allocator:HCQAllocatorBase, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType],
|
||||
comp_queue_t:Callable[[], HWQueue], copy_queue_t:Callable[[], HWQueue]|None, kernargs_size=(16 << 20), sigalloc_size=0x1000):
|
||||
@@ -366,15 +364,17 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
|
||||
|
||||
# TODO: peer logic is determined based on device name.
|
||||
self.peer_group = device.split(":")[0]
|
||||
HCQCompiled.peer_groups[self.peer_group].append(self)
|
||||
|
||||
# Map signals if any
|
||||
for sig_page in self.signal_pages: cast(HCQAllocator, self.allocator).map(sig_page)
|
||||
self.devices.append(self)
|
||||
for sig_page in HCQCompiled.signal_pages[self.peer_group]: cast(HCQAllocator, self.allocator).map(sig_page)
|
||||
|
||||
self.sigalloc_size = sigalloc_size
|
||||
self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t
|
||||
self.timeline_value:int = 1
|
||||
self.timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
|
||||
self._shadow_timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
|
||||
self.timeline_signal, self._shadow_timeline_signal = self.new_signal(value=0, is_timeline=True), self.new_signal(value=0, is_timeline=True)
|
||||
self.sig_prof_records:list[tuple[HCQSignal, HCQSignal, str, bool]] = []
|
||||
|
||||
self.kernargs_buf:HCQBuffer = self.allocator.alloc(kernargs_size, BufferSpec(cpu_access=True))
|
||||
@@ -395,13 +395,12 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
||||
self.timeline_value += 1
|
||||
return self.timeline_value - 1
|
||||
|
||||
@classmethod
|
||||
def _alloc_signal(cls) -> HCQBuffer:
|
||||
if not cls.signal_pool:
|
||||
cls.signal_pages.append(alc:=cls.devices[0].allocator.alloc(cls.devices[0].sigalloc_size, BufferSpec(host=True,uncached=True,cpu_access=True)))
|
||||
cls.signal_pool += [alc.offset(offset=off, size=16) for off in range(0, alc.size, 16)]
|
||||
for dev in cls.devices: cast(HCQAllocator, dev.allocator).map(alc)
|
||||
return cls.signal_pool.pop()
|
||||
def new_signal(self, **kwargs) -> SignalType:
|
||||
if not HCQCompiled.signal_pool[pg:=self.peer_group]:
|
||||
HCQCompiled.signal_pages[pg].append(alc:=self.allocator.alloc(self.sigalloc_size, BufferSpec(host=True, uncached=True, cpu_access=True)))
|
||||
HCQCompiled.signal_pool[pg] += [alc.offset(offset=off, size=16) for off in range(0, alc.size, 16)]
|
||||
for dev in HCQCompiled.peer_groups[pg]: cast(HCQAllocator, dev.allocator).map(alc)
|
||||
return self.signal_t(base_buf=HCQCompiled.signal_pool[pg].pop(), owner=self, **kwargs)
|
||||
|
||||
def _at_profile_finalize(self):
|
||||
def _sync(d:HCQCompiled, q_t:Callable[[], HWQueue]):
|
||||
|
||||
Reference in New Issue
Block a user