mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
hcq replace update with sint (#7899)
* try sym hcq * start with amd * move to nv * nv works * cache and qcom * fixes * signals * fix nv * qcom fixes * linter * linter * cache + typings * fixes * tiny fixes * linter * linter * lntr * ugh * comments
This commit is contained in:
@@ -24,27 +24,11 @@ Each runtime should implement the required functions that are defined in the `HW
|
||||
"signal",
|
||||
"wait",
|
||||
"timestamp",
|
||||
"update_signal",
|
||||
"update_wait",
|
||||
"bind",
|
||||
"submit",
|
||||
"memory_barrier",
|
||||
"exec",
|
||||
"update_exec",
|
||||
"copy",
|
||||
"update_copy",
|
||||
]
|
||||
show_source: false
|
||||
|
||||
#### Implementing custom commands
|
||||
|
||||
To implement custom commands in the queue, use the @hcq_command decorator for your command implementations.
|
||||
|
||||
::: tinygrad.runtime.support.hcq.hcq_command
|
||||
options:
|
||||
members: [
|
||||
"copy",
|
||||
"update_copy",
|
||||
]
|
||||
show_source: false
|
||||
|
||||
@@ -141,5 +125,5 @@ your_device.timeline_signal.wait(your_device.timeline_value - 1)
|
||||
|
||||
## HCQGraph
|
||||
|
||||
[HCQGraph](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/graph/hcq.py) is a core feature that implements `GraphRunner` for HCQ-compatible devices. `HCQGraph` builds static `HWQueue` for all operations per device. To optimize enqueue time, only the necessary parts of the queues are updated for each run using the update APIs of the queues, avoiding a complete rebuild.
|
||||
[HCQGraph](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/graph/hcq.py) is a core feature that implements `GraphRunner` for HCQ-compatible devices. `HCQGraph` builds static `HWQueue` for all operations per device. To optimize enqueue time, only the necessary parts of the queues are updated for each run using the symbolic variables, avoiding a complete rebuild.
|
||||
Optionally, queues can implement a `bind` API, which allows further optimization by eliminating the need to copy the queues into the device ring.
|
||||
|
||||
@@ -6,6 +6,7 @@ from tinygrad.runtime.support.hcq import HCQCompiled
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import get_runner, CompiledRunner
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
from tinygrad import Variable
|
||||
|
||||
MOCKGPU = getenv("MOCKGPU")
|
||||
|
||||
@@ -44,14 +45,19 @@ class TestHCQ(unittest.TestCase):
|
||||
for queue_type in [TestHCQ.d0.hw_compute_queue_t, TestHCQ.d0.hw_copy_queue_t]:
|
||||
if queue_type is None: continue
|
||||
|
||||
with self.subTest(name=str(queue_type)):
|
||||
q = queue_type().signal(TestHCQ.d0.signal_t(), 0x1000)
|
||||
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
|
||||
virt_signal = TestHCQ.d0.signal_t(base_addr=Variable("sig_addr", 0, 0xffffffffffffffff, dtypes.uint64))
|
||||
|
||||
q.update_signal(0, signal=TestHCQ.d0.timeline_signal, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
with self.subTest(name=str(queue_type)):
|
||||
q = queue_type().signal(virt_signal, virt_val)
|
||||
|
||||
var_vals = {virt_signal.base_addr: TestHCQ.d0.timeline_signal.base_addr, virt_val: TestHCQ.d0.timeline_value}
|
||||
q.submit(TestHCQ.d0, var_vals)
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
q.update_signal(0, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
var_vals = {virt_signal.base_addr: TestHCQ.d0.timeline_signal.base_addr, virt_val: TestHCQ.d0.timeline_value}
|
||||
q.submit(TestHCQ.d0, var_vals)
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
@@ -91,12 +97,15 @@ class TestHCQ(unittest.TestCase):
|
||||
if queue_type is None: continue
|
||||
|
||||
with self.subTest(name=str(queue_type)):
|
||||
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
|
||||
virt_signal = TestHCQ.d0.signal_t(base_addr=Variable("sig_addr", 0, 0xffffffffffffffff, dtypes.uint64))
|
||||
|
||||
fake_signal = TestHCQ.d0.signal_t()
|
||||
q = queue_type().wait(TestHCQ.d0.timeline_signal, 0xffffffff).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
q = queue_type().wait(virt_signal, virt_val).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
|
||||
fake_signal.value = 0x30
|
||||
|
||||
q.update_wait(0, signal=fake_signal, value=0x30).submit(TestHCQ.d0)
|
||||
q.submit(TestHCQ.d0, {virt_signal.base_addr: fake_signal.base_addr, virt_val: fake_signal.value})
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
@@ -112,26 +121,30 @@ class TestHCQ(unittest.TestCase):
|
||||
assert val == 1.0, f"got val {val}"
|
||||
|
||||
def test_exec_2_kernels_100_times(self):
|
||||
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
|
||||
|
||||
q = TestHCQ.d0.hw_compute_queue_t()
|
||||
q.wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \
|
||||
q.wait(TestHCQ.d0.timeline_signal, virt_val - 1) \
|
||||
.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
|
||||
.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ab_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
.signal(TestHCQ.d0.timeline_signal, virt_val)
|
||||
|
||||
for _ in range(100):
|
||||
q.update_wait(0, value=TestHCQ.d0.timeline_value - 1).update_signal(3, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
q.submit(TestHCQ.d0, {virt_val: TestHCQ.d0.timeline_value})
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
|
||||
assert val == 200.0, f"got val {val}"
|
||||
|
||||
def test_exec_update(self):
|
||||
sint_global = (Variable("sint_global", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.global_size[1:])
|
||||
sint_local = (Variable("sint_local", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.local_size[1:])
|
||||
|
||||
q = TestHCQ.d0.hw_compute_queue_t()
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \
|
||||
q.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, sint_global, sint_local) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
|
||||
q.update_exec(0, (1,1,1), (1,1,1))
|
||||
q.submit(TestHCQ.d0)
|
||||
q.submit(TestHCQ.d0, {sint_global[0]: 1, sint_local[0]: 1})
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
@@ -141,6 +154,9 @@ class TestHCQ(unittest.TestCase):
|
||||
assert val == 0.0, f"got val {val}, should not be updated"
|
||||
|
||||
def test_exec_update_fuzz(self):
|
||||
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
|
||||
virt_local = [Variable(f"local_{i}", 0, 0xffffffff, dtypes.uint32) for i in range(3)]
|
||||
|
||||
a = Tensor.randint((3, 3, 3), dtype=dtypes.int, device=Device.DEFAULT).realize()
|
||||
b = a + 1
|
||||
si = create_schedule([b.lazydata])[-1]
|
||||
@@ -156,16 +172,15 @@ class TestHCQ(unittest.TestCase):
|
||||
|
||||
q = TestHCQ.d0.hw_compute_queue_t()
|
||||
q.memory_barrier() \
|
||||
.exec(runner._prg, kernargs, (1,1,1), (1,1,1)) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
.exec(runner._prg, kernargs, (1,1,1), virt_local) \
|
||||
.signal(TestHCQ.d0.timeline_signal, virt_val)
|
||||
|
||||
for x in range(1, 4):
|
||||
for y in range(1, 4):
|
||||
for z in range(1, 4):
|
||||
ctypes.memset(zt._buf.va_addr, 0, zb.nbytes)
|
||||
|
||||
q.update_exec(1, local_size=(x,y,z)) \
|
||||
.update_signal(2, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
q.submit(TestHCQ.d0, {virt_val: TestHCQ.d0.timeline_value, virt_local[0]: x, virt_local[1]: y, virt_local[2]: z})
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
@@ -207,12 +222,14 @@ class TestHCQ(unittest.TestCase):
|
||||
def test_update_copy(self):
|
||||
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
|
||||
|
||||
virt_src_addr = Variable("virt_src_addr", 0, 0xffffffffffffffff, dtypes.uint64)
|
||||
virt_dest_addr = Variable("virt_dest_addr", 0, 0xffffffffffffffff, dtypes.uint64)
|
||||
|
||||
q = TestHCQ.d0.hw_copy_queue_t().wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \
|
||||
.copy(0x0, 0x0, 8) \
|
||||
.copy(virt_dest_addr, virt_src_addr, 8) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
|
||||
q.update_copy(1, dest=TestHCQ.b.lazydata.buffer._buf.va_addr, src=TestHCQ.a.lazydata.buffer._buf.va_addr) \
|
||||
.submit(TestHCQ.d0)
|
||||
q.submit(TestHCQ.d0, {virt_src_addr: TestHCQ.a.lazydata.buffer._buf.va_addr, virt_dest_addr: TestHCQ.b.lazydata.buffer._buf.va_addr})
|
||||
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
@@ -223,17 +240,19 @@ class TestHCQ(unittest.TestCase):
|
||||
def test_update_copy_long(self):
|
||||
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
|
||||
|
||||
virt_src_addr = Variable("virt_src_addr", 0, 0xffffffffffffffff, dtypes.uint64)
|
||||
virt_dest_addr = Variable("virt_dest_addr", 0, 0xffffffffffffffff, dtypes.uint64)
|
||||
|
||||
sz = 64 << 20
|
||||
buf1 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferSpec(host=True, nolru=True)).ensure_allocated()
|
||||
ctypes.memset(buf2._buf.va_addr, 1, sz)
|
||||
|
||||
q = TestHCQ.d0.hw_copy_queue_t().wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \
|
||||
.copy(0x0, 0x0, sz) \
|
||||
.copy(virt_dest_addr, virt_src_addr, sz) \
|
||||
.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
|
||||
q.update_copy(1, buf1._buf.va_addr, buf2._buf.va_addr) \
|
||||
.submit(TestHCQ.d0)
|
||||
q.submit(TestHCQ.d0, {virt_src_addr: buf2._buf.va_addr, virt_dest_addr: buf1._buf.va_addr})
|
||||
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
@@ -246,14 +265,17 @@ class TestHCQ(unittest.TestCase):
|
||||
for queue_type in [TestHCQ.d0.hw_compute_queue_t, TestHCQ.d0.hw_copy_queue_t]:
|
||||
if queue_type is None: continue
|
||||
|
||||
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
|
||||
virt_signal = TestHCQ.d0.signal_t(base_addr=Variable("sig_addr", 0, 0xffffffffffffffff, dtypes.uint64))
|
||||
|
||||
with self.subTest(name=str(queue_type)):
|
||||
fake_signal = TestHCQ.d0.signal_t()
|
||||
q = queue_type().wait(TestHCQ.d0.timeline_signal, 0xffffffff).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
q = queue_type().wait(virt_signal, virt_val).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
q.bind(TestHCQ.d0)
|
||||
|
||||
fake_signal.value = 0x30
|
||||
|
||||
q.update_wait(0, signal=fake_signal, value=0x30).submit(TestHCQ.d0)
|
||||
q.submit(TestHCQ.d0, {virt_signal.base_addr: fake_signal.base_addr, virt_val: fake_signal.value})
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
|
||||
@@ -3,12 +3,13 @@ from typing import List, Any, Dict, cast, Optional, Tuple, Set
|
||||
from tinygrad.helpers import round_up, PROFILE, memsize_to_str
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState
|
||||
from tinygrad.device import Buffer, BufferSpec, Compiled, Device
|
||||
from tinygrad.ops import Variable
|
||||
from tinygrad import Variable, dtypes
|
||||
from tinygrad.ops import sint, Variable as VariableT
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
|
||||
class HCQGraph(MultiGraphRunner):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[VariableT, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
|
||||
|
||||
@@ -40,6 +41,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
|
||||
self.signals: Dict[Any, HCQSignal] = {**{dev: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}}
|
||||
self.kickoff_value: int = 0
|
||||
self.kickoff_var = Variable("kickoff_var", 0, 0xffffffff, dtype=dtypes.uint32)
|
||||
|
||||
self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(jit_cache) * 2)] if PROFILE else []
|
||||
self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int], Optional[Dict]]] = []
|
||||
@@ -70,7 +72,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
|
||||
# Ensure device is ready for use in current context: the graph has initialized the device and it's safe to operate on it within this graph.
|
||||
for dep_queue, _ in opt_deps: dev_access[enqueue_queue].update(dev_access[dep_queue])
|
||||
sync_signals = [(self.signals[d], self.kickoff_value) for b in ji.bufs if (d:=Device[cast(Buffer, b).device]) not in dev_access[enqueue_queue]]
|
||||
sync_signals = [(self.signals[d], self.kickoff_var) for b in ji.bufs if (d:=Device[cast(Buffer, b).device]) not in dev_access[enqueue_queue]]
|
||||
dev_access[enqueue_queue].update(cast(HCQCompiled, Device[cast(Buffer, b).device]) for b in ji.bufs)
|
||||
|
||||
# Remove self-dependency for compute and copy queues.
|
||||
@@ -99,18 +101,21 @@ class HCQGraph(MultiGraphRunner):
|
||||
last_j[enqueue_queue] = j
|
||||
|
||||
# Build hardware queues.
|
||||
self.op_cmd_idx: Dict[int, Tuple[HWQueue, int]] = {}
|
||||
self.input_replace_to_var: Dict[Tuple[int, int], VariableT] = {}
|
||||
self.copy_to_devs: Dict[HCQCompiled, Set[HCQCompiled]] = {dev: set() for dev in self.devices}
|
||||
self.kickoff_wait_cmds: Dict[HWQueue, List] = {q: list() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
|
||||
|
||||
# Create variable timeline signals for each device.
|
||||
timeline_sigaddrs = {dev: Variable(f"timeline_sig_{dev.device_id}", 0, 0xffffffffffffffff, dtype=dtypes.uint64) for dev in self.devices}
|
||||
self.virt_timeline_vals = {dev: Variable(f"timeline_var_{dev.device_id}", 0, 0xffffffff, dtype=dtypes.uint32) for dev in self.devices}
|
||||
self.virt_timeline_signals = {dev: dev.signal_t(base_addr=timeline_sigaddrs[dev], timeline_for_device=dev) for dev in self.devices}
|
||||
|
||||
for dev in self.devices:
|
||||
self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1) \
|
||||
.wait(self.signals['CPU'], self.kickoff_value).signal(self.signals[dev], self.kickoff_value)
|
||||
self.comp_queues[dev].memory_barrier().wait(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev]) \
|
||||
.wait(self.signals['CPU'], self.kickoff_var).signal(self.signals[dev], self.kickoff_var)
|
||||
|
||||
for j,ji in enumerate(jit_cache):
|
||||
enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
|
||||
|
||||
for i in range(len(sync_signals)): self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) + i)
|
||||
for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val)
|
||||
|
||||
# Encode waits and start profile timestamp (if needed).
|
||||
@@ -118,13 +123,14 @@ class HCQGraph(MultiGraphRunner):
|
||||
|
||||
# Encode main commands based on ji type.
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
enqueue_queue.exec(ji.prg._prg, self.ji_args[j], *ji.prg.p.launch_dims(var_vals))
|
||||
enqueue_queue.exec(ji.prg._prg, self.ji_args[j], tuple(ji.prg.p.global_size or (1,1,1)), tuple(ji.prg.p.local_size or (1,1,1)))
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
cast(HCQAllocator, Device[src.device].allocator).map(dest._buf)
|
||||
enqueue_queue.copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes)
|
||||
|
||||
# TODO: For now sints is only for copies, should refactor to support exec as well.
|
||||
enqueue_queue.copy(self._buf_addr_as_sint(j, 0, dest._buf), self._buf_addr_as_sint(j, 1, src._buf), dest.nbytes)
|
||||
self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
|
||||
self.op_cmd_idx[j] = (enqueue_queue, len(enqueue_queue) - 1)
|
||||
|
||||
# Encode finish profile timestamp (if needed).
|
||||
if PROFILE and self.prof_records[j][1][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][1][0]])
|
||||
@@ -135,13 +141,13 @@ class HCQGraph(MultiGraphRunner):
|
||||
for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
|
||||
if dep_dev in self.copy_queues: self.comp_queues[dev].wait(self.signals[(copy_q:=self.copy_queues[dep_dev])], cast(int, last_j[copy_q]) + 1)
|
||||
|
||||
self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value).bind(dev)
|
||||
self.comp_queues[dev].signal(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev] + 1).bind(dev)
|
||||
if dev in self.copy_queues: self.copy_queues[dev].bind(dev)
|
||||
|
||||
self.last_timeline: Dict[HCQCompiled, Tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
|
||||
self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals]
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[VariableT, int], wait=False) -> Optional[float]:
|
||||
# Wait and restore signals
|
||||
self.kickoff_value += 1
|
||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||
@@ -150,28 +156,21 @@ class HCQGraph(MultiGraphRunner):
|
||||
|
||||
if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
|
||||
|
||||
hcq_var_vals = {self.kickoff_var: self.kickoff_value, **var_vals,
|
||||
**{var: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()},
|
||||
**{sig.base_addr: dev.timeline_signal.base_addr for dev, sig in self.virt_timeline_signals.items()}}
|
||||
|
||||
# Update rawbuffers
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
if j in self.ji_args: self.ji_args[j].update_buffer(i, input_rawbuffers[input_idx]._buf)
|
||||
else: self.op_cmd_idx[j][0].update_copy(self.op_cmd_idx[j][1], **{('dest' if i == 0 else 'src'): input_rawbuffers[input_idx]._buf.va_addr})
|
||||
if (var:=self.input_replace_to_var.get((j,i))) is not None: hcq_var_vals[var] = input_rawbuffers[input_idx]._buf.va_addr
|
||||
else: self.ji_args[j].update_buffer(i, input_rawbuffers[input_idx]._buf)
|
||||
|
||||
# Update var_vals
|
||||
for j, i, v in self.updated_vars(var_vals): self.ji_args[j].update_var(i, v)
|
||||
|
||||
# Update launch dims
|
||||
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
|
||||
queue, cmd_ptr = self.op_cmd_idx[j]
|
||||
queue.update_exec(cmd_ptr, global_dims, local_dims)
|
||||
|
||||
for dev in self.devices:
|
||||
comp_queue, copy_queue, need_sig_upd = self.comp_queues[dev], self.copy_queues.get(dev, None), dev.timeline_signal != self.last_timeline[dev][0]
|
||||
comp_queue.update_wait(1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value - 1) \
|
||||
.update_wait(2, value=self.kickoff_value).update_signal(3, value=self.kickoff_value) \
|
||||
.update_signal(len(comp_queue)-1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value).submit(dev)
|
||||
|
||||
if copy_queue is not None:
|
||||
for cmd_idx in self.kickoff_wait_cmds[copy_queue]: copy_queue.update_wait(cmd_idx, value=self.kickoff_value)
|
||||
copy_queue.submit(dev)
|
||||
self.comp_queues[dev].submit(dev, hcq_var_vals)
|
||||
if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals)
|
||||
|
||||
self.last_timeline[dev] = (dev.timeline_signal, dev.timeline_value)
|
||||
dev.timeline_value += 1
|
||||
@@ -192,6 +191,10 @@ class HCQGraph(MultiGraphRunner):
|
||||
(b_st,_), (b_en,_), b_dev, _, b_is_cp, _, _ = self.prof_records[x]
|
||||
dev.dep_prof_records += [(timestamps[b_st], timestamps[b_en], b_dev, b_is_cp, timestamps[st], timestamps[en], dev, is_cp)]
|
||||
|
||||
def _buf_addr_as_sint(self, j:int, i:int, buf:HCQBuffer) -> sint:
|
||||
if (j, i) not in self.input_replace: return buf.va_addr
|
||||
return self.input_replace_to_var.setdefault((j, i), Variable(f"input_{j}_{i}", 0, 0xffffffffffffffff, dtype=dtypes.uint64))
|
||||
|
||||
def __del__(self):
|
||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import os, ctypes, ctypes.util, functools, pathlib, mmap, errno, array, contextl
|
||||
assert sys.platform != 'win32'
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram
|
||||
from tinygrad.ops import sint
|
||||
from tinygrad.device import BufferSpec
|
||||
from tinygrad.helpers import getenv, to_mv, round_up, data64_le, mv_address
|
||||
from tinygrad.renderer.cstyle import AMDRenderer
|
||||
@@ -31,7 +32,8 @@ class AMDSignal(HCQSignal):
|
||||
def __init__(self, base_addr:Optional[int]=None, **kwargs):
|
||||
super().__init__(AMDDevice.signals_pool.pop() if base_addr is None else base_addr, **kwargs, timestamp_divider=100)
|
||||
|
||||
def __del__(self): AMDDevice.signals_pool.append(self.base_addr)
|
||||
def __del__(self):
|
||||
if isinstance(self.base_addr, int): AMDDevice.signals_pool.append(self.base_addr)
|
||||
|
||||
def _sleep(self, time_spent_waiting_ms:int):
|
||||
# Resonable to sleep for long workloads (which take more than 2s) and only timeline signals.
|
||||
@@ -39,10 +41,6 @@ class AMDSignal(HCQSignal):
|
||||
kfd.AMDKFD_IOC_WAIT_EVENTS(AMDDevice.kfd, events_ptr=self.timeline_for_device.queue_event_arr_ptr, num_events=1, wait_for_all=1, timeout=200)
|
||||
|
||||
class AMDComputeQueue(HWQueue):
|
||||
def __init__(self):
|
||||
self.cmd_idx_to_local_offset, self.cmd_idx_to_global_offset, self.cmd_idx_to_dispatch_packet = {}, {}, {}
|
||||
super().__init__()
|
||||
|
||||
def __del__(self):
|
||||
if self.binded_device is not None:
|
||||
self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferSpec(cpu_access=True, nolru=True, uncached=True))
|
||||
@@ -76,22 +74,23 @@ class AMDComputeQueue(HWQueue):
|
||||
|
||||
self.pkt3(amd_gpu.PACKET3_RELEASE_MEM, event_dw | cache_flags_dw, memsel_dw, *data64_le(address), *data64_le(value), ctxid)
|
||||
|
||||
def _memory_barrier(self):
|
||||
def memory_barrier(self):
|
||||
self.wait_reg_mem(reg_req=nbioreg(regBIF_BX_PF1_GPU_HDP_FLUSH_REQ), reg_done=nbioreg(regBIF_BX_PF1_GPU_HDP_FLUSH_DONE), value=0xffffffff)
|
||||
self.acquire_mem()
|
||||
return self
|
||||
|
||||
def _exec(self, prg:AMDProgram, args_state:AMDArgsState, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1)):
|
||||
def exec(self, prg:AMDProgram, args_state:AMDArgsState, global_size:Tuple[sint, ...], local_size:Tuple[sint, ...]):
|
||||
self.acquire_mem(gli=0, gl2=0)
|
||||
|
||||
cmd_idx = self._cur_cmd_idx()
|
||||
user_regs = [*data64_le(prg.dev.scratch.va_addr), 0xffffffff, 0xc00000] if prg.enable_private_segment_sgpr else []
|
||||
if prg.enable_dispatch_ptr:
|
||||
dp = hsa.hsa_kernel_dispatch_packet_t.from_address(dp_addr:=args_state.ptr + prg.kernargs_segment_size)
|
||||
dp.workgroup_size_x, dp.workgroup_size_y, dp.workgroup_size_z = local_size[0], local_size[1], local_size[2]
|
||||
dp.grid_size_x, dp.grid_size_y, dp.grid_size_z = global_size[0]*local_size[0], global_size[1]*local_size[1], global_size[2]*local_size[2]
|
||||
|
||||
self.bind_sints(*local_size, struct=dp, start_field='workgroup_size_x', fmt='H')
|
||||
self.bind_sints(*[g*l for g,l in zip(global_size, local_size)], struct=dp, start_field='grid_size_x', fmt='I')
|
||||
dp.group_segment_size, dp.private_segment_size, dp.kernarg_address = prg.group_segment_size, prg.private_segment_size, args_state.ptr
|
||||
user_regs += [*data64_le(dp_addr)]
|
||||
self.cmd_idx_to_dispatch_packet[cmd_idx] = dp
|
||||
|
||||
user_regs += [*data64_le(args_state.ptr)]
|
||||
|
||||
self.pkt3(amd_gpu.PACKET3_SET_SH_REG, gfxreg(amd_gpu.regCOMPUTE_PGM_LO), *data64_le(prg.prog_addr >> 8))
|
||||
@@ -107,29 +106,22 @@ class AMDComputeQueue(HWQueue):
|
||||
self.pkt3(amd_gpu.PACKET3_SET_SH_REG, gfxreg(amd_gpu.regCOMPUTE_STATIC_THREAD_MGMT_SE4), 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF)
|
||||
self.pkt3(amd_gpu.PACKET3_SET_SH_REG, gfxreg(amd_gpu.regCOMPUTE_USER_DATA_0), *user_regs)
|
||||
|
||||
self.cmd_idx_to_local_offset[cmd_idx] = len(self._q) - self.cmds_offset[cmd_idx] + 5 # +1 to skip PACKET3_SET_SH_REG + reg + 3 zeros.
|
||||
self.pkt3(amd_gpu.PACKET3_SET_SH_REG, gfxreg(amd_gpu.regCOMPUTE_START_X), 0, 0, 0, *local_size, 0, 0)
|
||||
self.pkt3(amd_gpu.PACKET3_SET_SH_REG, gfxreg(amd_gpu.regCOMPUTE_RESOURCE_LIMITS), 0)
|
||||
|
||||
self.cmd_idx_to_global_offset[cmd_idx] = len(self._q) - self.cmds_offset[cmd_idx] + 1 # +1 to skip PACKET3_DISPATCH_DIRECT.
|
||||
self.pkt3(amd_gpu.PACKET3_DISPATCH_DIRECT, *global_size, CS_W32_EN | FORCE_START_AT_000 | COMPUTE_SHADER_EN)
|
||||
self.pkt3(amd_gpu.PACKET3_EVENT_WRITE, amd_gpu.EVENT_TYPE(amd_gpu.CS_PARTIAL_FLUSH) | amd_gpu.EVENT_INDEX(EVENT_INDEX_PARTIAL_FLUSH))
|
||||
return self
|
||||
|
||||
def _update_exec(self, cmd_idx, global_size, local_size):
|
||||
if local_size is not None: self._patch(cmd_idx, offset=self.cmd_idx_to_local_offset[cmd_idx], data=local_size)
|
||||
if global_size is not None: self._patch(cmd_idx, offset=self.cmd_idx_to_global_offset[cmd_idx], data=global_size)
|
||||
def wait(self, signal:AMDSignal, value:sint=0):
|
||||
self.wait_reg_mem(mem=signal.value_addr, value=value, mask=0xffffffff)
|
||||
return self
|
||||
|
||||
if (dp:=self.cmd_idx_to_dispatch_packet.get(cmd_idx)) is not None:
|
||||
if local_size is not None: dp.workgroup_size_x, dp.workgroup_size_y, dp.workgroup_size_z = local_size[0], local_size[1], local_size[2]
|
||||
if global_size is not None:
|
||||
dp.grid_size_x,dp.grid_size_y,dp.grid_size_z = [g*l for g,l in zip(global_size,[dp.workgroup_size_x,dp.workgroup_size_y,dp.workgroup_size_z])]
|
||||
|
||||
def _wait(self, signal:AMDSignal, value=0): self.wait_reg_mem(mem=signal.value_addr, value=value, mask=0xffffffff)
|
||||
|
||||
def _timestamp(self, signal:AMDSignal):
|
||||
def timestamp(self, signal:AMDSignal):
|
||||
self.release_mem(signal.timestamp_addr, 0, amd_gpu.data_sel__mec_release_mem__send_gpu_clock_counter, amd_gpu.int_sel__mec_release_mem__none)
|
||||
return self
|
||||
|
||||
def _signal(self, signal:AMDSignal, value=0):
|
||||
def signal(self, signal:AMDSignal, value:sint=0):
|
||||
# NOTE: this needs an EOP buffer on the queue or it will NULL pointer
|
||||
self.release_mem(signal.value_addr, value, amd_gpu.data_sel__mec_release_mem__send_32_bit_low,
|
||||
amd_gpu.int_sel__mec_release_mem__send_interrupt_after_write_confirm, cache_flush=True)
|
||||
@@ -137,14 +129,7 @@ class AMDComputeQueue(HWQueue):
|
||||
if (dev:=signal.timeline_for_device) is not None:
|
||||
self.release_mem(dev.queue_event_mailbox_ptr, dev.queue_event.event_id, amd_gpu.data_sel__mec_release_mem__send_32_bit_low,
|
||||
amd_gpu.int_sel__mec_release_mem__send_interrupt_after_write_confirm, ctxid=dev.queue_event.event_id)
|
||||
|
||||
def _update_wait(self, cmd_idx, signal:Optional[AMDSignal]=None, value=None):
|
||||
if signal is not None: self._patch(cmd_idx, offset=2, data=data64_le(signal.value_addr))
|
||||
if value is not None: self._patch(cmd_idx, offset=4, data=[value])
|
||||
|
||||
def _update_signal(self, cmd_idx, signal:Optional[AMDSignal]=None, value=None):
|
||||
if signal is not None: self._patch(cmd_idx, offset=3, data=data64_le(signal.value_addr))
|
||||
if value is not None: self._patch(cmd_idx, offset=5, data=data64_le(value))
|
||||
return self
|
||||
|
||||
def bind(self, dev:AMDDevice):
|
||||
self.binded_device = dev
|
||||
@@ -155,6 +140,7 @@ class AMDComputeQueue(HWQueue):
|
||||
self.indirect_cmd = [amd_gpu.PACKET3(amd_gpu.PACKET3_INDIRECT_BUFFER, 2), *data64_le(self.hw_page.va_addr),
|
||||
len(self._q) | amd_gpu.INDIRECT_BUFFER_VALID]
|
||||
self._q = hw_view # type: ignore
|
||||
return self
|
||||
|
||||
def _submit(self, dev:AMDDevice):
|
||||
cmds = self.indirect_cmd if dev == self.binded_device else self._q
|
||||
@@ -168,16 +154,16 @@ class AMDComputeQueue(HWQueue):
|
||||
SDMA_MAX_COPY_SIZE = 0x400000
|
||||
class AMDCopyQueue(HWQueue):
|
||||
def __init__(self):
|
||||
self.internal_cmd_sizes, self.copy_cmds_per_copy = [], {}
|
||||
self.internal_cmd_sizes = []
|
||||
super().__init__()
|
||||
|
||||
def q(self, *arr):
|
||||
super().q(*arr)
|
||||
self.internal_cmd_sizes.append(len(arr))
|
||||
|
||||
def _copy(self, dest, src, copy_size):
|
||||
def copy(self, dest:sint, src:sint, copy_size:int):
|
||||
copied, copy_commands = 0, (copy_size + SDMA_MAX_COPY_SIZE - 1) // SDMA_MAX_COPY_SIZE
|
||||
self.copy_cmds_per_copy[len(self) - 1] = copy_commands
|
||||
|
||||
for _ in range(copy_commands):
|
||||
step_copy_size = min(copy_size - copied, SDMA_MAX_COPY_SIZE)
|
||||
|
||||
@@ -185,34 +171,27 @@ class AMDCopyQueue(HWQueue):
|
||||
amd_gpu.SDMA_PKT_COPY_LINEAR_COUNT_COUNT(step_copy_size - 1), 0, *data64_le(src + copied), *data64_le(dest + copied))
|
||||
|
||||
copied += step_copy_size
|
||||
return self
|
||||
|
||||
def _update_copy(self, cmd_idx, dest=None, src=None):
|
||||
for i in range(self.copy_cmds_per_copy[cmd_idx]):
|
||||
if src is not None: self._patch(cmd_idx, offset=3+i*7, data=[*data64_le(src + SDMA_MAX_COPY_SIZE*i)])
|
||||
if dest is not None: self._patch(cmd_idx, offset=5+i*7, data=[*data64_le(dest + SDMA_MAX_COPY_SIZE*i)])
|
||||
|
||||
def _signal(self, signal:AMDSignal, value=0):
|
||||
def signal(self, signal:AMDSignal, value:sint=0):
|
||||
self.q(amd_gpu.SDMA_OP_FENCE | amd_gpu.SDMA_PKT_FENCE_HEADER_MTYPE(3), *data64_le(signal.value_addr), value)
|
||||
|
||||
if (dev:=signal.timeline_for_device) is not None:
|
||||
self.q(amd_gpu.SDMA_OP_FENCE | amd_gpu.SDMA_PKT_FENCE_HEADER_MTYPE(3), *data64_le(dev.queue_event_mailbox_ptr), dev.queue_event.event_id)
|
||||
self.q(amd_gpu.SDMA_OP_TRAP, amd_gpu.SDMA_PKT_TRAP_INT_CONTEXT_INT_CONTEXT(dev.queue_event.event_id))
|
||||
|
||||
def _wait(self, signal:AMDSignal, value=0):
|
||||
return self
|
||||
|
||||
def wait(self, signal:AMDSignal, value:sint=0):
|
||||
self.q(amd_gpu.SDMA_OP_POLL_REGMEM | amd_gpu.SDMA_PKT_POLL_REGMEM_HEADER_FUNC(WAIT_REG_MEM_FUNCTION_GEQ) | \
|
||||
amd_gpu.SDMA_PKT_POLL_REGMEM_HEADER_MEM_POLL(1), *data64_le(signal.value_addr), value, 0xffffffff,
|
||||
amd_gpu.SDMA_PKT_POLL_REGMEM_DW5_INTERVAL(0x04) | amd_gpu.SDMA_PKT_POLL_REGMEM_DW5_RETRY_COUNT(0xfff))
|
||||
return self
|
||||
|
||||
def _update_signal(self, cmd_idx, signal:Optional[AMDSignal]=None, value=None):
|
||||
return self._update_wait(cmd_idx, signal, value) # the same offsets and commands
|
||||
|
||||
def _update_wait(self, cmd_idx, signal:Optional[AMDSignal]=None, value=None):
|
||||
if signal is not None: self._patch(cmd_idx, offset=1, data=data64_le(signal.value_addr))
|
||||
if value is not None: self._patch(cmd_idx, offset=3, data=[value])
|
||||
|
||||
def _timestamp(self, signal:AMDSignal):
|
||||
def timestamp(self, signal:AMDSignal):
|
||||
self.q(amd_gpu.SDMA_OP_TIMESTAMP | amd_gpu.SDMA_PKT_TIMESTAMP_GET_HEADER_SUB_OP(amd_gpu.SDMA_SUBOP_TIMESTAMP_GET_GLOBAL),
|
||||
*data64_le(signal.timestamp_addr))
|
||||
return self
|
||||
|
||||
def _submit(self, dev:AMDDevice):
|
||||
if dev.sdma_queue.put_value - dev.sdma_queue.read_ptr[0] > dev.sdma_queue.ring.nbytes: raise RuntimeError("SDMA queue overrun")
|
||||
|
||||
@@ -3,8 +3,8 @@ import os, ctypes, contextlib, re, fcntl, functools, mmap, struct, array, sys
|
||||
assert sys.platform != 'win32'
|
||||
from typing import Tuple, List, Any, cast, Union, Dict, Type, Optional
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, hcq_command
|
||||
from tinygrad.runtime.support.hcq import HCQArgsState, HCQProgram, HCQSignal
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, HCQArgsState, HCQProgram, HCQSignal
|
||||
from tinygrad.ops import sint
|
||||
from tinygrad.device import BufferSpec
|
||||
from tinygrad.helpers import getenv, mv_address, init_c_struct_t, to_mv, round_up, data64, data64_le, DEBUG, prod
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
@@ -77,13 +77,17 @@ class NVSignal(HCQSignal):
|
||||
def __init__(self, base_addr:Optional[int]=None, **kwargs):
|
||||
super().__init__(NVDevice.signals_pool.pop() if base_addr is None else base_addr, **kwargs, timestamp_divider=1000, value_off=0, timestamp_off=8)
|
||||
|
||||
def __del__(self): NVDevice.signals_pool.append(self.base_addr)
|
||||
def __del__(self):
|
||||
if isinstance(self.base_addr, int): NVDevice.signals_pool.append(self.base_addr)
|
||||
|
||||
class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']):
|
||||
def __init__(self):
|
||||
self.active_qmd = None
|
||||
super().__init__()
|
||||
|
||||
def __del__(self):
|
||||
if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferSpec(cpu_access=True, nolru=True))
|
||||
|
||||
@hcq_command
|
||||
def setup(self, compute_class=None, copy_class=None, local_mem_window=None, shared_mem_window=None, local_mem=None, local_mem_tpc_bytes=None):
|
||||
if compute_class: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_OBJECT, 1), compute_class)
|
||||
if copy_class: self.q(nvmethod(4, nv_gpu.NVC6C0_SET_OBJECT, 1), copy_class)
|
||||
@@ -91,16 +95,15 @@ class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']):
|
||||
if shared_mem_window: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_SHADER_SHARED_MEMORY_WINDOW_A, 2), *data64(shared_mem_window))
|
||||
if local_mem: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_SHADER_LOCAL_MEMORY_A, 2), *data64(local_mem))
|
||||
if local_mem_tpc_bytes: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_SHADER_LOCAL_MEMORY_NON_THROTTLED_A, 3), *data64(local_mem_tpc_bytes), 0xff)
|
||||
return self
|
||||
|
||||
def _wait(self, signal:NVSignal, value=0):
|
||||
def wait(self, signal:NVSignal, value:sint=0):
|
||||
self.q(nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *data64_le(signal.value_addr), *data64_le(value),
|
||||
(3 << 0) | (1 << 24)) # ACQUIRE | PAYLOAD_SIZE_64BIT
|
||||
self.active_qmd = None
|
||||
return self
|
||||
|
||||
def _update_wait(self, cmd_idx, signal=None, value=None):
|
||||
if signal is not None: self._q[(sigoff:=self.cmds_offset[cmd_idx]+1):sigoff+2] = array.array('I', data64_le(signal.value_addr))
|
||||
if value is not None: self._q[(valoff:=self.cmds_offset[cmd_idx]+3):valoff+2] = array.array('I', data64_le(value))
|
||||
|
||||
def _timestamp(self, signal): return self._signal(signal, 0)
|
||||
def timestamp(self, signal:NVSignal): return self.signal(signal, 0)
|
||||
|
||||
def bind(self, dev:NVDevice):
|
||||
self.binded_device = dev
|
||||
@@ -129,79 +132,63 @@ class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']):
|
||||
gpfifo.put_value += 1
|
||||
|
||||
class NVComputeQueue(NVCommandQueue):
|
||||
def __init__(self):
|
||||
self.cmd_idx_to_qmd, self.cmd_idx_to_signal_id, self.cmd_idx_to_global_dims, self.cmd_idx_to_local_dims = {}, {}, {}, {}
|
||||
super().__init__()
|
||||
def memory_barrier(self):
|
||||
self.q(nvmethod(1, nv_gpu.NVC6C0_INVALIDATE_SHADER_CACHES_NO_WFI, 1), (1 << 12) | (1 << 4) | (1 << 0))
|
||||
self.active_qmd = None
|
||||
return self
|
||||
|
||||
def _memory_barrier(self): self.q(nvmethod(1, nv_gpu.NVC6C0_INVALIDATE_SHADER_CACHES_NO_WFI, 1), (1 << 12) | (1 << 4) | (1 << 0))
|
||||
|
||||
def _exec(self, prg:NVProgram, args_state:NVArgsState, global_size, local_size):
|
||||
def exec(self, prg:NVProgram, args_state:NVArgsState, global_size:Tuple[sint, ...], local_size:Tuple[sint, ...]):
|
||||
ctypes.memmove(qmd_addr:=(args_state.ptr + round_up(prg.constbufs[0][1], 1 << 8)), ctypes.addressof(prg.qmd), 0x40 * 4)
|
||||
assert qmd_addr < (1 << 40), f"large qmd addr {qmd_addr:x}"
|
||||
|
||||
self.cmd_idx_to_qmd[self._cur_cmd_idx()] = qmd = qmd_struct_t.from_address(qmd_addr) # Save qmd for later update
|
||||
self.cmd_idx_to_global_dims[self._cur_cmd_idx()] = to_mv(qmd_addr + nv_gpu.NVC6C0_QMDV03_00_CTA_RASTER_WIDTH[1] // 8, 12).cast('I')
|
||||
self.cmd_idx_to_local_dims[self._cur_cmd_idx()] = to_mv(qmd_addr + nv_gpu.NVC6C0_QMDV03_00_CTA_THREAD_DIMENSION0[1] // 8, 6).cast('H')
|
||||
qmd = qmd_struct_t.from_address(qmd_addr) # Save qmd for later update
|
||||
|
||||
qmd.cta_raster_width, qmd.cta_raster_height, qmd.cta_raster_depth = global_size
|
||||
qmd.cta_thread_dimension0, qmd.cta_thread_dimension1, qmd.cta_thread_dimension2 = local_size
|
||||
self.bind_sints_to_ptr(*global_size, ptr=qmd_addr + nv_gpu.NVC6C0_QMDV03_00_CTA_RASTER_WIDTH[1] // 8, fmt='I')
|
||||
self.bind_sints_to_ptr(*local_size, ptr=qmd_addr + nv_gpu.NVC6C0_QMDV03_00_CTA_THREAD_DIMENSION0[1] // 8, fmt='H')
|
||||
qmd.constant_buffer_addr_upper_0, qmd.constant_buffer_addr_lower_0 = data64(args_state.ptr)
|
||||
|
||||
if (prev_qmd:=self.cmd_idx_to_qmd.get(self._cur_cmd_idx() - 1)) is None:
|
||||
if self.active_qmd is None:
|
||||
self.q(nvmethod(1, nv_gpu.NVC6C0_SEND_PCAS_A, 0x1), qmd_addr >> 8)
|
||||
self.q(nvmethod(1, nv_gpu.NVC6C0_SEND_SIGNALING_PCAS2_B, 0x1), 9)
|
||||
else:
|
||||
prev_qmd.dependent_qmd0_pointer = qmd_addr >> 8
|
||||
prev_qmd.dependent_qmd0_action = 1
|
||||
prev_qmd.dependent_qmd0_prefetch = 1
|
||||
prev_qmd.dependent_qmd0_enable = 1
|
||||
self.active_qmd.dependent_qmd0_pointer = qmd_addr >> 8
|
||||
self.active_qmd.dependent_qmd0_action = 1
|
||||
self.active_qmd.dependent_qmd0_prefetch = 1
|
||||
self.active_qmd.dependent_qmd0_enable = 1
|
||||
|
||||
def _update_exec(self, cmd_idx, global_size, local_size):
|
||||
# Patch the exec cmd with new launch dims
|
||||
if global_size is not None: self.cmd_idx_to_global_dims[cmd_idx][:] = array.array('I', global_size)
|
||||
if local_size is not None: self.cmd_idx_to_local_dims[cmd_idx][:] = array.array('H', local_size)
|
||||
self.active_qmd = qmd
|
||||
return self
|
||||
|
||||
def _signal(self, signal:NVSignal, value=0):
|
||||
if (prev_qmd:=self.cmd_idx_to_qmd.get(self._cur_cmd_idx() - 1)) is not None:
|
||||
def signal(self, signal:NVSignal, value:sint=0):
|
||||
if self.active_qmd is not None:
|
||||
for i in range(2):
|
||||
if getattr(prev_qmd, f'release{i}_enable') == 0:
|
||||
setattr(prev_qmd, f'release{i}_enable', 1)
|
||||
setattr(prev_qmd, f'release{i}_address', signal.value_addr)
|
||||
setattr(prev_qmd, f'release{i}_payload', value)
|
||||
self.cmd_idx_to_qmd[self._cur_cmd_idx()] = prev_qmd
|
||||
self.cmd_idx_to_signal_id[self._cur_cmd_idx()] = i
|
||||
return
|
||||
if getattr(self.active_qmd, f'release{i}_enable') == 0:
|
||||
setattr(self.active_qmd, f'release{i}_enable', 1)
|
||||
self.bind_sints(signal.value_addr, struct=self.active_qmd, start_field=f'release{i}_address', fmt='Q', mask=0xfffffffff)
|
||||
self.bind_sints(value, struct=self.active_qmd, start_field=f'release{i}_payload', fmt='Q')
|
||||
return self
|
||||
|
||||
self.q(nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *data64_le(signal.value_addr), *data64_le(value),
|
||||
(1 << 0) | (1 << 20) | (1 << 24) | (1 << 25)) # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP
|
||||
self.q(nvmethod(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 1), 0x0)
|
||||
self.active_qmd = None
|
||||
return self
|
||||
|
||||
def _update_signal(self, cmd_idx, signal:Optional[NVSignal]=None, value=None):
|
||||
if (qmd:=self.cmd_idx_to_qmd.get(cmd_idx)) is None: return super()._update_wait(cmd_idx, signal, value) # reuse wait, same offsets to update.
|
||||
if signal is not None: setattr(qmd, f'release{self.cmd_idx_to_signal_id[cmd_idx]}_address', signal.value_addr)
|
||||
if value is not None: setattr(qmd, f'release{self.cmd_idx_to_signal_id[cmd_idx]}_payload', value)
|
||||
|
||||
def _submit(self, dev): self._submit_to_gpfifo(dev, cast(NVDevice, dev).compute_gpfifo)
|
||||
def _submit(self, dev:NVDevice): self._submit_to_gpfifo(dev, dev.compute_gpfifo)
|
||||
|
||||
class NVCopyQueue(NVCommandQueue):
|
||||
def _copy(self, dest, src, copy_size):
|
||||
def copy(self, dest:sint, src:sint, copy_size:int):
|
||||
self.q(nvmethod(4, nv_gpu.NVC6B5_OFFSET_IN_UPPER, 4), *data64(src), *data64(dest))
|
||||
self.q(nvmethod(4, nv_gpu.NVC6B5_LINE_LENGTH_IN, 1), copy_size)
|
||||
self.q(nvmethod(4, nv_gpu.NVC6B5_LAUNCH_DMA, 1), 0x182) # TRANSFER_TYPE_NON_PIPELINED | DST_MEMORY_LAYOUT_PITCH | SRC_MEMORY_LAYOUT_PITCH
|
||||
return self
|
||||
|
||||
def _update_copy(self, cmd_idx, dest=None, src=None):
|
||||
if dest is not None: self._patch(cmd_idx, offset=3, data=data64(dest))
|
||||
if src is not None: self._patch(cmd_idx, offset=1, data=data64(src))
|
||||
|
||||
def _signal(self, signal, value=0):
|
||||
def signal(self, signal:NVSignal, value:sint=0):
|
||||
self.q(nvmethod(4, nv_gpu.NVC6B5_SET_SEMAPHORE_A, 3), *data64(signal.value_addr), value)
|
||||
self.q(nvmethod(4, nv_gpu.NVC6B5_LAUNCH_DMA, 1), 0x14)
|
||||
return self
|
||||
|
||||
def _update_signal(self, cmd_idx, signal=None, value=None):
|
||||
if signal is not None: self._patch(cmd_idx, offset=1, data=data64(signal.value_addr))
|
||||
if value is not None: self._patch(cmd_idx, offset=3, data=[value])
|
||||
|
||||
def _submit(self, dev): self._submit_to_gpfifo(dev, cast(NVDevice, dev).dma_gpfifo)
|
||||
def _submit(self, dev:NVDevice): self._submit_to_gpfifo(dev, dev.dma_gpfifo)
|
||||
|
||||
class NVArgsState(HCQArgsState):
|
||||
def __init__(self, ptr:int, prg:NVProgram, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()):
|
||||
|
||||
@@ -39,7 +39,8 @@ class QCOMSignal(HCQSignal):
|
||||
def __init__(self, base_addr:Optional[int]=None, **kwargs):
|
||||
super().__init__(QCOMDevice.signals_pool.pop() if base_addr is None else base_addr, **kwargs, timestamp_divider=19.2)
|
||||
|
||||
def __del__(self): QCOMDevice.signals_pool.append(self.base_addr)
|
||||
def __del__(self):
|
||||
if isinstance(self.base_addr, int): QCOMDevice.signals_pool.append(self.base_addr)
|
||||
|
||||
def _sleep(self, time_spent_waiting_ms:int):
|
||||
# Sleep only for only timeline signals. Do it immidiately to free cpu.
|
||||
@@ -48,10 +49,6 @@ class QCOMSignal(HCQSignal):
|
||||
timestamp=self.timeline_for_device.last_cmd, timeout=0xffffffff)
|
||||
|
||||
class QCOMComputeQueue(HWQueue):
|
||||
def __init__(self):
|
||||
self.cmd_idx_to_dims = {}
|
||||
super().__init__()
|
||||
|
||||
def __del__(self):
|
||||
if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferSpec(cpu_access=True, nolru=True))
|
||||
|
||||
@@ -66,9 +63,11 @@ class QCOMComputeQueue(HWQueue):
|
||||
if memsync: self.cmd(adreno.CP_WAIT_MEM_WRITES)
|
||||
if sync: self.cmd(adreno.CP_WAIT_FOR_IDLE)
|
||||
|
||||
def _memory_barrier(self): self._cache_flush(write_back=True, invalidate=True, sync=True, memsync=True)
|
||||
def memory_barrier(self):
|
||||
self._cache_flush(write_back=True, invalidate=True, sync=True, memsync=True)
|
||||
return self
|
||||
|
||||
def _signal(self, signal:QCOMSignal, value=0, ts=False):
|
||||
def signal(self, signal:QCOMSignal, value=0, ts=False):
|
||||
self.cmd(adreno.CP_WAIT_FOR_IDLE)
|
||||
if QCOMDevice.gpu_id < 700:
|
||||
self.cmd(adreno.CP_EVENT_WRITE, qreg.cp_event_write_0(event=adreno.CACHE_FLUSH_TS, timestamp=ts),
|
||||
@@ -77,20 +76,14 @@ class QCOMComputeQueue(HWQueue):
|
||||
else:
|
||||
# TODO: support devices starting with 8 Gen 1. Also, 700th series have convenient CP_GLOBAL_TIMESTAMP and CP_LOCAL_TIMESTAMP
|
||||
raise RuntimeError('CP_EVENT_WRITE7 is not supported')
|
||||
return self
|
||||
|
||||
def _timestamp(self, signal:QCOMSignal): return self._signal(signal, 0, ts=True)
|
||||
def timestamp(self, signal:QCOMSignal): return self.signal(signal, 0, ts=True)
|
||||
|
||||
def _wait(self, signal:QCOMSignal, value=0):
|
||||
def wait(self, signal:QCOMSignal, value=0):
|
||||
self.cmd(adreno.CP_WAIT_REG_MEM, qreg.cp_wait_reg_mem_0(function=adreno.WRITE_GE, poll=adreno.POLL_MEMORY),*data64_le(signal.value_addr),
|
||||
qreg.cp_wait_reg_mem_3(ref=value&0xFFFFFFFF), qreg.cp_wait_reg_mem_4(mask=0xFFFFFFFF), qreg.cp_wait_reg_mem_5(delay_loop_cycles=32))
|
||||
|
||||
def _update_signal(self, cmd_idx, signal:Optional[QCOMSignal], value):
|
||||
if signal is not None: self._patch(cmd_idx, offset=3, data=data64_le(signal.value_addr))
|
||||
if value is not None: self._patch(cmd_idx, offset=5, data=[value & 0xFFFFFFFF])
|
||||
|
||||
def _update_wait(self, cmd_idx, signal:Optional[QCOMSignal], value):
|
||||
if signal is not None: self._patch(cmd_idx, offset=2, data=data64_le(signal.value_addr))
|
||||
if value is not None: self._patch(cmd_idx, offset=4, data=[value & 0xFFFFFFFF])
|
||||
return self
|
||||
|
||||
def _build_gpu_command(self, dev:QCOMDevice, hw_addr=None):
|
||||
to_mv((hw_page_addr:=hw_addr or dev._alloc_cmd_buf(len(self._q) * 4)), len(self._q) * 4).cast('I')[:] = array.array('I', self._q)
|
||||
@@ -111,9 +104,9 @@ class QCOMComputeQueue(HWQueue):
|
||||
else: submit_req, _ = self._build_gpu_command(dev)
|
||||
dev.last_cmd = kgsl.IOCTL_KGSL_GPU_COMMAND(dev.fd, __payload=submit_req).timestamp
|
||||
|
||||
def _exec(self, prg:QCOMProgram, args_state:QCOMArgsState, global_size, local_size):
|
||||
global_size_mp = [int(g*l) for g,l in zip(global_size, local_size)]
|
||||
self.cmd_idx_to_dims[self._cur_cmd_idx()] = [global_size, local_size]
|
||||
def exec(self, prg:QCOMProgram, args_state:QCOMArgsState, global_size, local_size):
|
||||
def cast_int(x, ceil=False): return (math.ceil(x) if ceil else int(x)) if isinstance(x, float) else x
|
||||
global_size_mp = [cast_int(g*l) for g,l in zip(global_size, local_size)]
|
||||
|
||||
self.cmd(adreno.CP_SET_MARKER, qreg.a6xx_cp_set_marker_0(mode=adreno.RM6_COMPUTE))
|
||||
self.reg(adreno.REG_A6XX_HLSQ_INVALIDATE_CMD, qreg.a6xx_hlsq_invalidate_cmd(cs_state=True, cs_ibo=True))
|
||||
@@ -129,7 +122,7 @@ class QCOMComputeQueue(HWQueue):
|
||||
self.reg(adreno.REG_A6XX_HLSQ_CS_NDRANGE_0,
|
||||
qreg.a6xx_hlsq_cs_ndrange_0(kerneldim=3, localsizex=local_size[0] - 1, localsizey=local_size[1] - 1, localsizez=local_size[2] - 1),
|
||||
global_size_mp[0], 0, global_size_mp[1], 0, global_size_mp[2], 0, 0xccc0cf, 0xfc | qreg.a6xx_hlsq_cs_cntl_1(threadsize=adreno.THREAD64),
|
||||
int(math.ceil(global_size[0])), int(math.ceil(global_size[1])), int(math.ceil(global_size[2])))
|
||||
cast_int(global_size[0], ceil=True), cast_int(global_size[1], ceil=True), cast_int(global_size[2], ceil=True))
|
||||
|
||||
self.reg(adreno.REG_A6XX_SP_CS_CTRL_REG0,
|
||||
qreg.a6xx_sp_cs_ctrl_reg0(threadsize=adreno.THREAD64, halfregfootprint=prg.hregs, fullregfootprint=prg.fregs, branchstack=prg.brnchstck),
|
||||
@@ -172,19 +165,7 @@ class QCOMComputeQueue(HWQueue):
|
||||
qreg.a6xx_sp_cs_config(enabled=True, nsamp=args_state.prg.samp_cnt, ntex=args_state.prg.tex_cnt, nibo=args_state.prg.ibo_cnt))
|
||||
self.cmd(adreno.CP_RUN_OPENCL, 0)
|
||||
self._cache_flush(write_back=True, invalidate=False, sync=False, memsync=False)
|
||||
|
||||
def _update_exec(self, cmd_idx, global_size, local_size):
|
||||
if global_size is not None:
|
||||
self._patch(cmd_idx, offset=29, data=[int(math.ceil(global_size[0])), int(math.ceil(global_size[1])), int(math.ceil(global_size[2]))])
|
||||
self.cmd_idx_to_dims[cmd_idx][0] = global_size
|
||||
|
||||
if local_size is not None:
|
||||
payload = qreg.a6xx_hlsq_cs_ndrange_0(kerneldim=3, localsizex=local_size[0] - 1, localsizey=local_size[1] - 1, localsizez=local_size[2] - 1)
|
||||
self._patch(cmd_idx, offset=20, data=[payload])
|
||||
self.cmd_idx_to_dims[cmd_idx][1] = local_size
|
||||
|
||||
global_size_mp = [int(g*l) for g,l in zip(self.cmd_idx_to_dims[cmd_idx][0], self.cmd_idx_to_dims[cmd_idx][1])]
|
||||
self._patch(cmd_idx, offset=21, data=[global_size_mp[0], 0, global_size_mp[1], 0, global_size_mp[2], 0])
|
||||
return self
|
||||
|
||||
class QCOMArgsState(HCQArgsState):
|
||||
def __init__(self, ptr:int, prg:QCOMProgram, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()):
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Optional, Dict, Tuple, cast, Protocol, Type, Union, TypeVar, Generic, Callable, ParamSpec, Concatenate
|
||||
import contextlib, decimal, statistics, random, json, atexit, time, array, ctypes, functools
|
||||
from typing import List, Optional, Dict, Tuple, cast, Protocol, Type, Union, TypeVar, Generic, Any
|
||||
import contextlib, decimal, statistics, random, json, atexit, time, ctypes
|
||||
from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv, getenv, to_mv
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator
|
||||
from tinygrad.ops import sym_infer, sint, Variable
|
||||
|
||||
# **************** for HCQ Compatible Devices ****************
|
||||
|
||||
@@ -13,69 +14,41 @@ ProgramType = TypeVar('ProgramType', bound='HCQProgram')
|
||||
ArgsStateType = TypeVar('ArgsStateType', bound='HCQArgsState')
|
||||
QueueType = TypeVar('QueueType', bound='HWQueue')
|
||||
|
||||
P = ParamSpec('P')
|
||||
def hcq_command(func: Callable[Concatenate[QueueType, P], None]) -> Callable[Concatenate[QueueType, P], QueueType]:
|
||||
"""
|
||||
Decorator for HWCommandQueue commands. Enables command indexing and stores metadata for command updates.
|
||||
|
||||
For example:
|
||||
```python
|
||||
@hcq_command
|
||||
def command_method(self, ...): ...
|
||||
```
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def __wrapper(self:QueueType, *args:P.args, **kwargs:P.kwargs) -> QueueType:
|
||||
self.cmds_offset.append(len(self._q))
|
||||
func(self, *args, **kwargs)
|
||||
self.cmds_len.append(len(self._q) - self.cmds_offset[-1])
|
||||
self.cmds_meta.append(func.__name__)
|
||||
return self
|
||||
return __wrapper
|
||||
|
||||
class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
|
||||
"""
|
||||
A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
|
||||
Both compute and copy queues should have the following commands implemented.
|
||||
"""
|
||||
|
||||
def __init__(self): self._q, self.binded_device, self.cmds_offset, self.cmds_len, self.cmds_meta = [], None, [], [], []
|
||||
def q(self, *args) -> None: self._q.extend(args)
|
||||
def __init__(self):
|
||||
self._q:Any = []
|
||||
self.binded_device:Optional[DeviceType] = None
|
||||
self.q_sints:List[Tuple[int, int]] = []
|
||||
self.mv_sints:List[Tuple[memoryview, int, int, Optional[int]]] = []
|
||||
self.syms:List[sint] = []
|
||||
self._prev_resolved_syms:List[Optional[int]] = []
|
||||
|
||||
def __len__(self): return len(self.cmds_offset)
|
||||
def _patch(self, cmd_idx, offset, data): self._q[(st:=self.cmds_offset[cmd_idx]+offset):st+len(data)] = array.array('I', data)
|
||||
def _cur_cmd_idx(self) -> int:
|
||||
"""
|
||||
Returns the index of the command currently being enqueued.
|
||||
Should be called only within functions that enqueue commands and are decorated with `@hcq_command`.
|
||||
"""
|
||||
return len(self) - 1
|
||||
def _new_sym(self, sym:sint) -> int:
|
||||
if sym not in self.syms:
|
||||
self.syms.append(sym)
|
||||
self._prev_resolved_syms.append(None)
|
||||
return self.syms.index(sym)
|
||||
|
||||
@hcq_command
|
||||
def signal(self, signal:SignalType, value:int):
|
||||
def q(self, *values):
|
||||
"""
|
||||
Enqueues a signal command which sets the signal to the given value, ensuring all previous operations are completed.
|
||||
Enqueues values in the queue.
|
||||
|
||||
Args:
|
||||
signal: The signal to set
|
||||
value: The value to set the signal to
|
||||
values: The values to enqueue in the queue.
|
||||
"""
|
||||
self._signal(signal, value)
|
||||
def _signal(self, signal:SignalType, value:int): raise NotImplementedError("backend should overload this function")
|
||||
|
||||
@hcq_command
|
||||
def wait(self, signal:SignalType, value:int):
|
||||
"""
|
||||
Enqueues a wait command which halts execution until the signal is greater than or equal to a specific value.
|
||||
for v in values:
|
||||
if isinstance(v, int): self._q.append(v)
|
||||
else:
|
||||
self.q_sints.append((len(self._q), self._new_sym(v)))
|
||||
self._q.append(0xbadc0ded)
|
||||
|
||||
Args:
|
||||
signal: The signal to wait on
|
||||
value: The value to wait for
|
||||
"""
|
||||
self._wait(signal, value)
|
||||
def _wait(self, signal, value): raise NotImplementedError("backend should overload this function")
|
||||
# *** common commands ***
|
||||
|
||||
@hcq_command
|
||||
def timestamp(self, signal:SignalType):
|
||||
"""
|
||||
Enqueues a timestamp command which records the current time in a signal after all previously enqueued commands are completed.
|
||||
@@ -83,38 +56,56 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
|
||||
Args:
|
||||
signal: The signal to store the timestamp
|
||||
"""
|
||||
self._timestamp(signal)
|
||||
def _timestamp(self, signal): raise NotImplementedError("backend should overload this function")
|
||||
|
||||
def update_signal(self, cmd_idx:int, signal:Optional[SignalType]=None, value:Optional[int]=None):
|
||||
def signal(self, signal:SignalType, value:sint):
|
||||
"""
|
||||
Updates a previously queued signal command.
|
||||
Enqueues a signal command which sets the signal to the given value, ensuring all previous operations are completed.
|
||||
|
||||
Args:
|
||||
cmd_idx: Index of the signal command to update
|
||||
signal: New signal to set (if None, keeps the original)
|
||||
value: New value to set (if None, keeps the original)
|
||||
signal: The signal to set
|
||||
value: The value to set the signal to
|
||||
"""
|
||||
if self.cmds_meta[cmd_idx] != "signal": raise RuntimeError("called update_signal not on a signal command")
|
||||
self._update_signal(cmd_idx, signal, value)
|
||||
return self
|
||||
def _update_signal(self, cmd_idx:int, signal:Optional[SignalType], value:Optional[int]):
|
||||
raise NotImplementedError("backend should overload this function")
|
||||
|
||||
def update_wait(self, cmd_idx:int, signal:Optional[SignalType]=None, value:Optional[int]=None):
|
||||
def wait(self, signal:SignalType, value:sint):
|
||||
"""
|
||||
Updates a previously queued wait command.
|
||||
Enqueues a wait command which halts execution until the signal is greater than or equal to a specific value.
|
||||
|
||||
Args:
|
||||
cmd_idx: Index of the wait command to update
|
||||
signal: New signal to wait on (if None, keeps the original)
|
||||
value: New value to wait for (if None, keeps the original)
|
||||
signal: The signal to wait on
|
||||
value: The value to wait for
|
||||
"""
|
||||
if self.cmds_meta[cmd_idx] != "wait": raise RuntimeError("called update_wait not on a wait command")
|
||||
self._update_wait(cmd_idx, signal, value)
|
||||
return self
|
||||
def _update_wait(self, cmd_idx:int, signal:Optional[SignalType], value:Optional[int]):
|
||||
raise NotImplementedError("backend should overload this function")
|
||||
|
||||
# *** commands for compute queues ***
|
||||
|
||||
def memory_barrier(self):
|
||||
"""
|
||||
Enqueues a memory barrier command to ensure memory coherence between agents. Only on compute queues.
|
||||
"""
|
||||
|
||||
def exec(self, prg:ProgramType, args_state:ArgsStateType, global_size:Tuple[sint, ...], local_size:Tuple[sint, ...]):
|
||||
"""
|
||||
Enqueues an execution command for a kernel program. Only on compute queues.
|
||||
|
||||
Args:
|
||||
prg: The program to execute
|
||||
args_state: The args state to execute program with
|
||||
global_size: The global work size
|
||||
local_size: The local work size
|
||||
"""
|
||||
|
||||
# *** commands for copy queues ***
|
||||
|
||||
def copy(self, dest:sint, src:sint, copy_size:int):
|
||||
"""
|
||||
Enqueues a copy command to transfer data. Only on copy queues.
|
||||
|
||||
Args:
|
||||
dest: The destination of the copy
|
||||
src: The source of the copy
|
||||
copy_size: The size of data to copy
|
||||
"""
|
||||
|
||||
# *** submit and bind commands ***
|
||||
|
||||
def bind(self, dev:DeviceType):
|
||||
"""
|
||||
@@ -130,94 +121,50 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
|
||||
Implementing this method is optional but recommended for performance gains.
|
||||
"""
|
||||
|
||||
def submit(self, dev:DeviceType):
|
||||
def bind_sints(self, *vals:sint, struct:ctypes.Structure, start_field:str, fmt, mask:Optional[int]=None):
|
||||
self.bind_sints_to_ptr(*vals, ptr=ctypes.addressof(struct) + getattr(type(struct), start_field).offset, fmt=fmt, mask=mask)
|
||||
|
||||
def bind_sints_to_ptr(self, *vals:sint, ptr:int, fmt, mask:Optional[int]=None):
|
||||
mv = to_mv(ptr, 8*len(vals)).cast(fmt)
|
||||
for i, val in enumerate(vals):
|
||||
if isinstance(val, int): mv[i] = val if mask is None else ((mv[i] & ~mask) | val)
|
||||
else: self.mv_sints.append((mv, i, self._new_sym(val), mask))
|
||||
|
||||
def _apply_var_vals(self, var_vals:Dict[Variable, int]):
|
||||
resolved_syms = [sym_infer(sym, var_vals) for sym in self.syms]
|
||||
|
||||
for off, sym_idx in self.q_sints:
|
||||
if self._prev_resolved_syms[sym_idx] == resolved_syms[sym_idx]: continue
|
||||
self._q[off] = resolved_syms[sym_idx]
|
||||
|
||||
for mv, off, sym_idx, mask in self.mv_sints:
|
||||
if self._prev_resolved_syms[sym_idx] == resolved_syms[sym_idx]: continue
|
||||
mv[off] = resolved_syms[sym_idx] if mask is None else ((mv[off] & ~mask) | resolved_syms[sym_idx])
|
||||
|
||||
self._prev_resolved_syms = cast(List[Optional[int]], resolved_syms)
|
||||
|
||||
def submit(self, dev:DeviceType, var_vals:Optional[Dict[Variable, int]]=None):
|
||||
"""
|
||||
Submits the command queue to a specific device for execution.
|
||||
|
||||
Args:
|
||||
dev: The device to submit the queue to
|
||||
"""
|
||||
if self._q: self._submit(dev)
|
||||
|
||||
if var_vals is not None: self._apply_var_vals(var_vals)
|
||||
self._submit(dev)
|
||||
return self
|
||||
def _submit(self, dev:DeviceType): raise NotImplementedError("backend should overload this function")
|
||||
|
||||
# *** commands for compute queues ***
|
||||
|
||||
@hcq_command
|
||||
def memory_barrier(self):
|
||||
"""
|
||||
Enqueues a memory barrier command to ensure memory coherence between agents. Only on compute queues.
|
||||
"""
|
||||
self._memory_barrier()
|
||||
def _memory_barrier(self): pass
|
||||
|
||||
@hcq_command
|
||||
def exec(self, prg:ProgramType, args_state:ArgsStateType, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int]):
|
||||
"""
|
||||
Enqueues an execution command for a kernel program. Only on compute queues.
|
||||
|
||||
Args:
|
||||
prg: The program to execute
|
||||
args_state: The args state to execute program with
|
||||
global_size: The global work size
|
||||
local_size: The local work size
|
||||
"""
|
||||
self._exec(prg, args_state, global_size, local_size)
|
||||
def _exec(self, prg:ProgramType, args_state:ArgsStateType, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int]):
|
||||
raise NotImplementedError("backend should overload this function")
|
||||
|
||||
def update_exec(self, cmd_idx:int, global_size:Optional[Tuple[int,int,int]]=None, local_size:Optional[Tuple[int,int,int]]=None):
|
||||
"""
|
||||
Updates a previously queued execution command. Only on compute queues.
|
||||
|
||||
Args:
|
||||
cmd_idx: Index of the execution command to update
|
||||
global_size: New global work size (if None, keeps the original)
|
||||
local_size: New local work size (if None, keeps the original)
|
||||
"""
|
||||
if self.cmds_meta[cmd_idx] != "exec": raise RuntimeError("called update_exec not on an exec command")
|
||||
self._update_exec(cmd_idx, global_size, local_size)
|
||||
return self
|
||||
def _update_exec(self, cmd_idx, global_size, local_size): raise NotImplementedError("backend should overload this function")
|
||||
|
||||
# *** commands for copy queues ***
|
||||
|
||||
@hcq_command
|
||||
def copy(self, dest:int, src:int, copy_size:int):
|
||||
"""
|
||||
Enqueues a copy command to transfer data. Only on copy queues.
|
||||
|
||||
Args:
|
||||
dest: The destination of the copy
|
||||
src: The source of the copy
|
||||
copy_size: The size of data to copy
|
||||
"""
|
||||
self._copy(dest, src, copy_size)
|
||||
def _copy(self, dest:int, src:int, copy_size:int): raise NotImplementedError("backend should overload this function")
|
||||
|
||||
def update_copy(self, cmd_idx:int, dest:Optional[int]=None, src:Optional[int]=None):
|
||||
"""
|
||||
Updates a previously queued copy command. Only on copy queues.
|
||||
|
||||
Args:
|
||||
cmd_idx: Index of the copy command to update
|
||||
dest: New destination of the copy (if None, keeps the original)
|
||||
src: New source of the copy (if None, keeps the original)
|
||||
"""
|
||||
if self.cmds_meta[cmd_idx] != "copy": raise RuntimeError("called update_copy not on an copy command")
|
||||
self._update_copy(cmd_idx, dest, src)
|
||||
return self
|
||||
def _update_copy(self, cmd_idx:int, dest:Optional[int], src:Optional[int]):
|
||||
raise NotImplementedError("backend should overload this function")
|
||||
def _submit(self, dev:DeviceType): raise NotImplementedError("need _submit")
|
||||
|
||||
class HCQSignal(Generic[DeviceType]):
|
||||
def __init__(self, base_addr:int=0, value:int=0, timeline_for_device:Optional[DeviceType]=None, timestamp_divider=1, value_off=0, timestamp_off=8):
|
||||
def __init__(self, base_addr:sint=0, value:int=0, timeline_for_device:Optional[DeviceType]=None, timestamp_divider=1, value_off=0, timestamp_off=8):
|
||||
self.base_addr, self.value_addr, self.timestamp_addr = base_addr, base_addr+value_off, base_addr+timestamp_off
|
||||
self.timestamp_divider:decimal.Decimal = decimal.Decimal(timestamp_divider)
|
||||
self.timeline_for_device:Optional[DeviceType] = timeline_for_device
|
||||
|
||||
self.value_mv, self.timestamp_mv = to_mv(self.value_addr, 8).cast('Q'), to_mv(self.timestamp_addr, 8).cast('Q')
|
||||
self.value_mv[0] = value
|
||||
if isinstance(base_addr, int):
|
||||
self.value_mv, self.timestamp_mv = to_mv(self.value_addr, 8).cast('Q'), to_mv(self.timestamp_addr, 8).cast('Q')
|
||||
self.value_mv[0] = value
|
||||
|
||||
@property
|
||||
def value(self) -> int: return self.value_mv[0]
|
||||
|
||||
Reference in New Issue
Block a user