hcq replace update with sint (#7899)

* try sym hcq

* start with amd

* move to nv

* nv works

* cache and qcom

* fixes

* signals

* fix nv

* qcom fixes

* linter

* linter

* cache + typings

* fixes

* tiny fixes

* linter

* linter

* lntr

* ugh

* comments
This commit is contained in:
nimlgen
2024-11-29 20:08:13 +03:00
committed by GitHub
parent aa51f3c14e
commit 10f431b96d
7 changed files with 261 additions and 358 deletions

View File

@@ -24,27 +24,11 @@ Each runtime should implement the required functions that are defined in the `HW
"signal",
"wait",
"timestamp",
"update_signal",
"update_wait",
"bind",
"submit",
"memory_barrier",
"exec",
"update_exec",
"copy",
"update_copy",
]
show_source: false
#### Implementing custom commands
To implement custom commands in the queue, use the @hcq_command decorator for your command implementations.
::: tinygrad.runtime.support.hcq.hcq_command
options:
members: [
"copy",
"update_copy",
]
show_source: false
@@ -141,5 +125,5 @@ your_device.timeline_signal.wait(your_device.timeline_value - 1)
## HCQGraph
[HCQGraph](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/graph/hcq.py) is a core feature that implements `GraphRunner` for HCQ-compatible devices. `HCQGraph` builds static `HWQueue` for all operations per device. To optimize enqueue time, only the necessary parts of the queues are updated for each run using the update APIs of the queues, avoiding a complete rebuild.
[HCQGraph](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/graph/hcq.py) is a core feature that implements `GraphRunner` for HCQ-compatible devices. `HCQGraph` builds static `HWQueue` for all operations per device. To optimize enqueue time, only the necessary parts of the queues are updated for each run using the symbolic variables, avoiding a complete rebuild.
Optionally, queues can implement a `bind` API, which allows further optimization by eliminating the need to copy the queues into the device ring.

View File

@@ -6,6 +6,7 @@ from tinygrad.runtime.support.hcq import HCQCompiled
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import get_runner, CompiledRunner
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
from tinygrad import Variable
MOCKGPU = getenv("MOCKGPU")
@@ -44,14 +45,19 @@ class TestHCQ(unittest.TestCase):
for queue_type in [TestHCQ.d0.hw_compute_queue_t, TestHCQ.d0.hw_copy_queue_t]:
if queue_type is None: continue
with self.subTest(name=str(queue_type)):
q = queue_type().signal(TestHCQ.d0.signal_t(), 0x1000)
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
virt_signal = TestHCQ.d0.signal_t(base_addr=Variable("sig_addr", 0, 0xffffffffffffffff, dtypes.uint64))
q.update_signal(0, signal=TestHCQ.d0.timeline_signal, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
with self.subTest(name=str(queue_type)):
q = queue_type().signal(virt_signal, virt_val)
var_vals = {virt_signal.base_addr: TestHCQ.d0.timeline_signal.base_addr, virt_val: TestHCQ.d0.timeline_value}
q.submit(TestHCQ.d0, var_vals)
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
q.update_signal(0, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
var_vals = {virt_signal.base_addr: TestHCQ.d0.timeline_signal.base_addr, virt_val: TestHCQ.d0.timeline_value}
q.submit(TestHCQ.d0, var_vals)
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -91,12 +97,15 @@ class TestHCQ(unittest.TestCase):
if queue_type is None: continue
with self.subTest(name=str(queue_type)):
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
virt_signal = TestHCQ.d0.signal_t(base_addr=Variable("sig_addr", 0, 0xffffffffffffffff, dtypes.uint64))
fake_signal = TestHCQ.d0.signal_t()
q = queue_type().wait(TestHCQ.d0.timeline_signal, 0xffffffff).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q = queue_type().wait(virt_signal, virt_val).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
fake_signal.value = 0x30
q.update_wait(0, signal=fake_signal, value=0x30).submit(TestHCQ.d0)
q.submit(TestHCQ.d0, {virt_signal.base_addr: fake_signal.base_addr, virt_val: fake_signal.value})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -112,26 +121,30 @@ class TestHCQ(unittest.TestCase):
assert val == 1.0, f"got val {val}"
def test_exec_2_kernels_100_times(self):
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
q = TestHCQ.d0.hw_compute_queue_t()
q.wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \
q.wait(TestHCQ.d0.timeline_signal, virt_val - 1) \
.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ab_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
.signal(TestHCQ.d0.timeline_signal, virt_val)
for _ in range(100):
q.update_wait(0, value=TestHCQ.d0.timeline_value - 1).update_signal(3, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
q.submit(TestHCQ.d0, {virt_val: TestHCQ.d0.timeline_value})
TestHCQ.d0.timeline_value += 1
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
assert val == 200.0, f"got val {val}"
def test_exec_update(self):
sint_global = (Variable("sint_global", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.global_size[1:])
sint_local = (Variable("sint_local", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.local_size[1:])
q = TestHCQ.d0.hw_compute_queue_t()
q.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
q.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, sint_global, sint_local) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.update_exec(0, (1,1,1), (1,1,1))
q.submit(TestHCQ.d0)
q.submit(TestHCQ.d0, {sint_global[0]: 1, sint_local[0]: 1})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -141,6 +154,9 @@ class TestHCQ(unittest.TestCase):
assert val == 0.0, f"got val {val}, should not be updated"
def test_exec_update_fuzz(self):
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
virt_local = [Variable(f"local_{i}", 0, 0xffffffff, dtypes.uint32) for i in range(3)]
a = Tensor.randint((3, 3, 3), dtype=dtypes.int, device=Device.DEFAULT).realize()
b = a + 1
si = create_schedule([b.lazydata])[-1]
@@ -156,16 +172,15 @@ class TestHCQ(unittest.TestCase):
q = TestHCQ.d0.hw_compute_queue_t()
q.memory_barrier() \
.exec(runner._prg, kernargs, (1,1,1), (1,1,1)) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
.exec(runner._prg, kernargs, (1,1,1), virt_local) \
.signal(TestHCQ.d0.timeline_signal, virt_val)
for x in range(1, 4):
for y in range(1, 4):
for z in range(1, 4):
ctypes.memset(zt._buf.va_addr, 0, zb.nbytes)
q.update_exec(1, local_size=(x,y,z)) \
.update_signal(2, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
q.submit(TestHCQ.d0, {virt_val: TestHCQ.d0.timeline_value, virt_local[0]: x, virt_local[1]: y, virt_local[2]: z})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -207,12 +222,14 @@ class TestHCQ(unittest.TestCase):
def test_update_copy(self):
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
virt_src_addr = Variable("virt_src_addr", 0, 0xffffffffffffffff, dtypes.uint64)
virt_dest_addr = Variable("virt_dest_addr", 0, 0xffffffffffffffff, dtypes.uint64)
q = TestHCQ.d0.hw_copy_queue_t().wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \
.copy(0x0, 0x0, 8) \
.copy(virt_dest_addr, virt_src_addr, 8) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.update_copy(1, dest=TestHCQ.b.lazydata.buffer._buf.va_addr, src=TestHCQ.a.lazydata.buffer._buf.va_addr) \
.submit(TestHCQ.d0)
q.submit(TestHCQ.d0, {virt_src_addr: TestHCQ.a.lazydata.buffer._buf.va_addr, virt_dest_addr: TestHCQ.b.lazydata.buffer._buf.va_addr})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -223,17 +240,19 @@ class TestHCQ(unittest.TestCase):
def test_update_copy_long(self):
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
virt_src_addr = Variable("virt_src_addr", 0, 0xffffffffffffffff, dtypes.uint64)
virt_dest_addr = Variable("virt_dest_addr", 0, 0xffffffffffffffff, dtypes.uint64)
sz = 64 << 20
buf1 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
buf2 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferSpec(host=True, nolru=True)).ensure_allocated()
ctypes.memset(buf2._buf.va_addr, 1, sz)
q = TestHCQ.d0.hw_copy_queue_t().wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \
.copy(0x0, 0x0, sz) \
.copy(virt_dest_addr, virt_src_addr, sz) \
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.update_copy(1, buf1._buf.va_addr, buf2._buf.va_addr) \
.submit(TestHCQ.d0)
q.submit(TestHCQ.d0, {virt_src_addr: buf2._buf.va_addr, virt_dest_addr: buf1._buf.va_addr})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
@@ -246,14 +265,17 @@ class TestHCQ(unittest.TestCase):
for queue_type in [TestHCQ.d0.hw_compute_queue_t, TestHCQ.d0.hw_copy_queue_t]:
if queue_type is None: continue
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
virt_signal = TestHCQ.d0.signal_t(base_addr=Variable("sig_addr", 0, 0xffffffffffffffff, dtypes.uint64))
with self.subTest(name=str(queue_type)):
fake_signal = TestHCQ.d0.signal_t()
q = queue_type().wait(TestHCQ.d0.timeline_signal, 0xffffffff).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q = queue_type().wait(virt_signal, virt_val).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.bind(TestHCQ.d0)
fake_signal.value = 0x30
q.update_wait(0, signal=fake_signal, value=0x30).submit(TestHCQ.d0)
q.submit(TestHCQ.d0, {virt_signal.base_addr: fake_signal.base_addr, virt_val: fake_signal.value})
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1

View File

@@ -3,12 +3,13 @@ from typing import List, Any, Dict, cast, Optional, Tuple, Set
from tinygrad.helpers import round_up, PROFILE, memsize_to_str
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState
from tinygrad.device import Buffer, BufferSpec, Compiled, Device
from tinygrad.ops import Variable
from tinygrad import Variable, dtypes
from tinygrad.ops import sint, Variable as VariableT
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
from tinygrad.engine.jit import MultiGraphRunner
class HCQGraph(MultiGraphRunner):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[VariableT, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
@@ -40,6 +41,7 @@ class HCQGraph(MultiGraphRunner):
self.signals: Dict[Any, HCQSignal] = {**{dev: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}}
self.kickoff_value: int = 0
self.kickoff_var = Variable("kickoff_var", 0, 0xffffffff, dtype=dtypes.uint32)
self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(jit_cache) * 2)] if PROFILE else []
self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int], Optional[Dict]]] = []
@@ -70,7 +72,7 @@ class HCQGraph(MultiGraphRunner):
# Ensure device is ready for use in current context: the graph has initialized the device and it's safe to operate on it within this graph.
for dep_queue, _ in opt_deps: dev_access[enqueue_queue].update(dev_access[dep_queue])
sync_signals = [(self.signals[d], self.kickoff_value) for b in ji.bufs if (d:=Device[cast(Buffer, b).device]) not in dev_access[enqueue_queue]]
sync_signals = [(self.signals[d], self.kickoff_var) for b in ji.bufs if (d:=Device[cast(Buffer, b).device]) not in dev_access[enqueue_queue]]
dev_access[enqueue_queue].update(cast(HCQCompiled, Device[cast(Buffer, b).device]) for b in ji.bufs)
# Remove self-dependency for compute and copy queues.
@@ -99,18 +101,21 @@ class HCQGraph(MultiGraphRunner):
last_j[enqueue_queue] = j
# Build hardware queues.
self.op_cmd_idx: Dict[int, Tuple[HWQueue, int]] = {}
self.input_replace_to_var: Dict[Tuple[int, int], VariableT] = {}
self.copy_to_devs: Dict[HCQCompiled, Set[HCQCompiled]] = {dev: set() for dev in self.devices}
self.kickoff_wait_cmds: Dict[HWQueue, List] = {q: list() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
# Create variable timeline signals for each device.
timeline_sigaddrs = {dev: Variable(f"timeline_sig_{dev.device_id}", 0, 0xffffffffffffffff, dtype=dtypes.uint64) for dev in self.devices}
self.virt_timeline_vals = {dev: Variable(f"timeline_var_{dev.device_id}", 0, 0xffffffff, dtype=dtypes.uint32) for dev in self.devices}
self.virt_timeline_signals = {dev: dev.signal_t(base_addr=timeline_sigaddrs[dev], timeline_for_device=dev) for dev in self.devices}
for dev in self.devices:
self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1) \
.wait(self.signals['CPU'], self.kickoff_value).signal(self.signals[dev], self.kickoff_value)
self.comp_queues[dev].memory_barrier().wait(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev]) \
.wait(self.signals['CPU'], self.kickoff_var).signal(self.signals[dev], self.kickoff_var)
for j,ji in enumerate(jit_cache):
enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
for i in range(len(sync_signals)): self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) + i)
for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val)
# Encode waits and start profile timestamp (if needed).
@@ -118,13 +123,14 @@ class HCQGraph(MultiGraphRunner):
# Encode main commands based on ji type.
if isinstance(ji.prg, CompiledRunner):
enqueue_queue.exec(ji.prg._prg, self.ji_args[j], *ji.prg.p.launch_dims(var_vals))
enqueue_queue.exec(ji.prg._prg, self.ji_args[j], tuple(ji.prg.p.global_size or (1,1,1)), tuple(ji.prg.p.local_size or (1,1,1)))
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
cast(HCQAllocator, Device[src.device].allocator).map(dest._buf)
enqueue_queue.copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes)
# TODO: For now sints is only for copies, should refactor to support exec as well.
enqueue_queue.copy(self._buf_addr_as_sint(j, 0, dest._buf), self._buf_addr_as_sint(j, 1, src._buf), dest.nbytes)
self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
self.op_cmd_idx[j] = (enqueue_queue, len(enqueue_queue) - 1)
# Encode finish profile timestamp (if needed).
if PROFILE and self.prof_records[j][1][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][1][0]])
@@ -135,13 +141,13 @@ class HCQGraph(MultiGraphRunner):
for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
if dep_dev in self.copy_queues: self.comp_queues[dev].wait(self.signals[(copy_q:=self.copy_queues[dep_dev])], cast(int, last_j[copy_q]) + 1)
self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value).bind(dev)
self.comp_queues[dev].signal(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev] + 1).bind(dev)
if dev in self.copy_queues: self.copy_queues[dev].bind(dev)
self.last_timeline: Dict[HCQCompiled, Tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals]
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[VariableT, int], wait=False) -> Optional[float]:
# Wait and restore signals
self.kickoff_value += 1
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
@@ -150,28 +156,21 @@ class HCQGraph(MultiGraphRunner):
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
hcq_var_vals = {self.kickoff_var: self.kickoff_value, **var_vals,
**{var: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()},
**{sig.base_addr: dev.timeline_signal.base_addr for dev, sig in self.virt_timeline_signals.items()}}
# Update rawbuffers
for (j,i),input_idx in self.input_replace.items():
if j in self.ji_args: self.ji_args[j].update_buffer(i, input_rawbuffers[input_idx]._buf)
else: self.op_cmd_idx[j][0].update_copy(self.op_cmd_idx[j][1], **{('dest' if i == 0 else 'src'): input_rawbuffers[input_idx]._buf.va_addr})
if (var:=self.input_replace_to_var.get((j,i))) is not None: hcq_var_vals[var] = input_rawbuffers[input_idx]._buf.va_addr
else: self.ji_args[j].update_buffer(i, input_rawbuffers[input_idx]._buf)
# Update var_vals
for j, i, v in self.updated_vars(var_vals): self.ji_args[j].update_var(i, v)
# Update launch dims
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
queue, cmd_ptr = self.op_cmd_idx[j]
queue.update_exec(cmd_ptr, global_dims, local_dims)
for dev in self.devices:
comp_queue, copy_queue, need_sig_upd = self.comp_queues[dev], self.copy_queues.get(dev, None), dev.timeline_signal != self.last_timeline[dev][0]
comp_queue.update_wait(1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value - 1) \
.update_wait(2, value=self.kickoff_value).update_signal(3, value=self.kickoff_value) \
.update_signal(len(comp_queue)-1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value).submit(dev)
if copy_queue is not None:
for cmd_idx in self.kickoff_wait_cmds[copy_queue]: copy_queue.update_wait(cmd_idx, value=self.kickoff_value)
copy_queue.submit(dev)
self.comp_queues[dev].submit(dev, hcq_var_vals)
if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals)
self.last_timeline[dev] = (dev.timeline_signal, dev.timeline_value)
dev.timeline_value += 1
@@ -192,6 +191,10 @@ class HCQGraph(MultiGraphRunner):
(b_st,_), (b_en,_), b_dev, _, b_is_cp, _, _ = self.prof_records[x]
dev.dep_prof_records += [(timestamps[b_st], timestamps[b_en], b_dev, b_is_cp, timestamps[st], timestamps[en], dev, is_cp)]
def _buf_addr_as_sint(self, j:int, i:int, buf:HCQBuffer) -> sint:
if (j, i) not in self.input_replace: return buf.va_addr
return self.input_replace_to_var.setdefault((j, i), Variable(f"input_{j}_{i}", 0, 0xffffffffffffffff, dtype=dtypes.uint64))
def __del__(self):
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])

View File

@@ -4,6 +4,7 @@ import os, ctypes, ctypes.util, functools, pathlib, mmap, errno, array, contextl
assert sys.platform != 'win32'
from dataclasses import dataclass
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram
from tinygrad.ops import sint
from tinygrad.device import BufferSpec
from tinygrad.helpers import getenv, to_mv, round_up, data64_le, mv_address
from tinygrad.renderer.cstyle import AMDRenderer
@@ -31,7 +32,8 @@ class AMDSignal(HCQSignal):
def __init__(self, base_addr:Optional[int]=None, **kwargs):
super().__init__(AMDDevice.signals_pool.pop() if base_addr is None else base_addr, **kwargs, timestamp_divider=100)
def __del__(self): AMDDevice.signals_pool.append(self.base_addr)
def __del__(self):
if isinstance(self.base_addr, int): AMDDevice.signals_pool.append(self.base_addr)
def _sleep(self, time_spent_waiting_ms:int):
# Resonable to sleep for long workloads (which take more than 2s) and only timeline signals.
@@ -39,10 +41,6 @@ class AMDSignal(HCQSignal):
kfd.AMDKFD_IOC_WAIT_EVENTS(AMDDevice.kfd, events_ptr=self.timeline_for_device.queue_event_arr_ptr, num_events=1, wait_for_all=1, timeout=200)
class AMDComputeQueue(HWQueue):
def __init__(self):
self.cmd_idx_to_local_offset, self.cmd_idx_to_global_offset, self.cmd_idx_to_dispatch_packet = {}, {}, {}
super().__init__()
def __del__(self):
if self.binded_device is not None:
self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferSpec(cpu_access=True, nolru=True, uncached=True))
@@ -76,22 +74,23 @@ class AMDComputeQueue(HWQueue):
self.pkt3(amd_gpu.PACKET3_RELEASE_MEM, event_dw | cache_flags_dw, memsel_dw, *data64_le(address), *data64_le(value), ctxid)
def _memory_barrier(self):
def memory_barrier(self):
self.wait_reg_mem(reg_req=nbioreg(regBIF_BX_PF1_GPU_HDP_FLUSH_REQ), reg_done=nbioreg(regBIF_BX_PF1_GPU_HDP_FLUSH_DONE), value=0xffffffff)
self.acquire_mem()
return self
def _exec(self, prg:AMDProgram, args_state:AMDArgsState, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1)):
def exec(self, prg:AMDProgram, args_state:AMDArgsState, global_size:Tuple[sint, ...], local_size:Tuple[sint, ...]):
self.acquire_mem(gli=0, gl2=0)
cmd_idx = self._cur_cmd_idx()
user_regs = [*data64_le(prg.dev.scratch.va_addr), 0xffffffff, 0xc00000] if prg.enable_private_segment_sgpr else []
if prg.enable_dispatch_ptr:
dp = hsa.hsa_kernel_dispatch_packet_t.from_address(dp_addr:=args_state.ptr + prg.kernargs_segment_size)
dp.workgroup_size_x, dp.workgroup_size_y, dp.workgroup_size_z = local_size[0], local_size[1], local_size[2]
dp.grid_size_x, dp.grid_size_y, dp.grid_size_z = global_size[0]*local_size[0], global_size[1]*local_size[1], global_size[2]*local_size[2]
self.bind_sints(*local_size, struct=dp, start_field='workgroup_size_x', fmt='H')
self.bind_sints(*[g*l for g,l in zip(global_size, local_size)], struct=dp, start_field='grid_size_x', fmt='I')
dp.group_segment_size, dp.private_segment_size, dp.kernarg_address = prg.group_segment_size, prg.private_segment_size, args_state.ptr
user_regs += [*data64_le(dp_addr)]
self.cmd_idx_to_dispatch_packet[cmd_idx] = dp
user_regs += [*data64_le(args_state.ptr)]
self.pkt3(amd_gpu.PACKET3_SET_SH_REG, gfxreg(amd_gpu.regCOMPUTE_PGM_LO), *data64_le(prg.prog_addr >> 8))
@@ -107,29 +106,22 @@ class AMDComputeQueue(HWQueue):
self.pkt3(amd_gpu.PACKET3_SET_SH_REG, gfxreg(amd_gpu.regCOMPUTE_STATIC_THREAD_MGMT_SE4), 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF)
self.pkt3(amd_gpu.PACKET3_SET_SH_REG, gfxreg(amd_gpu.regCOMPUTE_USER_DATA_0), *user_regs)
self.cmd_idx_to_local_offset[cmd_idx] = len(self._q) - self.cmds_offset[cmd_idx] + 5 # +1 to skip PACKET3_SET_SH_REG + reg + 3 zeros.
self.pkt3(amd_gpu.PACKET3_SET_SH_REG, gfxreg(amd_gpu.regCOMPUTE_START_X), 0, 0, 0, *local_size, 0, 0)
self.pkt3(amd_gpu.PACKET3_SET_SH_REG, gfxreg(amd_gpu.regCOMPUTE_RESOURCE_LIMITS), 0)
self.cmd_idx_to_global_offset[cmd_idx] = len(self._q) - self.cmds_offset[cmd_idx] + 1 # +1 to skip PACKET3_DISPATCH_DIRECT.
self.pkt3(amd_gpu.PACKET3_DISPATCH_DIRECT, *global_size, CS_W32_EN | FORCE_START_AT_000 | COMPUTE_SHADER_EN)
self.pkt3(amd_gpu.PACKET3_EVENT_WRITE, amd_gpu.EVENT_TYPE(amd_gpu.CS_PARTIAL_FLUSH) | amd_gpu.EVENT_INDEX(EVENT_INDEX_PARTIAL_FLUSH))
return self
def _update_exec(self, cmd_idx, global_size, local_size):
if local_size is not None: self._patch(cmd_idx, offset=self.cmd_idx_to_local_offset[cmd_idx], data=local_size)
if global_size is not None: self._patch(cmd_idx, offset=self.cmd_idx_to_global_offset[cmd_idx], data=global_size)
def wait(self, signal:AMDSignal, value:sint=0):
self.wait_reg_mem(mem=signal.value_addr, value=value, mask=0xffffffff)
return self
if (dp:=self.cmd_idx_to_dispatch_packet.get(cmd_idx)) is not None:
if local_size is not None: dp.workgroup_size_x, dp.workgroup_size_y, dp.workgroup_size_z = local_size[0], local_size[1], local_size[2]
if global_size is not None:
dp.grid_size_x,dp.grid_size_y,dp.grid_size_z = [g*l for g,l in zip(global_size,[dp.workgroup_size_x,dp.workgroup_size_y,dp.workgroup_size_z])]
def _wait(self, signal:AMDSignal, value=0): self.wait_reg_mem(mem=signal.value_addr, value=value, mask=0xffffffff)
def _timestamp(self, signal:AMDSignal):
def timestamp(self, signal:AMDSignal):
self.release_mem(signal.timestamp_addr, 0, amd_gpu.data_sel__mec_release_mem__send_gpu_clock_counter, amd_gpu.int_sel__mec_release_mem__none)
return self
def _signal(self, signal:AMDSignal, value=0):
def signal(self, signal:AMDSignal, value:sint=0):
# NOTE: this needs an EOP buffer on the queue or it will NULL pointer
self.release_mem(signal.value_addr, value, amd_gpu.data_sel__mec_release_mem__send_32_bit_low,
amd_gpu.int_sel__mec_release_mem__send_interrupt_after_write_confirm, cache_flush=True)
@@ -137,14 +129,7 @@ class AMDComputeQueue(HWQueue):
if (dev:=signal.timeline_for_device) is not None:
self.release_mem(dev.queue_event_mailbox_ptr, dev.queue_event.event_id, amd_gpu.data_sel__mec_release_mem__send_32_bit_low,
amd_gpu.int_sel__mec_release_mem__send_interrupt_after_write_confirm, ctxid=dev.queue_event.event_id)
def _update_wait(self, cmd_idx, signal:Optional[AMDSignal]=None, value=None):
if signal is not None: self._patch(cmd_idx, offset=2, data=data64_le(signal.value_addr))
if value is not None: self._patch(cmd_idx, offset=4, data=[value])
def _update_signal(self, cmd_idx, signal:Optional[AMDSignal]=None, value=None):
if signal is not None: self._patch(cmd_idx, offset=3, data=data64_le(signal.value_addr))
if value is not None: self._patch(cmd_idx, offset=5, data=data64_le(value))
return self
def bind(self, dev:AMDDevice):
self.binded_device = dev
@@ -155,6 +140,7 @@ class AMDComputeQueue(HWQueue):
self.indirect_cmd = [amd_gpu.PACKET3(amd_gpu.PACKET3_INDIRECT_BUFFER, 2), *data64_le(self.hw_page.va_addr),
len(self._q) | amd_gpu.INDIRECT_BUFFER_VALID]
self._q = hw_view # type: ignore
return self
def _submit(self, dev:AMDDevice):
cmds = self.indirect_cmd if dev == self.binded_device else self._q
@@ -168,16 +154,16 @@ class AMDComputeQueue(HWQueue):
SDMA_MAX_COPY_SIZE = 0x400000
class AMDCopyQueue(HWQueue):
def __init__(self):
self.internal_cmd_sizes, self.copy_cmds_per_copy = [], {}
self.internal_cmd_sizes = []
super().__init__()
def q(self, *arr):
super().q(*arr)
self.internal_cmd_sizes.append(len(arr))
def _copy(self, dest, src, copy_size):
def copy(self, dest:sint, src:sint, copy_size:int):
copied, copy_commands = 0, (copy_size + SDMA_MAX_COPY_SIZE - 1) // SDMA_MAX_COPY_SIZE
self.copy_cmds_per_copy[len(self) - 1] = copy_commands
for _ in range(copy_commands):
step_copy_size = min(copy_size - copied, SDMA_MAX_COPY_SIZE)
@@ -185,34 +171,27 @@ class AMDCopyQueue(HWQueue):
amd_gpu.SDMA_PKT_COPY_LINEAR_COUNT_COUNT(step_copy_size - 1), 0, *data64_le(src + copied), *data64_le(dest + copied))
copied += step_copy_size
return self
def _update_copy(self, cmd_idx, dest=None, src=None):
for i in range(self.copy_cmds_per_copy[cmd_idx]):
if src is not None: self._patch(cmd_idx, offset=3+i*7, data=[*data64_le(src + SDMA_MAX_COPY_SIZE*i)])
if dest is not None: self._patch(cmd_idx, offset=5+i*7, data=[*data64_le(dest + SDMA_MAX_COPY_SIZE*i)])
def _signal(self, signal:AMDSignal, value=0):
def signal(self, signal:AMDSignal, value:sint=0):
self.q(amd_gpu.SDMA_OP_FENCE | amd_gpu.SDMA_PKT_FENCE_HEADER_MTYPE(3), *data64_le(signal.value_addr), value)
if (dev:=signal.timeline_for_device) is not None:
self.q(amd_gpu.SDMA_OP_FENCE | amd_gpu.SDMA_PKT_FENCE_HEADER_MTYPE(3), *data64_le(dev.queue_event_mailbox_ptr), dev.queue_event.event_id)
self.q(amd_gpu.SDMA_OP_TRAP, amd_gpu.SDMA_PKT_TRAP_INT_CONTEXT_INT_CONTEXT(dev.queue_event.event_id))
def _wait(self, signal:AMDSignal, value=0):
return self
def wait(self, signal:AMDSignal, value:sint=0):
self.q(amd_gpu.SDMA_OP_POLL_REGMEM | amd_gpu.SDMA_PKT_POLL_REGMEM_HEADER_FUNC(WAIT_REG_MEM_FUNCTION_GEQ) | \
amd_gpu.SDMA_PKT_POLL_REGMEM_HEADER_MEM_POLL(1), *data64_le(signal.value_addr), value, 0xffffffff,
amd_gpu.SDMA_PKT_POLL_REGMEM_DW5_INTERVAL(0x04) | amd_gpu.SDMA_PKT_POLL_REGMEM_DW5_RETRY_COUNT(0xfff))
return self
def _update_signal(self, cmd_idx, signal:Optional[AMDSignal]=None, value=None):
return self._update_wait(cmd_idx, signal, value) # the same offsets and commands
def _update_wait(self, cmd_idx, signal:Optional[AMDSignal]=None, value=None):
if signal is not None: self._patch(cmd_idx, offset=1, data=data64_le(signal.value_addr))
if value is not None: self._patch(cmd_idx, offset=3, data=[value])
def _timestamp(self, signal:AMDSignal):
def timestamp(self, signal:AMDSignal):
self.q(amd_gpu.SDMA_OP_TIMESTAMP | amd_gpu.SDMA_PKT_TIMESTAMP_GET_HEADER_SUB_OP(amd_gpu.SDMA_SUBOP_TIMESTAMP_GET_GLOBAL),
*data64_le(signal.timestamp_addr))
return self
def _submit(self, dev:AMDDevice):
if dev.sdma_queue.put_value - dev.sdma_queue.read_ptr[0] > dev.sdma_queue.ring.nbytes: raise RuntimeError("SDMA queue overrun")

View File

@@ -3,8 +3,8 @@ import os, ctypes, contextlib, re, fcntl, functools, mmap, struct, array, sys
assert sys.platform != 'win32'
from typing import Tuple, List, Any, cast, Union, Dict, Type, Optional
from dataclasses import dataclass
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, hcq_command
from tinygrad.runtime.support.hcq import HCQArgsState, HCQProgram, HCQSignal
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, HCQArgsState, HCQProgram, HCQSignal
from tinygrad.ops import sint
from tinygrad.device import BufferSpec
from tinygrad.helpers import getenv, mv_address, init_c_struct_t, to_mv, round_up, data64, data64_le, DEBUG, prod
from tinygrad.renderer.ptx import PTXRenderer
@@ -77,13 +77,17 @@ class NVSignal(HCQSignal):
def __init__(self, base_addr:Optional[int]=None, **kwargs):
super().__init__(NVDevice.signals_pool.pop() if base_addr is None else base_addr, **kwargs, timestamp_divider=1000, value_off=0, timestamp_off=8)
def __del__(self): NVDevice.signals_pool.append(self.base_addr)
def __del__(self):
if isinstance(self.base_addr, int): NVDevice.signals_pool.append(self.base_addr)
class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']):
def __init__(self):
self.active_qmd = None
super().__init__()
def __del__(self):
if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferSpec(cpu_access=True, nolru=True))
@hcq_command
def setup(self, compute_class=None, copy_class=None, local_mem_window=None, shared_mem_window=None, local_mem=None, local_mem_tpc_bytes=None):
if compute_class: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_OBJECT, 1), compute_class)
if copy_class: self.q(nvmethod(4, nv_gpu.NVC6C0_SET_OBJECT, 1), copy_class)
@@ -91,16 +95,15 @@ class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']):
if shared_mem_window: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_SHADER_SHARED_MEMORY_WINDOW_A, 2), *data64(shared_mem_window))
if local_mem: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_SHADER_LOCAL_MEMORY_A, 2), *data64(local_mem))
if local_mem_tpc_bytes: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_SHADER_LOCAL_MEMORY_NON_THROTTLED_A, 3), *data64(local_mem_tpc_bytes), 0xff)
return self
def _wait(self, signal:NVSignal, value=0):
def wait(self, signal:NVSignal, value:sint=0):
self.q(nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *data64_le(signal.value_addr), *data64_le(value),
(3 << 0) | (1 << 24)) # ACQUIRE | PAYLOAD_SIZE_64BIT
self.active_qmd = None
return self
def _update_wait(self, cmd_idx, signal=None, value=None):
if signal is not None: self._q[(sigoff:=self.cmds_offset[cmd_idx]+1):sigoff+2] = array.array('I', data64_le(signal.value_addr))
if value is not None: self._q[(valoff:=self.cmds_offset[cmd_idx]+3):valoff+2] = array.array('I', data64_le(value))
def _timestamp(self, signal): return self._signal(signal, 0)
def timestamp(self, signal:NVSignal): return self.signal(signal, 0)
def bind(self, dev:NVDevice):
self.binded_device = dev
@@ -129,79 +132,63 @@ class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']):
gpfifo.put_value += 1
class NVComputeQueue(NVCommandQueue):
def __init__(self):
self.cmd_idx_to_qmd, self.cmd_idx_to_signal_id, self.cmd_idx_to_global_dims, self.cmd_idx_to_local_dims = {}, {}, {}, {}
super().__init__()
def memory_barrier(self):
self.q(nvmethod(1, nv_gpu.NVC6C0_INVALIDATE_SHADER_CACHES_NO_WFI, 1), (1 << 12) | (1 << 4) | (1 << 0))
self.active_qmd = None
return self
def _memory_barrier(self): self.q(nvmethod(1, nv_gpu.NVC6C0_INVALIDATE_SHADER_CACHES_NO_WFI, 1), (1 << 12) | (1 << 4) | (1 << 0))
def _exec(self, prg:NVProgram, args_state:NVArgsState, global_size, local_size):
def exec(self, prg:NVProgram, args_state:NVArgsState, global_size:Tuple[sint, ...], local_size:Tuple[sint, ...]):
ctypes.memmove(qmd_addr:=(args_state.ptr + round_up(prg.constbufs[0][1], 1 << 8)), ctypes.addressof(prg.qmd), 0x40 * 4)
assert qmd_addr < (1 << 40), f"large qmd addr {qmd_addr:x}"
self.cmd_idx_to_qmd[self._cur_cmd_idx()] = qmd = qmd_struct_t.from_address(qmd_addr) # Save qmd for later update
self.cmd_idx_to_global_dims[self._cur_cmd_idx()] = to_mv(qmd_addr + nv_gpu.NVC6C0_QMDV03_00_CTA_RASTER_WIDTH[1] // 8, 12).cast('I')
self.cmd_idx_to_local_dims[self._cur_cmd_idx()] = to_mv(qmd_addr + nv_gpu.NVC6C0_QMDV03_00_CTA_THREAD_DIMENSION0[1] // 8, 6).cast('H')
qmd = qmd_struct_t.from_address(qmd_addr) # Save qmd for later update
qmd.cta_raster_width, qmd.cta_raster_height, qmd.cta_raster_depth = global_size
qmd.cta_thread_dimension0, qmd.cta_thread_dimension1, qmd.cta_thread_dimension2 = local_size
self.bind_sints_to_ptr(*global_size, ptr=qmd_addr + nv_gpu.NVC6C0_QMDV03_00_CTA_RASTER_WIDTH[1] // 8, fmt='I')
self.bind_sints_to_ptr(*local_size, ptr=qmd_addr + nv_gpu.NVC6C0_QMDV03_00_CTA_THREAD_DIMENSION0[1] // 8, fmt='H')
qmd.constant_buffer_addr_upper_0, qmd.constant_buffer_addr_lower_0 = data64(args_state.ptr)
if (prev_qmd:=self.cmd_idx_to_qmd.get(self._cur_cmd_idx() - 1)) is None:
if self.active_qmd is None:
self.q(nvmethod(1, nv_gpu.NVC6C0_SEND_PCAS_A, 0x1), qmd_addr >> 8)
self.q(nvmethod(1, nv_gpu.NVC6C0_SEND_SIGNALING_PCAS2_B, 0x1), 9)
else:
prev_qmd.dependent_qmd0_pointer = qmd_addr >> 8
prev_qmd.dependent_qmd0_action = 1
prev_qmd.dependent_qmd0_prefetch = 1
prev_qmd.dependent_qmd0_enable = 1
self.active_qmd.dependent_qmd0_pointer = qmd_addr >> 8
self.active_qmd.dependent_qmd0_action = 1
self.active_qmd.dependent_qmd0_prefetch = 1
self.active_qmd.dependent_qmd0_enable = 1
def _update_exec(self, cmd_idx, global_size, local_size):
# Patch the exec cmd with new launch dims
if global_size is not None: self.cmd_idx_to_global_dims[cmd_idx][:] = array.array('I', global_size)
if local_size is not None: self.cmd_idx_to_local_dims[cmd_idx][:] = array.array('H', local_size)
self.active_qmd = qmd
return self
def _signal(self, signal:NVSignal, value=0):
if (prev_qmd:=self.cmd_idx_to_qmd.get(self._cur_cmd_idx() - 1)) is not None:
def signal(self, signal:NVSignal, value:sint=0):
if self.active_qmd is not None:
for i in range(2):
if getattr(prev_qmd, f'release{i}_enable') == 0:
setattr(prev_qmd, f'release{i}_enable', 1)
setattr(prev_qmd, f'release{i}_address', signal.value_addr)
setattr(prev_qmd, f'release{i}_payload', value)
self.cmd_idx_to_qmd[self._cur_cmd_idx()] = prev_qmd
self.cmd_idx_to_signal_id[self._cur_cmd_idx()] = i
return
if getattr(self.active_qmd, f'release{i}_enable') == 0:
setattr(self.active_qmd, f'release{i}_enable', 1)
self.bind_sints(signal.value_addr, struct=self.active_qmd, start_field=f'release{i}_address', fmt='Q', mask=0xfffffffff)
self.bind_sints(value, struct=self.active_qmd, start_field=f'release{i}_payload', fmt='Q')
return self
self.q(nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *data64_le(signal.value_addr), *data64_le(value),
(1 << 0) | (1 << 20) | (1 << 24) | (1 << 25)) # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP
self.q(nvmethod(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 1), 0x0)
self.active_qmd = None
return self
def _update_signal(self, cmd_idx, signal:Optional[NVSignal]=None, value=None):
if (qmd:=self.cmd_idx_to_qmd.get(cmd_idx)) is None: return super()._update_wait(cmd_idx, signal, value) # reuse wait, same offsets to update.
if signal is not None: setattr(qmd, f'release{self.cmd_idx_to_signal_id[cmd_idx]}_address', signal.value_addr)
if value is not None: setattr(qmd, f'release{self.cmd_idx_to_signal_id[cmd_idx]}_payload', value)
def _submit(self, dev): self._submit_to_gpfifo(dev, cast(NVDevice, dev).compute_gpfifo)
def _submit(self, dev:NVDevice): self._submit_to_gpfifo(dev, dev.compute_gpfifo)
class NVCopyQueue(NVCommandQueue):
def _copy(self, dest, src, copy_size):
def copy(self, dest:sint, src:sint, copy_size:int):
self.q(nvmethod(4, nv_gpu.NVC6B5_OFFSET_IN_UPPER, 4), *data64(src), *data64(dest))
self.q(nvmethod(4, nv_gpu.NVC6B5_LINE_LENGTH_IN, 1), copy_size)
self.q(nvmethod(4, nv_gpu.NVC6B5_LAUNCH_DMA, 1), 0x182) # TRANSFER_TYPE_NON_PIPELINED | DST_MEMORY_LAYOUT_PITCH | SRC_MEMORY_LAYOUT_PITCH
return self
def _update_copy(self, cmd_idx, dest=None, src=None):
if dest is not None: self._patch(cmd_idx, offset=3, data=data64(dest))
if src is not None: self._patch(cmd_idx, offset=1, data=data64(src))
def _signal(self, signal, value=0):
def signal(self, signal:NVSignal, value:sint=0):
self.q(nvmethod(4, nv_gpu.NVC6B5_SET_SEMAPHORE_A, 3), *data64(signal.value_addr), value)
self.q(nvmethod(4, nv_gpu.NVC6B5_LAUNCH_DMA, 1), 0x14)
return self
def _update_signal(self, cmd_idx, signal=None, value=None):
if signal is not None: self._patch(cmd_idx, offset=1, data=data64(signal.value_addr))
if value is not None: self._patch(cmd_idx, offset=3, data=[value])
def _submit(self, dev): self._submit_to_gpfifo(dev, cast(NVDevice, dev).dma_gpfifo)
def _submit(self, dev:NVDevice): self._submit_to_gpfifo(dev, dev.dma_gpfifo)
class NVArgsState(HCQArgsState):
def __init__(self, ptr:int, prg:NVProgram, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()):

View File

@@ -39,7 +39,8 @@ class QCOMSignal(HCQSignal):
def __init__(self, base_addr:Optional[int]=None, **kwargs):
super().__init__(QCOMDevice.signals_pool.pop() if base_addr is None else base_addr, **kwargs, timestamp_divider=19.2)
def __del__(self): QCOMDevice.signals_pool.append(self.base_addr)
def __del__(self):
if isinstance(self.base_addr, int): QCOMDevice.signals_pool.append(self.base_addr)
def _sleep(self, time_spent_waiting_ms:int):
# Sleep only for only timeline signals. Do it immidiately to free cpu.
@@ -48,10 +49,6 @@ class QCOMSignal(HCQSignal):
timestamp=self.timeline_for_device.last_cmd, timeout=0xffffffff)
class QCOMComputeQueue(HWQueue):
def __init__(self):
self.cmd_idx_to_dims = {}
super().__init__()
def __del__(self):
if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferSpec(cpu_access=True, nolru=True))
@@ -66,9 +63,11 @@ class QCOMComputeQueue(HWQueue):
if memsync: self.cmd(adreno.CP_WAIT_MEM_WRITES)
if sync: self.cmd(adreno.CP_WAIT_FOR_IDLE)
def _memory_barrier(self): self._cache_flush(write_back=True, invalidate=True, sync=True, memsync=True)
def memory_barrier(self):
self._cache_flush(write_back=True, invalidate=True, sync=True, memsync=True)
return self
def _signal(self, signal:QCOMSignal, value=0, ts=False):
def signal(self, signal:QCOMSignal, value=0, ts=False):
self.cmd(adreno.CP_WAIT_FOR_IDLE)
if QCOMDevice.gpu_id < 700:
self.cmd(adreno.CP_EVENT_WRITE, qreg.cp_event_write_0(event=adreno.CACHE_FLUSH_TS, timestamp=ts),
@@ -77,20 +76,14 @@ class QCOMComputeQueue(HWQueue):
else:
# TODO: support devices starting with 8 Gen 1. Also, 700th series have convenient CP_GLOBAL_TIMESTAMP and CP_LOCAL_TIMESTAMP
raise RuntimeError('CP_EVENT_WRITE7 is not supported')
return self
def _timestamp(self, signal:QCOMSignal): return self._signal(signal, 0, ts=True)
def timestamp(self, signal:QCOMSignal): return self.signal(signal, 0, ts=True)
def _wait(self, signal:QCOMSignal, value=0):
def wait(self, signal:QCOMSignal, value=0):
self.cmd(adreno.CP_WAIT_REG_MEM, qreg.cp_wait_reg_mem_0(function=adreno.WRITE_GE, poll=adreno.POLL_MEMORY),*data64_le(signal.value_addr),
qreg.cp_wait_reg_mem_3(ref=value&0xFFFFFFFF), qreg.cp_wait_reg_mem_4(mask=0xFFFFFFFF), qreg.cp_wait_reg_mem_5(delay_loop_cycles=32))
def _update_signal(self, cmd_idx, signal:Optional[QCOMSignal], value):
if signal is not None: self._patch(cmd_idx, offset=3, data=data64_le(signal.value_addr))
if value is not None: self._patch(cmd_idx, offset=5, data=[value & 0xFFFFFFFF])
def _update_wait(self, cmd_idx, signal:Optional[QCOMSignal], value):
if signal is not None: self._patch(cmd_idx, offset=2, data=data64_le(signal.value_addr))
if value is not None: self._patch(cmd_idx, offset=4, data=[value & 0xFFFFFFFF])
return self
def _build_gpu_command(self, dev:QCOMDevice, hw_addr=None):
to_mv((hw_page_addr:=hw_addr or dev._alloc_cmd_buf(len(self._q) * 4)), len(self._q) * 4).cast('I')[:] = array.array('I', self._q)
@@ -111,9 +104,9 @@ class QCOMComputeQueue(HWQueue):
else: submit_req, _ = self._build_gpu_command(dev)
dev.last_cmd = kgsl.IOCTL_KGSL_GPU_COMMAND(dev.fd, __payload=submit_req).timestamp
def _exec(self, prg:QCOMProgram, args_state:QCOMArgsState, global_size, local_size):
global_size_mp = [int(g*l) for g,l in zip(global_size, local_size)]
self.cmd_idx_to_dims[self._cur_cmd_idx()] = [global_size, local_size]
def exec(self, prg:QCOMProgram, args_state:QCOMArgsState, global_size, local_size):
def cast_int(x, ceil=False): return (math.ceil(x) if ceil else int(x)) if isinstance(x, float) else x
global_size_mp = [cast_int(g*l) for g,l in zip(global_size, local_size)]
self.cmd(adreno.CP_SET_MARKER, qreg.a6xx_cp_set_marker_0(mode=adreno.RM6_COMPUTE))
self.reg(adreno.REG_A6XX_HLSQ_INVALIDATE_CMD, qreg.a6xx_hlsq_invalidate_cmd(cs_state=True, cs_ibo=True))
@@ -129,7 +122,7 @@ class QCOMComputeQueue(HWQueue):
self.reg(adreno.REG_A6XX_HLSQ_CS_NDRANGE_0,
qreg.a6xx_hlsq_cs_ndrange_0(kerneldim=3, localsizex=local_size[0] - 1, localsizey=local_size[1] - 1, localsizez=local_size[2] - 1),
global_size_mp[0], 0, global_size_mp[1], 0, global_size_mp[2], 0, 0xccc0cf, 0xfc | qreg.a6xx_hlsq_cs_cntl_1(threadsize=adreno.THREAD64),
int(math.ceil(global_size[0])), int(math.ceil(global_size[1])), int(math.ceil(global_size[2])))
cast_int(global_size[0], ceil=True), cast_int(global_size[1], ceil=True), cast_int(global_size[2], ceil=True))
self.reg(adreno.REG_A6XX_SP_CS_CTRL_REG0,
qreg.a6xx_sp_cs_ctrl_reg0(threadsize=adreno.THREAD64, halfregfootprint=prg.hregs, fullregfootprint=prg.fregs, branchstack=prg.brnchstck),
@@ -172,19 +165,7 @@ class QCOMComputeQueue(HWQueue):
qreg.a6xx_sp_cs_config(enabled=True, nsamp=args_state.prg.samp_cnt, ntex=args_state.prg.tex_cnt, nibo=args_state.prg.ibo_cnt))
self.cmd(adreno.CP_RUN_OPENCL, 0)
self._cache_flush(write_back=True, invalidate=False, sync=False, memsync=False)
def _update_exec(self, cmd_idx, global_size, local_size):
if global_size is not None:
self._patch(cmd_idx, offset=29, data=[int(math.ceil(global_size[0])), int(math.ceil(global_size[1])), int(math.ceil(global_size[2]))])
self.cmd_idx_to_dims[cmd_idx][0] = global_size
if local_size is not None:
payload = qreg.a6xx_hlsq_cs_ndrange_0(kerneldim=3, localsizex=local_size[0] - 1, localsizey=local_size[1] - 1, localsizez=local_size[2] - 1)
self._patch(cmd_idx, offset=20, data=[payload])
self.cmd_idx_to_dims[cmd_idx][1] = local_size
global_size_mp = [int(g*l) for g,l in zip(self.cmd_idx_to_dims[cmd_idx][0], self.cmd_idx_to_dims[cmd_idx][1])]
self._patch(cmd_idx, offset=21, data=[global_size_mp[0], 0, global_size_mp[1], 0, global_size_mp[2], 0])
return self
class QCOMArgsState(HCQArgsState):
def __init__(self, ptr:int, prg:QCOMProgram, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()):

View File

@@ -1,9 +1,10 @@
from __future__ import annotations
from typing import List, Optional, Dict, Tuple, cast, Protocol, Type, Union, TypeVar, Generic, Callable, ParamSpec, Concatenate
import contextlib, decimal, statistics, random, json, atexit, time, array, ctypes, functools
from typing import List, Optional, Dict, Tuple, cast, Protocol, Type, Union, TypeVar, Generic, Any
import contextlib, decimal, statistics, random, json, atexit, time, ctypes
from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv, getenv, to_mv
from tinygrad.renderer import Renderer
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator
from tinygrad.ops import sym_infer, sint, Variable
# **************** for HCQ Compatible Devices ****************
@@ -13,69 +14,41 @@ ProgramType = TypeVar('ProgramType', bound='HCQProgram')
ArgsStateType = TypeVar('ArgsStateType', bound='HCQArgsState')
QueueType = TypeVar('QueueType', bound='HWQueue')
P = ParamSpec('P')
def hcq_command(func: Callable[Concatenate[QueueType, P], None]) -> Callable[Concatenate[QueueType, P], QueueType]:
"""
Decorator for HWCommandQueue commands. Enables command indexing and stores metadata for command updates.
For example:
```python
@hcq_command
def command_method(self, ...): ...
```
"""
@functools.wraps(func)
def __wrapper(self:QueueType, *args:P.args, **kwargs:P.kwargs) -> QueueType:
self.cmds_offset.append(len(self._q))
func(self, *args, **kwargs)
self.cmds_len.append(len(self._q) - self.cmds_offset[-1])
self.cmds_meta.append(func.__name__)
return self
return __wrapper
class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
"""
A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
Both compute and copy queues should have the following commands implemented.
"""
def __init__(self): self._q, self.binded_device, self.cmds_offset, self.cmds_len, self.cmds_meta = [], None, [], [], []
def q(self, *args) -> None: self._q.extend(args)
def __init__(self):
self._q:Any = []
self.binded_device:Optional[DeviceType] = None
self.q_sints:List[Tuple[int, int]] = []
self.mv_sints:List[Tuple[memoryview, int, int, Optional[int]]] = []
self.syms:List[sint] = []
self._prev_resolved_syms:List[Optional[int]] = []
def __len__(self): return len(self.cmds_offset)
def _patch(self, cmd_idx, offset, data): self._q[(st:=self.cmds_offset[cmd_idx]+offset):st+len(data)] = array.array('I', data)
def _cur_cmd_idx(self) -> int:
"""
Returns the index of the command currently being enqueued.
Should be called only within functions that enqueue commands and are decorated with `@hcq_command`.
"""
return len(self) - 1
def _new_sym(self, sym:sint) -> int:
if sym not in self.syms:
self.syms.append(sym)
self._prev_resolved_syms.append(None)
return self.syms.index(sym)
@hcq_command
def signal(self, signal:SignalType, value:int):
def q(self, *values):
"""
Enqueues a signal command which sets the signal to the given value, ensuring all previous operations are completed.
Enqueues values in the queue.
Args:
signal: The signal to set
value: The value to set the signal to
values: The values to enqueue in the queue.
"""
self._signal(signal, value)
def _signal(self, signal:SignalType, value:int): raise NotImplementedError("backend should overload this function")
@hcq_command
def wait(self, signal:SignalType, value:int):
"""
Enqueues a wait command which halts execution until the signal is greater than or equal to a specific value.
for v in values:
if isinstance(v, int): self._q.append(v)
else:
self.q_sints.append((len(self._q), self._new_sym(v)))
self._q.append(0xbadc0ded)
Args:
signal: The signal to wait on
value: The value to wait for
"""
self._wait(signal, value)
def _wait(self, signal, value): raise NotImplementedError("backend should overload this function")
# *** common commands ***
@hcq_command
def timestamp(self, signal:SignalType):
"""
Enqueues a timestamp command which records the current time in a signal after all previously enqueued commands are completed.
@@ -83,38 +56,56 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
Args:
signal: The signal to store the timestamp
"""
self._timestamp(signal)
def _timestamp(self, signal): raise NotImplementedError("backend should overload this function")
def update_signal(self, cmd_idx:int, signal:Optional[SignalType]=None, value:Optional[int]=None):
def signal(self, signal:SignalType, value:sint):
"""
Updates a previously queued signal command.
Enqueues a signal command which sets the signal to the given value, ensuring all previous operations are completed.
Args:
cmd_idx: Index of the signal command to update
signal: New signal to set (if None, keeps the original)
value: New value to set (if None, keeps the original)
signal: The signal to set
value: The value to set the signal to
"""
if self.cmds_meta[cmd_idx] != "signal": raise RuntimeError("called update_signal not on a signal command")
self._update_signal(cmd_idx, signal, value)
return self
def _update_signal(self, cmd_idx:int, signal:Optional[SignalType], value:Optional[int]):
raise NotImplementedError("backend should overload this function")
def update_wait(self, cmd_idx:int, signal:Optional[SignalType]=None, value:Optional[int]=None):
def wait(self, signal:SignalType, value:sint):
"""
Updates a previously queued wait command.
Enqueues a wait command which halts execution until the signal is greater than or equal to a specific value.
Args:
cmd_idx: Index of the wait command to update
signal: New signal to wait on (if None, keeps the original)
value: New value to wait for (if None, keeps the original)
signal: The signal to wait on
value: The value to wait for
"""
if self.cmds_meta[cmd_idx] != "wait": raise RuntimeError("called update_wait not on a wait command")
self._update_wait(cmd_idx, signal, value)
return self
def _update_wait(self, cmd_idx:int, signal:Optional[SignalType], value:Optional[int]):
raise NotImplementedError("backend should overload this function")
# *** commands for compute queues ***
def memory_barrier(self):
"""
Enqueues a memory barrier command to ensure memory coherence between agents. Only on compute queues.
"""
def exec(self, prg:ProgramType, args_state:ArgsStateType, global_size:Tuple[sint, ...], local_size:Tuple[sint, ...]):
"""
Enqueues an execution command for a kernel program. Only on compute queues.
Args:
prg: The program to execute
args_state: The args state to execute program with
global_size: The global work size
local_size: The local work size
"""
# *** commands for copy queues ***
def copy(self, dest:sint, src:sint, copy_size:int):
"""
Enqueues a copy command to transfer data. Only on copy queues.
Args:
dest: The destination of the copy
src: The source of the copy
copy_size: The size of data to copy
"""
# *** submit and bind commands ***
def bind(self, dev:DeviceType):
"""
@@ -130,94 +121,50 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
Implementing this method is optional but recommended for performance gains.
"""
def submit(self, dev:DeviceType):
def bind_sints(self, *vals:sint, struct:ctypes.Structure, start_field:str, fmt, mask:Optional[int]=None):
self.bind_sints_to_ptr(*vals, ptr=ctypes.addressof(struct) + getattr(type(struct), start_field).offset, fmt=fmt, mask=mask)
def bind_sints_to_ptr(self, *vals:sint, ptr:int, fmt, mask:Optional[int]=None):
mv = to_mv(ptr, 8*len(vals)).cast(fmt)
for i, val in enumerate(vals):
if isinstance(val, int): mv[i] = val if mask is None else ((mv[i] & ~mask) | val)
else: self.mv_sints.append((mv, i, self._new_sym(val), mask))
def _apply_var_vals(self, var_vals:Dict[Variable, int]):
resolved_syms = [sym_infer(sym, var_vals) for sym in self.syms]
for off, sym_idx in self.q_sints:
if self._prev_resolved_syms[sym_idx] == resolved_syms[sym_idx]: continue
self._q[off] = resolved_syms[sym_idx]
for mv, off, sym_idx, mask in self.mv_sints:
if self._prev_resolved_syms[sym_idx] == resolved_syms[sym_idx]: continue
mv[off] = resolved_syms[sym_idx] if mask is None else ((mv[off] & ~mask) | resolved_syms[sym_idx])
self._prev_resolved_syms = cast(List[Optional[int]], resolved_syms)
def submit(self, dev:DeviceType, var_vals:Optional[Dict[Variable, int]]=None):
"""
Submits the command queue to a specific device for execution.
Args:
dev: The device to submit the queue to
"""
if self._q: self._submit(dev)
if var_vals is not None: self._apply_var_vals(var_vals)
self._submit(dev)
return self
def _submit(self, dev:DeviceType): raise NotImplementedError("backend should overload this function")
# *** commands for compute queues ***
@hcq_command
def memory_barrier(self):
"""
Enqueues a memory barrier command to ensure memory coherence between agents. Only on compute queues.
"""
self._memory_barrier()
def _memory_barrier(self): pass
@hcq_command
def exec(self, prg:ProgramType, args_state:ArgsStateType, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int]):
"""
Enqueues an execution command for a kernel program. Only on compute queues.
Args:
prg: The program to execute
args_state: The args state to execute program with
global_size: The global work size
local_size: The local work size
"""
self._exec(prg, args_state, global_size, local_size)
def _exec(self, prg:ProgramType, args_state:ArgsStateType, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int]):
raise NotImplementedError("backend should overload this function")
def update_exec(self, cmd_idx:int, global_size:Optional[Tuple[int,int,int]]=None, local_size:Optional[Tuple[int,int,int]]=None):
"""
Updates a previously queued execution command. Only on compute queues.
Args:
cmd_idx: Index of the execution command to update
global_size: New global work size (if None, keeps the original)
local_size: New local work size (if None, keeps the original)
"""
if self.cmds_meta[cmd_idx] != "exec": raise RuntimeError("called update_exec not on an exec command")
self._update_exec(cmd_idx, global_size, local_size)
return self
def _update_exec(self, cmd_idx, global_size, local_size): raise NotImplementedError("backend should overload this function")
# *** commands for copy queues ***
@hcq_command
def copy(self, dest:int, src:int, copy_size:int):
"""
Enqueues a copy command to transfer data. Only on copy queues.
Args:
dest: The destination of the copy
src: The source of the copy
copy_size: The size of data to copy
"""
self._copy(dest, src, copy_size)
def _copy(self, dest:int, src:int, copy_size:int): raise NotImplementedError("backend should overload this function")
def update_copy(self, cmd_idx:int, dest:Optional[int]=None, src:Optional[int]=None):
"""
Updates a previously queued copy command. Only on copy queues.
Args:
cmd_idx: Index of the copy command to update
dest: New destination of the copy (if None, keeps the original)
src: New source of the copy (if None, keeps the original)
"""
if self.cmds_meta[cmd_idx] != "copy": raise RuntimeError("called update_copy not on an copy command")
self._update_copy(cmd_idx, dest, src)
return self
def _update_copy(self, cmd_idx:int, dest:Optional[int], src:Optional[int]):
raise NotImplementedError("backend should overload this function")
def _submit(self, dev:DeviceType): raise NotImplementedError("need _submit")
class HCQSignal(Generic[DeviceType]):
def __init__(self, base_addr:int=0, value:int=0, timeline_for_device:Optional[DeviceType]=None, timestamp_divider=1, value_off=0, timestamp_off=8):
def __init__(self, base_addr:sint=0, value:int=0, timeline_for_device:Optional[DeviceType]=None, timestamp_divider=1, value_off=0, timestamp_off=8):
self.base_addr, self.value_addr, self.timestamp_addr = base_addr, base_addr+value_off, base_addr+timestamp_off
self.timestamp_divider:decimal.Decimal = decimal.Decimal(timestamp_divider)
self.timeline_for_device:Optional[DeviceType] = timeline_for_device
self.value_mv, self.timestamp_mv = to_mv(self.value_addr, 8).cast('Q'), to_mv(self.timestamp_addr, 8).cast('Q')
self.value_mv[0] = value
if isinstance(base_addr, int):
self.value_mv, self.timestamp_mv = to_mv(self.value_addr, 8).cast('Q'), to_mv(self.timestamp_addr, 8).cast('Q')
self.value_mv[0] = value
@property
def value(self) -> int: return self.value_mv[0]