diff --git a/docs/developer/hcq.md b/docs/developer/hcq.md index c6ac8aa9c6..a2c725499a 100644 --- a/docs/developer/hcq.md +++ b/docs/developer/hcq.md @@ -24,27 +24,11 @@ Each runtime should implement the required functions that are defined in the `HW "signal", "wait", "timestamp", - "update_signal", - "update_wait", "bind", "submit", "memory_barrier", "exec", - "update_exec", "copy", - "update_copy", - ] - show_source: false - -#### Implementing custom commands - -To implement custom commands in the queue, use the @hcq_command decorator for your command implementations. - -::: tinygrad.runtime.support.hcq.hcq_command - options: - members: [ - "copy", - "update_copy", ] show_source: false @@ -141,5 +125,5 @@ your_device.timeline_signal.wait(your_device.timeline_value - 1) ## HCQGraph -[HCQGraph](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/graph/hcq.py) is a core feature that implements `GraphRunner` for HCQ-compatible devices. `HCQGraph` builds static `HWQueue` for all operations per device. To optimize enqueue time, only the necessary parts of the queues are updated for each run using the update APIs of the queues, avoiding a complete rebuild. +[HCQGraph](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/graph/hcq.py) is a core feature that implements `GraphRunner` for HCQ-compatible devices. `HCQGraph` builds static `HWQueue` for all operations per device. To optimize enqueue time, only the necessary parts of the queues are updated for each run using the symbolic variables, avoiding a complete rebuild. Optionally, queues can implement a `bind` API, which allows further optimization by eliminating the need to copy the queues into the device ring. diff --git a/test/test_hcq.py b/test/test_hcq.py index e2b307920f..a3832f0b5b 100644 --- a/test/test_hcq.py +++ b/test/test_hcq.py @@ -6,6 +6,7 @@ from tinygrad.runtime.support.hcq import HCQCompiled from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import get_runner, CompiledRunner from tinygrad.codegen.kernel import Kernel, Opt, OptOps +from tinygrad import Variable MOCKGPU = getenv("MOCKGPU") @@ -44,14 +45,19 @@ class TestHCQ(unittest.TestCase): for queue_type in [TestHCQ.d0.hw_compute_queue_t, TestHCQ.d0.hw_copy_queue_t]: if queue_type is None: continue - with self.subTest(name=str(queue_type)): - q = queue_type().signal(TestHCQ.d0.signal_t(), 0x1000) + virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32) + virt_signal = TestHCQ.d0.signal_t(base_addr=Variable("sig_addr", 0, 0xffffffffffffffff, dtypes.uint64)) - q.update_signal(0, signal=TestHCQ.d0.timeline_signal, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0) + with self.subTest(name=str(queue_type)): + q = queue_type().signal(virt_signal, virt_val) + + var_vals = {virt_signal.base_addr: TestHCQ.d0.timeline_signal.base_addr, virt_val: TestHCQ.d0.timeline_value} + q.submit(TestHCQ.d0, var_vals) TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - q.update_signal(0, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0) + var_vals = {virt_signal.base_addr: TestHCQ.d0.timeline_signal.base_addr, virt_val: TestHCQ.d0.timeline_value} + q.submit(TestHCQ.d0, var_vals) TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 @@ -91,12 +97,15 @@ class TestHCQ(unittest.TestCase): if queue_type is None: continue with self.subTest(name=str(queue_type)): + virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32) + virt_signal = TestHCQ.d0.signal_t(base_addr=Variable("sig_addr", 0, 0xffffffffffffffff, dtypes.uint64)) + fake_signal = TestHCQ.d0.signal_t() - q = queue_type().wait(TestHCQ.d0.timeline_signal, 0xffffffff).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) + q = queue_type().wait(virt_signal, virt_val).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) fake_signal.value = 0x30 - q.update_wait(0, signal=fake_signal, value=0x30).submit(TestHCQ.d0) + q.submit(TestHCQ.d0, {virt_signal.base_addr: fake_signal.base_addr, virt_val: fake_signal.value}) TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 @@ -112,26 +121,30 @@ class TestHCQ(unittest.TestCase): assert val == 1.0, f"got val {val}" def test_exec_2_kernels_100_times(self): + virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32) + q = TestHCQ.d0.hw_compute_queue_t() - q.wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \ + q.wait(TestHCQ.d0.timeline_signal, virt_val - 1) \ .exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \ .exec(TestHCQ.runner._prg, TestHCQ.kernargs_ab_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \ - .signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) + .signal(TestHCQ.d0.timeline_signal, virt_val) for _ in range(100): - q.update_wait(0, value=TestHCQ.d0.timeline_value - 1).update_signal(3, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0) + q.submit(TestHCQ.d0, {virt_val: TestHCQ.d0.timeline_value}) TestHCQ.d0.timeline_value += 1 val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0] assert val == 200.0, f"got val {val}" def test_exec_update(self): + sint_global = (Variable("sint_global", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.global_size[1:]) + sint_local = (Variable("sint_local", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.local_size[1:]) + q = TestHCQ.d0.hw_compute_queue_t() - q.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size) \ + q.exec(TestHCQ.runner._prg, TestHCQ.kernargs_ba_ptr, sint_global, sint_local) \ .signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) - q.update_exec(0, (1,1,1), (1,1,1)) - q.submit(TestHCQ.d0) + q.submit(TestHCQ.d0, {sint_global[0]: 1, sint_local[0]: 1}) TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 @@ -141,6 +154,9 @@ class TestHCQ(unittest.TestCase): assert val == 0.0, f"got val {val}, should not be updated" def test_exec_update_fuzz(self): + virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32) + virt_local = [Variable(f"local_{i}", 0, 0xffffffff, dtypes.uint32) for i in range(3)] + a = Tensor.randint((3, 3, 3), dtype=dtypes.int, device=Device.DEFAULT).realize() b = a + 1 si = create_schedule([b.lazydata])[-1] @@ -156,16 +172,15 @@ class TestHCQ(unittest.TestCase): q = TestHCQ.d0.hw_compute_queue_t() q.memory_barrier() \ - .exec(runner._prg, kernargs, (1,1,1), (1,1,1)) \ - .signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) + .exec(runner._prg, kernargs, (1,1,1), virt_local) \ + .signal(TestHCQ.d0.timeline_signal, virt_val) for x in range(1, 4): for y in range(1, 4): for z in range(1, 4): ctypes.memset(zt._buf.va_addr, 0, zb.nbytes) - q.update_exec(1, local_size=(x,y,z)) \ - .update_signal(2, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0) + q.submit(TestHCQ.d0, {virt_val: TestHCQ.d0.timeline_value, virt_local[0]: x, virt_local[1]: y, virt_local[2]: z}) TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 @@ -207,12 +222,14 @@ class TestHCQ(unittest.TestCase): def test_update_copy(self): if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue") + virt_src_addr = Variable("virt_src_addr", 0, 0xffffffffffffffff, dtypes.uint64) + virt_dest_addr = Variable("virt_dest_addr", 0, 0xffffffffffffffff, dtypes.uint64) + q = TestHCQ.d0.hw_copy_queue_t().wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \ - .copy(0x0, 0x0, 8) \ + .copy(virt_dest_addr, virt_src_addr, 8) \ .signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) - q.update_copy(1, dest=TestHCQ.b.lazydata.buffer._buf.va_addr, src=TestHCQ.a.lazydata.buffer._buf.va_addr) \ - .submit(TestHCQ.d0) + q.submit(TestHCQ.d0, {virt_src_addr: TestHCQ.a.lazydata.buffer._buf.va_addr, virt_dest_addr: TestHCQ.b.lazydata.buffer._buf.va_addr}) TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 @@ -223,17 +240,19 @@ class TestHCQ(unittest.TestCase): def test_update_copy_long(self): if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue") + virt_src_addr = Variable("virt_src_addr", 0, 0xffffffffffffffff, dtypes.uint64) + virt_dest_addr = Variable("virt_dest_addr", 0, 0xffffffffffffffff, dtypes.uint64) + sz = 64 << 20 buf1 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated() buf2 = Buffer(Device.DEFAULT, sz, dtypes.int8, options=BufferSpec(host=True, nolru=True)).ensure_allocated() ctypes.memset(buf2._buf.va_addr, 1, sz) q = TestHCQ.d0.hw_copy_queue_t().wait(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value - 1) \ - .copy(0x0, 0x0, sz) \ + .copy(virt_dest_addr, virt_src_addr, sz) \ .signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) - q.update_copy(1, buf1._buf.va_addr, buf2._buf.va_addr) \ - .submit(TestHCQ.d0) + q.submit(TestHCQ.d0, {virt_src_addr: buf2._buf.va_addr, virt_dest_addr: buf1._buf.va_addr}) TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 @@ -246,14 +265,17 @@ class TestHCQ(unittest.TestCase): for queue_type in [TestHCQ.d0.hw_compute_queue_t, TestHCQ.d0.hw_copy_queue_t]: if queue_type is None: continue + virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32) + virt_signal = TestHCQ.d0.signal_t(base_addr=Variable("sig_addr", 0, 0xffffffffffffffff, dtypes.uint64)) + with self.subTest(name=str(queue_type)): fake_signal = TestHCQ.d0.signal_t() - q = queue_type().wait(TestHCQ.d0.timeline_signal, 0xffffffff).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) + q = queue_type().wait(virt_signal, virt_val).signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) q.bind(TestHCQ.d0) fake_signal.value = 0x30 - q.update_wait(0, signal=fake_signal, value=0x30).submit(TestHCQ.d0) + q.submit(TestHCQ.d0, {virt_signal.base_addr: fake_signal.base_addr, virt_val: fake_signal.value}) TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index 4dc3b75e33..4b0c57ec59 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -3,12 +3,13 @@ from typing import List, Any, Dict, cast, Optional, Tuple, Set from tinygrad.helpers import round_up, PROFILE, memsize_to_str from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState from tinygrad.device import Buffer, BufferSpec, Compiled, Device -from tinygrad.ops import Variable +from tinygrad import Variable, dtypes +from tinygrad.ops import sint, Variable as VariableT from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner from tinygrad.engine.jit import MultiGraphRunner class HCQGraph(MultiGraphRunner): - def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): + def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[VariableT, int]): super().__init__(jit_cache, input_rawbuffers, var_vals) self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs])) @@ -40,6 +41,7 @@ class HCQGraph(MultiGraphRunner): self.signals: Dict[Any, HCQSignal] = {**{dev: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}} self.kickoff_value: int = 0 + self.kickoff_var = Variable("kickoff_var", 0, 0xffffffff, dtype=dtypes.uint32) self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(jit_cache) * 2)] if PROFILE else [] self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int], Optional[Dict]]] = [] @@ -70,7 +72,7 @@ class HCQGraph(MultiGraphRunner): # Ensure device is ready for use in current context: the graph has initialized the device and it's safe to operate on it within this graph. for dep_queue, _ in opt_deps: dev_access[enqueue_queue].update(dev_access[dep_queue]) - sync_signals = [(self.signals[d], self.kickoff_value) for b in ji.bufs if (d:=Device[cast(Buffer, b).device]) not in dev_access[enqueue_queue]] + sync_signals = [(self.signals[d], self.kickoff_var) for b in ji.bufs if (d:=Device[cast(Buffer, b).device]) not in dev_access[enqueue_queue]] dev_access[enqueue_queue].update(cast(HCQCompiled, Device[cast(Buffer, b).device]) for b in ji.bufs) # Remove self-dependency for compute and copy queues. @@ -99,18 +101,21 @@ class HCQGraph(MultiGraphRunner): last_j[enqueue_queue] = j # Build hardware queues. - self.op_cmd_idx: Dict[int, Tuple[HWQueue, int]] = {} + self.input_replace_to_var: Dict[Tuple[int, int], VariableT] = {} self.copy_to_devs: Dict[HCQCompiled, Set[HCQCompiled]] = {dev: set() for dev in self.devices} - self.kickoff_wait_cmds: Dict[HWQueue, List] = {q: list() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())} + + # Create variable timeline signals for each device. + timeline_sigaddrs = {dev: Variable(f"timeline_sig_{dev.device_id}", 0, 0xffffffffffffffff, dtype=dtypes.uint64) for dev in self.devices} + self.virt_timeline_vals = {dev: Variable(f"timeline_var_{dev.device_id}", 0, 0xffffffff, dtype=dtypes.uint32) for dev in self.devices} + self.virt_timeline_signals = {dev: dev.signal_t(base_addr=timeline_sigaddrs[dev], timeline_for_device=dev) for dev in self.devices} for dev in self.devices: - self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1) \ - .wait(self.signals['CPU'], self.kickoff_value).signal(self.signals[dev], self.kickoff_value) + self.comp_queues[dev].memory_barrier().wait(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev]) \ + .wait(self.signals['CPU'], self.kickoff_var).signal(self.signals[dev], self.kickoff_var) for j,ji in enumerate(jit_cache): enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j] - for i in range(len(sync_signals)): self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) + i) for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val) # Encode waits and start profile timestamp (if needed). @@ -118,13 +123,14 @@ class HCQGraph(MultiGraphRunner): # Encode main commands based on ji type. if isinstance(ji.prg, CompiledRunner): - enqueue_queue.exec(ji.prg._prg, self.ji_args[j], *ji.prg.p.launch_dims(var_vals)) + enqueue_queue.exec(ji.prg._prg, self.ji_args[j], tuple(ji.prg.p.global_size or (1,1,1)), tuple(ji.prg.p.local_size or (1,1,1))) elif isinstance(ji.prg, BufferXfer): dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]] cast(HCQAllocator, Device[src.device].allocator).map(dest._buf) - enqueue_queue.copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes) + + # TODO: For now sints is only for copies, should refactor to support exec as well. + enqueue_queue.copy(self._buf_addr_as_sint(j, 0, dest._buf), self._buf_addr_as_sint(j, 1, src._buf), dest.nbytes) self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device])) - self.op_cmd_idx[j] = (enqueue_queue, len(enqueue_queue) - 1) # Encode finish profile timestamp (if needed). if PROFILE and self.prof_records[j][1][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][1][0]]) @@ -135,13 +141,13 @@ class HCQGraph(MultiGraphRunner): for dep_dev in list(self.copy_to_devs[dev]) + [dev]: if dep_dev in self.copy_queues: self.comp_queues[dev].wait(self.signals[(copy_q:=self.copy_queues[dep_dev])], cast(int, last_j[copy_q]) + 1) - self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value).bind(dev) + self.comp_queues[dev].signal(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev] + 1).bind(dev) if dev in self.copy_queues: self.copy_queues[dev].bind(dev) self.last_timeline: Dict[HCQCompiled, Tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices} self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals] - def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]: + def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[VariableT, int], wait=False) -> Optional[float]: # Wait and restore signals self.kickoff_value += 1 for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1]) @@ -150,28 +156,21 @@ class HCQGraph(MultiGraphRunner): if PROFILE and self.kickoff_value > 1: self.collect_timestamps() + hcq_var_vals = {self.kickoff_var: self.kickoff_value, **var_vals, + **{var: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()}, + **{sig.base_addr: dev.timeline_signal.base_addr for dev, sig in self.virt_timeline_signals.items()}} + # Update rawbuffers for (j,i),input_idx in self.input_replace.items(): - if j in self.ji_args: self.ji_args[j].update_buffer(i, input_rawbuffers[input_idx]._buf) - else: self.op_cmd_idx[j][0].update_copy(self.op_cmd_idx[j][1], **{('dest' if i == 0 else 'src'): input_rawbuffers[input_idx]._buf.va_addr}) + if (var:=self.input_replace_to_var.get((j,i))) is not None: hcq_var_vals[var] = input_rawbuffers[input_idx]._buf.va_addr + else: self.ji_args[j].update_buffer(i, input_rawbuffers[input_idx]._buf) # Update var_vals for j, i, v in self.updated_vars(var_vals): self.ji_args[j].update_var(i, v) - # Update launch dims - for j, global_dims, local_dims in self.updated_launch_dims(var_vals): - queue, cmd_ptr = self.op_cmd_idx[j] - queue.update_exec(cmd_ptr, global_dims, local_dims) - for dev in self.devices: - comp_queue, copy_queue, need_sig_upd = self.comp_queues[dev], self.copy_queues.get(dev, None), dev.timeline_signal != self.last_timeline[dev][0] - comp_queue.update_wait(1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value - 1) \ - .update_wait(2, value=self.kickoff_value).update_signal(3, value=self.kickoff_value) \ - .update_signal(len(comp_queue)-1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value).submit(dev) - - if copy_queue is not None: - for cmd_idx in self.kickoff_wait_cmds[copy_queue]: copy_queue.update_wait(cmd_idx, value=self.kickoff_value) - copy_queue.submit(dev) + self.comp_queues[dev].submit(dev, hcq_var_vals) + if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals) self.last_timeline[dev] = (dev.timeline_signal, dev.timeline_value) dev.timeline_value += 1 @@ -192,6 +191,10 @@ class HCQGraph(MultiGraphRunner): (b_st,_), (b_en,_), b_dev, _, b_is_cp, _, _ = self.prof_records[x] dev.dep_prof_records += [(timestamps[b_st], timestamps[b_en], b_dev, b_is_cp, timestamps[st], timestamps[en], dev, is_cp)] + def _buf_addr_as_sint(self, j:int, i:int, buf:HCQBuffer) -> sint: + if (j, i) not in self.input_replace: return buf.va_addr + return self.input_replace_to_var.setdefault((j, i), Variable(f"input_{j}_{i}", 0, 0xffffffffffffffff, dtype=dtypes.uint64)) + def __del__(self): for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1]) diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index fea52ba38f..043818dd6f 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -4,6 +4,7 @@ import os, ctypes, ctypes.util, functools, pathlib, mmap, errno, array, contextl assert sys.platform != 'win32' from dataclasses import dataclass from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram +from tinygrad.ops import sint from tinygrad.device import BufferSpec from tinygrad.helpers import getenv, to_mv, round_up, data64_le, mv_address from tinygrad.renderer.cstyle import AMDRenderer @@ -31,7 +32,8 @@ class AMDSignal(HCQSignal): def __init__(self, base_addr:Optional[int]=None, **kwargs): super().__init__(AMDDevice.signals_pool.pop() if base_addr is None else base_addr, **kwargs, timestamp_divider=100) - def __del__(self): AMDDevice.signals_pool.append(self.base_addr) + def __del__(self): + if isinstance(self.base_addr, int): AMDDevice.signals_pool.append(self.base_addr) def _sleep(self, time_spent_waiting_ms:int): # Resonable to sleep for long workloads (which take more than 2s) and only timeline signals. @@ -39,10 +41,6 @@ class AMDSignal(HCQSignal): kfd.AMDKFD_IOC_WAIT_EVENTS(AMDDevice.kfd, events_ptr=self.timeline_for_device.queue_event_arr_ptr, num_events=1, wait_for_all=1, timeout=200) class AMDComputeQueue(HWQueue): - def __init__(self): - self.cmd_idx_to_local_offset, self.cmd_idx_to_global_offset, self.cmd_idx_to_dispatch_packet = {}, {}, {} - super().__init__() - def __del__(self): if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferSpec(cpu_access=True, nolru=True, uncached=True)) @@ -76,22 +74,23 @@ class AMDComputeQueue(HWQueue): self.pkt3(amd_gpu.PACKET3_RELEASE_MEM, event_dw | cache_flags_dw, memsel_dw, *data64_le(address), *data64_le(value), ctxid) - def _memory_barrier(self): + def memory_barrier(self): self.wait_reg_mem(reg_req=nbioreg(regBIF_BX_PF1_GPU_HDP_FLUSH_REQ), reg_done=nbioreg(regBIF_BX_PF1_GPU_HDP_FLUSH_DONE), value=0xffffffff) self.acquire_mem() + return self - def _exec(self, prg:AMDProgram, args_state:AMDArgsState, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1)): + def exec(self, prg:AMDProgram, args_state:AMDArgsState, global_size:Tuple[sint, ...], local_size:Tuple[sint, ...]): self.acquire_mem(gli=0, gl2=0) - cmd_idx = self._cur_cmd_idx() user_regs = [*data64_le(prg.dev.scratch.va_addr), 0xffffffff, 0xc00000] if prg.enable_private_segment_sgpr else [] if prg.enable_dispatch_ptr: dp = hsa.hsa_kernel_dispatch_packet_t.from_address(dp_addr:=args_state.ptr + prg.kernargs_segment_size) - dp.workgroup_size_x, dp.workgroup_size_y, dp.workgroup_size_z = local_size[0], local_size[1], local_size[2] - dp.grid_size_x, dp.grid_size_y, dp.grid_size_z = global_size[0]*local_size[0], global_size[1]*local_size[1], global_size[2]*local_size[2] + + self.bind_sints(*local_size, struct=dp, start_field='workgroup_size_x', fmt='H') + self.bind_sints(*[g*l for g,l in zip(global_size, local_size)], struct=dp, start_field='grid_size_x', fmt='I') dp.group_segment_size, dp.private_segment_size, dp.kernarg_address = prg.group_segment_size, prg.private_segment_size, args_state.ptr user_regs += [*data64_le(dp_addr)] - self.cmd_idx_to_dispatch_packet[cmd_idx] = dp + user_regs += [*data64_le(args_state.ptr)] self.pkt3(amd_gpu.PACKET3_SET_SH_REG, gfxreg(amd_gpu.regCOMPUTE_PGM_LO), *data64_le(prg.prog_addr >> 8)) @@ -107,29 +106,22 @@ class AMDComputeQueue(HWQueue): self.pkt3(amd_gpu.PACKET3_SET_SH_REG, gfxreg(amd_gpu.regCOMPUTE_STATIC_THREAD_MGMT_SE4), 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF) self.pkt3(amd_gpu.PACKET3_SET_SH_REG, gfxreg(amd_gpu.regCOMPUTE_USER_DATA_0), *user_regs) - self.cmd_idx_to_local_offset[cmd_idx] = len(self._q) - self.cmds_offset[cmd_idx] + 5 # +1 to skip PACKET3_SET_SH_REG + reg + 3 zeros. self.pkt3(amd_gpu.PACKET3_SET_SH_REG, gfxreg(amd_gpu.regCOMPUTE_START_X), 0, 0, 0, *local_size, 0, 0) self.pkt3(amd_gpu.PACKET3_SET_SH_REG, gfxreg(amd_gpu.regCOMPUTE_RESOURCE_LIMITS), 0) - self.cmd_idx_to_global_offset[cmd_idx] = len(self._q) - self.cmds_offset[cmd_idx] + 1 # +1 to skip PACKET3_DISPATCH_DIRECT. self.pkt3(amd_gpu.PACKET3_DISPATCH_DIRECT, *global_size, CS_W32_EN | FORCE_START_AT_000 | COMPUTE_SHADER_EN) self.pkt3(amd_gpu.PACKET3_EVENT_WRITE, amd_gpu.EVENT_TYPE(amd_gpu.CS_PARTIAL_FLUSH) | amd_gpu.EVENT_INDEX(EVENT_INDEX_PARTIAL_FLUSH)) + return self - def _update_exec(self, cmd_idx, global_size, local_size): - if local_size is not None: self._patch(cmd_idx, offset=self.cmd_idx_to_local_offset[cmd_idx], data=local_size) - if global_size is not None: self._patch(cmd_idx, offset=self.cmd_idx_to_global_offset[cmd_idx], data=global_size) + def wait(self, signal:AMDSignal, value:sint=0): + self.wait_reg_mem(mem=signal.value_addr, value=value, mask=0xffffffff) + return self - if (dp:=self.cmd_idx_to_dispatch_packet.get(cmd_idx)) is not None: - if local_size is not None: dp.workgroup_size_x, dp.workgroup_size_y, dp.workgroup_size_z = local_size[0], local_size[1], local_size[2] - if global_size is not None: - dp.grid_size_x,dp.grid_size_y,dp.grid_size_z = [g*l for g,l in zip(global_size,[dp.workgroup_size_x,dp.workgroup_size_y,dp.workgroup_size_z])] - - def _wait(self, signal:AMDSignal, value=0): self.wait_reg_mem(mem=signal.value_addr, value=value, mask=0xffffffff) - - def _timestamp(self, signal:AMDSignal): + def timestamp(self, signal:AMDSignal): self.release_mem(signal.timestamp_addr, 0, amd_gpu.data_sel__mec_release_mem__send_gpu_clock_counter, amd_gpu.int_sel__mec_release_mem__none) + return self - def _signal(self, signal:AMDSignal, value=0): + def signal(self, signal:AMDSignal, value:sint=0): # NOTE: this needs an EOP buffer on the queue or it will NULL pointer self.release_mem(signal.value_addr, value, amd_gpu.data_sel__mec_release_mem__send_32_bit_low, amd_gpu.int_sel__mec_release_mem__send_interrupt_after_write_confirm, cache_flush=True) @@ -137,14 +129,7 @@ class AMDComputeQueue(HWQueue): if (dev:=signal.timeline_for_device) is not None: self.release_mem(dev.queue_event_mailbox_ptr, dev.queue_event.event_id, amd_gpu.data_sel__mec_release_mem__send_32_bit_low, amd_gpu.int_sel__mec_release_mem__send_interrupt_after_write_confirm, ctxid=dev.queue_event.event_id) - - def _update_wait(self, cmd_idx, signal:Optional[AMDSignal]=None, value=None): - if signal is not None: self._patch(cmd_idx, offset=2, data=data64_le(signal.value_addr)) - if value is not None: self._patch(cmd_idx, offset=4, data=[value]) - - def _update_signal(self, cmd_idx, signal:Optional[AMDSignal]=None, value=None): - if signal is not None: self._patch(cmd_idx, offset=3, data=data64_le(signal.value_addr)) - if value is not None: self._patch(cmd_idx, offset=5, data=data64_le(value)) + return self def bind(self, dev:AMDDevice): self.binded_device = dev @@ -155,6 +140,7 @@ class AMDComputeQueue(HWQueue): self.indirect_cmd = [amd_gpu.PACKET3(amd_gpu.PACKET3_INDIRECT_BUFFER, 2), *data64_le(self.hw_page.va_addr), len(self._q) | amd_gpu.INDIRECT_BUFFER_VALID] self._q = hw_view # type: ignore + return self def _submit(self, dev:AMDDevice): cmds = self.indirect_cmd if dev == self.binded_device else self._q @@ -168,16 +154,16 @@ class AMDComputeQueue(HWQueue): SDMA_MAX_COPY_SIZE = 0x400000 class AMDCopyQueue(HWQueue): def __init__(self): - self.internal_cmd_sizes, self.copy_cmds_per_copy = [], {} + self.internal_cmd_sizes = [] super().__init__() def q(self, *arr): super().q(*arr) self.internal_cmd_sizes.append(len(arr)) - def _copy(self, dest, src, copy_size): + def copy(self, dest:sint, src:sint, copy_size:int): copied, copy_commands = 0, (copy_size + SDMA_MAX_COPY_SIZE - 1) // SDMA_MAX_COPY_SIZE - self.copy_cmds_per_copy[len(self) - 1] = copy_commands + for _ in range(copy_commands): step_copy_size = min(copy_size - copied, SDMA_MAX_COPY_SIZE) @@ -185,34 +171,27 @@ class AMDCopyQueue(HWQueue): amd_gpu.SDMA_PKT_COPY_LINEAR_COUNT_COUNT(step_copy_size - 1), 0, *data64_le(src + copied), *data64_le(dest + copied)) copied += step_copy_size + return self - def _update_copy(self, cmd_idx, dest=None, src=None): - for i in range(self.copy_cmds_per_copy[cmd_idx]): - if src is not None: self._patch(cmd_idx, offset=3+i*7, data=[*data64_le(src + SDMA_MAX_COPY_SIZE*i)]) - if dest is not None: self._patch(cmd_idx, offset=5+i*7, data=[*data64_le(dest + SDMA_MAX_COPY_SIZE*i)]) - - def _signal(self, signal:AMDSignal, value=0): + def signal(self, signal:AMDSignal, value:sint=0): self.q(amd_gpu.SDMA_OP_FENCE | amd_gpu.SDMA_PKT_FENCE_HEADER_MTYPE(3), *data64_le(signal.value_addr), value) if (dev:=signal.timeline_for_device) is not None: self.q(amd_gpu.SDMA_OP_FENCE | amd_gpu.SDMA_PKT_FENCE_HEADER_MTYPE(3), *data64_le(dev.queue_event_mailbox_ptr), dev.queue_event.event_id) self.q(amd_gpu.SDMA_OP_TRAP, amd_gpu.SDMA_PKT_TRAP_INT_CONTEXT_INT_CONTEXT(dev.queue_event.event_id)) - def _wait(self, signal:AMDSignal, value=0): + return self + + def wait(self, signal:AMDSignal, value:sint=0): self.q(amd_gpu.SDMA_OP_POLL_REGMEM | amd_gpu.SDMA_PKT_POLL_REGMEM_HEADER_FUNC(WAIT_REG_MEM_FUNCTION_GEQ) | \ amd_gpu.SDMA_PKT_POLL_REGMEM_HEADER_MEM_POLL(1), *data64_le(signal.value_addr), value, 0xffffffff, amd_gpu.SDMA_PKT_POLL_REGMEM_DW5_INTERVAL(0x04) | amd_gpu.SDMA_PKT_POLL_REGMEM_DW5_RETRY_COUNT(0xfff)) + return self - def _update_signal(self, cmd_idx, signal:Optional[AMDSignal]=None, value=None): - return self._update_wait(cmd_idx, signal, value) # the same offsets and commands - - def _update_wait(self, cmd_idx, signal:Optional[AMDSignal]=None, value=None): - if signal is not None: self._patch(cmd_idx, offset=1, data=data64_le(signal.value_addr)) - if value is not None: self._patch(cmd_idx, offset=3, data=[value]) - - def _timestamp(self, signal:AMDSignal): + def timestamp(self, signal:AMDSignal): self.q(amd_gpu.SDMA_OP_TIMESTAMP | amd_gpu.SDMA_PKT_TIMESTAMP_GET_HEADER_SUB_OP(amd_gpu.SDMA_SUBOP_TIMESTAMP_GET_GLOBAL), *data64_le(signal.timestamp_addr)) + return self def _submit(self, dev:AMDDevice): if dev.sdma_queue.put_value - dev.sdma_queue.read_ptr[0] > dev.sdma_queue.ring.nbytes: raise RuntimeError("SDMA queue overrun") diff --git a/tinygrad/runtime/ops_nv.py b/tinygrad/runtime/ops_nv.py index b164c5c4df..9f4113f4e7 100644 --- a/tinygrad/runtime/ops_nv.py +++ b/tinygrad/runtime/ops_nv.py @@ -3,8 +3,8 @@ import os, ctypes, contextlib, re, fcntl, functools, mmap, struct, array, sys assert sys.platform != 'win32' from typing import Tuple, List, Any, cast, Union, Dict, Type, Optional from dataclasses import dataclass -from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, hcq_command -from tinygrad.runtime.support.hcq import HCQArgsState, HCQProgram, HCQSignal +from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, HCQArgsState, HCQProgram, HCQSignal +from tinygrad.ops import sint from tinygrad.device import BufferSpec from tinygrad.helpers import getenv, mv_address, init_c_struct_t, to_mv, round_up, data64, data64_le, DEBUG, prod from tinygrad.renderer.ptx import PTXRenderer @@ -77,13 +77,17 @@ class NVSignal(HCQSignal): def __init__(self, base_addr:Optional[int]=None, **kwargs): super().__init__(NVDevice.signals_pool.pop() if base_addr is None else base_addr, **kwargs, timestamp_divider=1000, value_off=0, timestamp_off=8) - def __del__(self): NVDevice.signals_pool.append(self.base_addr) + def __del__(self): + if isinstance(self.base_addr, int): NVDevice.signals_pool.append(self.base_addr) class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']): + def __init__(self): + self.active_qmd = None + super().__init__() + def __del__(self): if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferSpec(cpu_access=True, nolru=True)) - @hcq_command def setup(self, compute_class=None, copy_class=None, local_mem_window=None, shared_mem_window=None, local_mem=None, local_mem_tpc_bytes=None): if compute_class: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_OBJECT, 1), compute_class) if copy_class: self.q(nvmethod(4, nv_gpu.NVC6C0_SET_OBJECT, 1), copy_class) @@ -91,16 +95,15 @@ class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']): if shared_mem_window: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_SHADER_SHARED_MEMORY_WINDOW_A, 2), *data64(shared_mem_window)) if local_mem: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_SHADER_LOCAL_MEMORY_A, 2), *data64(local_mem)) if local_mem_tpc_bytes: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_SHADER_LOCAL_MEMORY_NON_THROTTLED_A, 3), *data64(local_mem_tpc_bytes), 0xff) + return self - def _wait(self, signal:NVSignal, value=0): + def wait(self, signal:NVSignal, value:sint=0): self.q(nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *data64_le(signal.value_addr), *data64_le(value), (3 << 0) | (1 << 24)) # ACQUIRE | PAYLOAD_SIZE_64BIT + self.active_qmd = None + return self - def _update_wait(self, cmd_idx, signal=None, value=None): - if signal is not None: self._q[(sigoff:=self.cmds_offset[cmd_idx]+1):sigoff+2] = array.array('I', data64_le(signal.value_addr)) - if value is not None: self._q[(valoff:=self.cmds_offset[cmd_idx]+3):valoff+2] = array.array('I', data64_le(value)) - - def _timestamp(self, signal): return self._signal(signal, 0) + def timestamp(self, signal:NVSignal): return self.signal(signal, 0) def bind(self, dev:NVDevice): self.binded_device = dev @@ -129,79 +132,63 @@ class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']): gpfifo.put_value += 1 class NVComputeQueue(NVCommandQueue): - def __init__(self): - self.cmd_idx_to_qmd, self.cmd_idx_to_signal_id, self.cmd_idx_to_global_dims, self.cmd_idx_to_local_dims = {}, {}, {}, {} - super().__init__() + def memory_barrier(self): + self.q(nvmethod(1, nv_gpu.NVC6C0_INVALIDATE_SHADER_CACHES_NO_WFI, 1), (1 << 12) | (1 << 4) | (1 << 0)) + self.active_qmd = None + return self - def _memory_barrier(self): self.q(nvmethod(1, nv_gpu.NVC6C0_INVALIDATE_SHADER_CACHES_NO_WFI, 1), (1 << 12) | (1 << 4) | (1 << 0)) - - def _exec(self, prg:NVProgram, args_state:NVArgsState, global_size, local_size): + def exec(self, prg:NVProgram, args_state:NVArgsState, global_size:Tuple[sint, ...], local_size:Tuple[sint, ...]): ctypes.memmove(qmd_addr:=(args_state.ptr + round_up(prg.constbufs[0][1], 1 << 8)), ctypes.addressof(prg.qmd), 0x40 * 4) assert qmd_addr < (1 << 40), f"large qmd addr {qmd_addr:x}" - self.cmd_idx_to_qmd[self._cur_cmd_idx()] = qmd = qmd_struct_t.from_address(qmd_addr) # Save qmd for later update - self.cmd_idx_to_global_dims[self._cur_cmd_idx()] = to_mv(qmd_addr + nv_gpu.NVC6C0_QMDV03_00_CTA_RASTER_WIDTH[1] // 8, 12).cast('I') - self.cmd_idx_to_local_dims[self._cur_cmd_idx()] = to_mv(qmd_addr + nv_gpu.NVC6C0_QMDV03_00_CTA_THREAD_DIMENSION0[1] // 8, 6).cast('H') + qmd = qmd_struct_t.from_address(qmd_addr) # Save qmd for later update - qmd.cta_raster_width, qmd.cta_raster_height, qmd.cta_raster_depth = global_size - qmd.cta_thread_dimension0, qmd.cta_thread_dimension1, qmd.cta_thread_dimension2 = local_size + self.bind_sints_to_ptr(*global_size, ptr=qmd_addr + nv_gpu.NVC6C0_QMDV03_00_CTA_RASTER_WIDTH[1] // 8, fmt='I') + self.bind_sints_to_ptr(*local_size, ptr=qmd_addr + nv_gpu.NVC6C0_QMDV03_00_CTA_THREAD_DIMENSION0[1] // 8, fmt='H') qmd.constant_buffer_addr_upper_0, qmd.constant_buffer_addr_lower_0 = data64(args_state.ptr) - if (prev_qmd:=self.cmd_idx_to_qmd.get(self._cur_cmd_idx() - 1)) is None: + if self.active_qmd is None: self.q(nvmethod(1, nv_gpu.NVC6C0_SEND_PCAS_A, 0x1), qmd_addr >> 8) self.q(nvmethod(1, nv_gpu.NVC6C0_SEND_SIGNALING_PCAS2_B, 0x1), 9) else: - prev_qmd.dependent_qmd0_pointer = qmd_addr >> 8 - prev_qmd.dependent_qmd0_action = 1 - prev_qmd.dependent_qmd0_prefetch = 1 - prev_qmd.dependent_qmd0_enable = 1 + self.active_qmd.dependent_qmd0_pointer = qmd_addr >> 8 + self.active_qmd.dependent_qmd0_action = 1 + self.active_qmd.dependent_qmd0_prefetch = 1 + self.active_qmd.dependent_qmd0_enable = 1 - def _update_exec(self, cmd_idx, global_size, local_size): - # Patch the exec cmd with new launch dims - if global_size is not None: self.cmd_idx_to_global_dims[cmd_idx][:] = array.array('I', global_size) - if local_size is not None: self.cmd_idx_to_local_dims[cmd_idx][:] = array.array('H', local_size) + self.active_qmd = qmd + return self - def _signal(self, signal:NVSignal, value=0): - if (prev_qmd:=self.cmd_idx_to_qmd.get(self._cur_cmd_idx() - 1)) is not None: + def signal(self, signal:NVSignal, value:sint=0): + if self.active_qmd is not None: for i in range(2): - if getattr(prev_qmd, f'release{i}_enable') == 0: - setattr(prev_qmd, f'release{i}_enable', 1) - setattr(prev_qmd, f'release{i}_address', signal.value_addr) - setattr(prev_qmd, f'release{i}_payload', value) - self.cmd_idx_to_qmd[self._cur_cmd_idx()] = prev_qmd - self.cmd_idx_to_signal_id[self._cur_cmd_idx()] = i - return + if getattr(self.active_qmd, f'release{i}_enable') == 0: + setattr(self.active_qmd, f'release{i}_enable', 1) + self.bind_sints(signal.value_addr, struct=self.active_qmd, start_field=f'release{i}_address', fmt='Q', mask=0xfffffffff) + self.bind_sints(value, struct=self.active_qmd, start_field=f'release{i}_payload', fmt='Q') + return self self.q(nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *data64_le(signal.value_addr), *data64_le(value), (1 << 0) | (1 << 20) | (1 << 24) | (1 << 25)) # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP self.q(nvmethod(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 1), 0x0) + self.active_qmd = None + return self - def _update_signal(self, cmd_idx, signal:Optional[NVSignal]=None, value=None): - if (qmd:=self.cmd_idx_to_qmd.get(cmd_idx)) is None: return super()._update_wait(cmd_idx, signal, value) # reuse wait, same offsets to update. - if signal is not None: setattr(qmd, f'release{self.cmd_idx_to_signal_id[cmd_idx]}_address', signal.value_addr) - if value is not None: setattr(qmd, f'release{self.cmd_idx_to_signal_id[cmd_idx]}_payload', value) - - def _submit(self, dev): self._submit_to_gpfifo(dev, cast(NVDevice, dev).compute_gpfifo) + def _submit(self, dev:NVDevice): self._submit_to_gpfifo(dev, dev.compute_gpfifo) class NVCopyQueue(NVCommandQueue): - def _copy(self, dest, src, copy_size): + def copy(self, dest:sint, src:sint, copy_size:int): self.q(nvmethod(4, nv_gpu.NVC6B5_OFFSET_IN_UPPER, 4), *data64(src), *data64(dest)) self.q(nvmethod(4, nv_gpu.NVC6B5_LINE_LENGTH_IN, 1), copy_size) self.q(nvmethod(4, nv_gpu.NVC6B5_LAUNCH_DMA, 1), 0x182) # TRANSFER_TYPE_NON_PIPELINED | DST_MEMORY_LAYOUT_PITCH | SRC_MEMORY_LAYOUT_PITCH + return self - def _update_copy(self, cmd_idx, dest=None, src=None): - if dest is not None: self._patch(cmd_idx, offset=3, data=data64(dest)) - if src is not None: self._patch(cmd_idx, offset=1, data=data64(src)) - - def _signal(self, signal, value=0): + def signal(self, signal:NVSignal, value:sint=0): self.q(nvmethod(4, nv_gpu.NVC6B5_SET_SEMAPHORE_A, 3), *data64(signal.value_addr), value) self.q(nvmethod(4, nv_gpu.NVC6B5_LAUNCH_DMA, 1), 0x14) + return self - def _update_signal(self, cmd_idx, signal=None, value=None): - if signal is not None: self._patch(cmd_idx, offset=1, data=data64(signal.value_addr)) - if value is not None: self._patch(cmd_idx, offset=3, data=[value]) - - def _submit(self, dev): self._submit_to_gpfifo(dev, cast(NVDevice, dev).dma_gpfifo) + def _submit(self, dev:NVDevice): self._submit_to_gpfifo(dev, dev.dma_gpfifo) class NVArgsState(HCQArgsState): def __init__(self, ptr:int, prg:NVProgram, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()): diff --git a/tinygrad/runtime/ops_qcom.py b/tinygrad/runtime/ops_qcom.py index 3ade2d7432..a44cf01aa1 100644 --- a/tinygrad/runtime/ops_qcom.py +++ b/tinygrad/runtime/ops_qcom.py @@ -39,7 +39,8 @@ class QCOMSignal(HCQSignal): def __init__(self, base_addr:Optional[int]=None, **kwargs): super().__init__(QCOMDevice.signals_pool.pop() if base_addr is None else base_addr, **kwargs, timestamp_divider=19.2) - def __del__(self): QCOMDevice.signals_pool.append(self.base_addr) + def __del__(self): + if isinstance(self.base_addr, int): QCOMDevice.signals_pool.append(self.base_addr) def _sleep(self, time_spent_waiting_ms:int): # Sleep only for only timeline signals. Do it immidiately to free cpu. @@ -48,10 +49,6 @@ class QCOMSignal(HCQSignal): timestamp=self.timeline_for_device.last_cmd, timeout=0xffffffff) class QCOMComputeQueue(HWQueue): - def __init__(self): - self.cmd_idx_to_dims = {} - super().__init__() - def __del__(self): if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferSpec(cpu_access=True, nolru=True)) @@ -66,9 +63,11 @@ class QCOMComputeQueue(HWQueue): if memsync: self.cmd(adreno.CP_WAIT_MEM_WRITES) if sync: self.cmd(adreno.CP_WAIT_FOR_IDLE) - def _memory_barrier(self): self._cache_flush(write_back=True, invalidate=True, sync=True, memsync=True) + def memory_barrier(self): + self._cache_flush(write_back=True, invalidate=True, sync=True, memsync=True) + return self - def _signal(self, signal:QCOMSignal, value=0, ts=False): + def signal(self, signal:QCOMSignal, value=0, ts=False): self.cmd(adreno.CP_WAIT_FOR_IDLE) if QCOMDevice.gpu_id < 700: self.cmd(adreno.CP_EVENT_WRITE, qreg.cp_event_write_0(event=adreno.CACHE_FLUSH_TS, timestamp=ts), @@ -77,20 +76,14 @@ class QCOMComputeQueue(HWQueue): else: # TODO: support devices starting with 8 Gen 1. Also, 700th series have convenient CP_GLOBAL_TIMESTAMP and CP_LOCAL_TIMESTAMP raise RuntimeError('CP_EVENT_WRITE7 is not supported') + return self - def _timestamp(self, signal:QCOMSignal): return self._signal(signal, 0, ts=True) + def timestamp(self, signal:QCOMSignal): return self.signal(signal, 0, ts=True) - def _wait(self, signal:QCOMSignal, value=0): + def wait(self, signal:QCOMSignal, value=0): self.cmd(adreno.CP_WAIT_REG_MEM, qreg.cp_wait_reg_mem_0(function=adreno.WRITE_GE, poll=adreno.POLL_MEMORY),*data64_le(signal.value_addr), qreg.cp_wait_reg_mem_3(ref=value&0xFFFFFFFF), qreg.cp_wait_reg_mem_4(mask=0xFFFFFFFF), qreg.cp_wait_reg_mem_5(delay_loop_cycles=32)) - - def _update_signal(self, cmd_idx, signal:Optional[QCOMSignal], value): - if signal is not None: self._patch(cmd_idx, offset=3, data=data64_le(signal.value_addr)) - if value is not None: self._patch(cmd_idx, offset=5, data=[value & 0xFFFFFFFF]) - - def _update_wait(self, cmd_idx, signal:Optional[QCOMSignal], value): - if signal is not None: self._patch(cmd_idx, offset=2, data=data64_le(signal.value_addr)) - if value is not None: self._patch(cmd_idx, offset=4, data=[value & 0xFFFFFFFF]) + return self def _build_gpu_command(self, dev:QCOMDevice, hw_addr=None): to_mv((hw_page_addr:=hw_addr or dev._alloc_cmd_buf(len(self._q) * 4)), len(self._q) * 4).cast('I')[:] = array.array('I', self._q) @@ -111,9 +104,9 @@ class QCOMComputeQueue(HWQueue): else: submit_req, _ = self._build_gpu_command(dev) dev.last_cmd = kgsl.IOCTL_KGSL_GPU_COMMAND(dev.fd, __payload=submit_req).timestamp - def _exec(self, prg:QCOMProgram, args_state:QCOMArgsState, global_size, local_size): - global_size_mp = [int(g*l) for g,l in zip(global_size, local_size)] - self.cmd_idx_to_dims[self._cur_cmd_idx()] = [global_size, local_size] + def exec(self, prg:QCOMProgram, args_state:QCOMArgsState, global_size, local_size): + def cast_int(x, ceil=False): return (math.ceil(x) if ceil else int(x)) if isinstance(x, float) else x + global_size_mp = [cast_int(g*l) for g,l in zip(global_size, local_size)] self.cmd(adreno.CP_SET_MARKER, qreg.a6xx_cp_set_marker_0(mode=adreno.RM6_COMPUTE)) self.reg(adreno.REG_A6XX_HLSQ_INVALIDATE_CMD, qreg.a6xx_hlsq_invalidate_cmd(cs_state=True, cs_ibo=True)) @@ -129,7 +122,7 @@ class QCOMComputeQueue(HWQueue): self.reg(adreno.REG_A6XX_HLSQ_CS_NDRANGE_0, qreg.a6xx_hlsq_cs_ndrange_0(kerneldim=3, localsizex=local_size[0] - 1, localsizey=local_size[1] - 1, localsizez=local_size[2] - 1), global_size_mp[0], 0, global_size_mp[1], 0, global_size_mp[2], 0, 0xccc0cf, 0xfc | qreg.a6xx_hlsq_cs_cntl_1(threadsize=adreno.THREAD64), - int(math.ceil(global_size[0])), int(math.ceil(global_size[1])), int(math.ceil(global_size[2]))) + cast_int(global_size[0], ceil=True), cast_int(global_size[1], ceil=True), cast_int(global_size[2], ceil=True)) self.reg(adreno.REG_A6XX_SP_CS_CTRL_REG0, qreg.a6xx_sp_cs_ctrl_reg0(threadsize=adreno.THREAD64, halfregfootprint=prg.hregs, fullregfootprint=prg.fregs, branchstack=prg.brnchstck), @@ -172,19 +165,7 @@ class QCOMComputeQueue(HWQueue): qreg.a6xx_sp_cs_config(enabled=True, nsamp=args_state.prg.samp_cnt, ntex=args_state.prg.tex_cnt, nibo=args_state.prg.ibo_cnt)) self.cmd(adreno.CP_RUN_OPENCL, 0) self._cache_flush(write_back=True, invalidate=False, sync=False, memsync=False) - - def _update_exec(self, cmd_idx, global_size, local_size): - if global_size is not None: - self._patch(cmd_idx, offset=29, data=[int(math.ceil(global_size[0])), int(math.ceil(global_size[1])), int(math.ceil(global_size[2]))]) - self.cmd_idx_to_dims[cmd_idx][0] = global_size - - if local_size is not None: - payload = qreg.a6xx_hlsq_cs_ndrange_0(kerneldim=3, localsizex=local_size[0] - 1, localsizey=local_size[1] - 1, localsizez=local_size[2] - 1) - self._patch(cmd_idx, offset=20, data=[payload]) - self.cmd_idx_to_dims[cmd_idx][1] = local_size - - global_size_mp = [int(g*l) for g,l in zip(self.cmd_idx_to_dims[cmd_idx][0], self.cmd_idx_to_dims[cmd_idx][1])] - self._patch(cmd_idx, offset=21, data=[global_size_mp[0], 0, global_size_mp[1], 0, global_size_mp[2], 0]) + return self class QCOMArgsState(HCQArgsState): def __init__(self, ptr:int, prg:QCOMProgram, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()): diff --git a/tinygrad/runtime/support/hcq.py b/tinygrad/runtime/support/hcq.py index bfd86bee00..16250d415f 100644 --- a/tinygrad/runtime/support/hcq.py +++ b/tinygrad/runtime/support/hcq.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import List, Optional, Dict, Tuple, cast, Protocol, Type, Union, TypeVar, Generic, Callable, ParamSpec, Concatenate -import contextlib, decimal, statistics, random, json, atexit, time, array, ctypes, functools +from typing import List, Optional, Dict, Tuple, cast, Protocol, Type, Union, TypeVar, Generic, Any +import contextlib, decimal, statistics, random, json, atexit, time, ctypes from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv, getenv, to_mv from tinygrad.renderer import Renderer from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator +from tinygrad.ops import sym_infer, sint, Variable # **************** for HCQ Compatible Devices **************** @@ -13,69 +14,41 @@ ProgramType = TypeVar('ProgramType', bound='HCQProgram') ArgsStateType = TypeVar('ArgsStateType', bound='HCQArgsState') QueueType = TypeVar('QueueType', bound='HWQueue') -P = ParamSpec('P') -def hcq_command(func: Callable[Concatenate[QueueType, P], None]) -> Callable[Concatenate[QueueType, P], QueueType]: - """ - Decorator for HWCommandQueue commands. Enables command indexing and stores metadata for command updates. - - For example: - ```python - @hcq_command - def command_method(self, ...): ... - ``` - """ - @functools.wraps(func) - def __wrapper(self:QueueType, *args:P.args, **kwargs:P.kwargs) -> QueueType: - self.cmds_offset.append(len(self._q)) - func(self, *args, **kwargs) - self.cmds_len.append(len(self._q) - self.cmds_offset[-1]) - self.cmds_meta.append(func.__name__) - return self - return __wrapper - class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]): """ A base class for hardware command queues in the HCQ (Hardware Command Queue) API. - Both compute and copy queues should have the following commands implemented. """ - def __init__(self): self._q, self.binded_device, self.cmds_offset, self.cmds_len, self.cmds_meta = [], None, [], [], [] - def q(self, *args) -> None: self._q.extend(args) + def __init__(self): + self._q:Any = [] + self.binded_device:Optional[DeviceType] = None + self.q_sints:List[Tuple[int, int]] = [] + self.mv_sints:List[Tuple[memoryview, int, int, Optional[int]]] = [] + self.syms:List[sint] = [] + self._prev_resolved_syms:List[Optional[int]] = [] - def __len__(self): return len(self.cmds_offset) - def _patch(self, cmd_idx, offset, data): self._q[(st:=self.cmds_offset[cmd_idx]+offset):st+len(data)] = array.array('I', data) - def _cur_cmd_idx(self) -> int: - """ - Returns the index of the command currently being enqueued. - Should be called only within functions that enqueue commands and are decorated with `@hcq_command`. - """ - return len(self) - 1 + def _new_sym(self, sym:sint) -> int: + if sym not in self.syms: + self.syms.append(sym) + self._prev_resolved_syms.append(None) + return self.syms.index(sym) - @hcq_command - def signal(self, signal:SignalType, value:int): + def q(self, *values): """ - Enqueues a signal command which sets the signal to the given value, ensuring all previous operations are completed. + Enqueues values in the queue. Args: - signal: The signal to set - value: The value to set the signal to + values: The values to enqueue in the queue. """ - self._signal(signal, value) - def _signal(self, signal:SignalType, value:int): raise NotImplementedError("backend should overload this function") - @hcq_command - def wait(self, signal:SignalType, value:int): - """ - Enqueues a wait command which halts execution until the signal is greater than or equal to a specific value. + for v in values: + if isinstance(v, int): self._q.append(v) + else: + self.q_sints.append((len(self._q), self._new_sym(v))) + self._q.append(0xbadc0ded) - Args: - signal: The signal to wait on - value: The value to wait for - """ - self._wait(signal, value) - def _wait(self, signal, value): raise NotImplementedError("backend should overload this function") + # *** common commands *** - @hcq_command def timestamp(self, signal:SignalType): """ Enqueues a timestamp command which records the current time in a signal after all previously enqueued commands are completed. @@ -83,38 +56,56 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]): Args: signal: The signal to store the timestamp """ - self._timestamp(signal) - def _timestamp(self, signal): raise NotImplementedError("backend should overload this function") - def update_signal(self, cmd_idx:int, signal:Optional[SignalType]=None, value:Optional[int]=None): + def signal(self, signal:SignalType, value:sint): """ - Updates a previously queued signal command. + Enqueues a signal command which sets the signal to the given value, ensuring all previous operations are completed. Args: - cmd_idx: Index of the signal command to update - signal: New signal to set (if None, keeps the original) - value: New value to set (if None, keeps the original) + signal: The signal to set + value: The value to set the signal to """ - if self.cmds_meta[cmd_idx] != "signal": raise RuntimeError("called update_signal not on a signal command") - self._update_signal(cmd_idx, signal, value) - return self - def _update_signal(self, cmd_idx:int, signal:Optional[SignalType], value:Optional[int]): - raise NotImplementedError("backend should overload this function") - def update_wait(self, cmd_idx:int, signal:Optional[SignalType]=None, value:Optional[int]=None): + def wait(self, signal:SignalType, value:sint): """ - Updates a previously queued wait command. + Enqueues a wait command which halts execution until the signal is greater than or equal to a specific value. Args: - cmd_idx: Index of the wait command to update - signal: New signal to wait on (if None, keeps the original) - value: New value to wait for (if None, keeps the original) + signal: The signal to wait on + value: The value to wait for """ - if self.cmds_meta[cmd_idx] != "wait": raise RuntimeError("called update_wait not on a wait command") - self._update_wait(cmd_idx, signal, value) - return self - def _update_wait(self, cmd_idx:int, signal:Optional[SignalType], value:Optional[int]): - raise NotImplementedError("backend should overload this function") + + # *** commands for compute queues *** + + def memory_barrier(self): + """ + Enqueues a memory barrier command to ensure memory coherence between agents. Only on compute queues. + """ + + def exec(self, prg:ProgramType, args_state:ArgsStateType, global_size:Tuple[sint, ...], local_size:Tuple[sint, ...]): + """ + Enqueues an execution command for a kernel program. Only on compute queues. + + Args: + prg: The program to execute + args_state: The args state to execute program with + global_size: The global work size + local_size: The local work size + """ + + # *** commands for copy queues *** + + def copy(self, dest:sint, src:sint, copy_size:int): + """ + Enqueues a copy command to transfer data. Only on copy queues. + + Args: + dest: The destination of the copy + src: The source of the copy + copy_size: The size of data to copy + """ + + # *** submit and bind commands *** def bind(self, dev:DeviceType): """ @@ -130,94 +121,50 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]): Implementing this method is optional but recommended for performance gains. """ - def submit(self, dev:DeviceType): + def bind_sints(self, *vals:sint, struct:ctypes.Structure, start_field:str, fmt, mask:Optional[int]=None): + self.bind_sints_to_ptr(*vals, ptr=ctypes.addressof(struct) + getattr(type(struct), start_field).offset, fmt=fmt, mask=mask) + + def bind_sints_to_ptr(self, *vals:sint, ptr:int, fmt, mask:Optional[int]=None): + mv = to_mv(ptr, 8*len(vals)).cast(fmt) + for i, val in enumerate(vals): + if isinstance(val, int): mv[i] = val if mask is None else ((mv[i] & ~mask) | val) + else: self.mv_sints.append((mv, i, self._new_sym(val), mask)) + + def _apply_var_vals(self, var_vals:Dict[Variable, int]): + resolved_syms = [sym_infer(sym, var_vals) for sym in self.syms] + + for off, sym_idx in self.q_sints: + if self._prev_resolved_syms[sym_idx] == resolved_syms[sym_idx]: continue + self._q[off] = resolved_syms[sym_idx] + + for mv, off, sym_idx, mask in self.mv_sints: + if self._prev_resolved_syms[sym_idx] == resolved_syms[sym_idx]: continue + mv[off] = resolved_syms[sym_idx] if mask is None else ((mv[off] & ~mask) | resolved_syms[sym_idx]) + + self._prev_resolved_syms = cast(List[Optional[int]], resolved_syms) + + def submit(self, dev:DeviceType, var_vals:Optional[Dict[Variable, int]]=None): """ Submits the command queue to a specific device for execution. Args: dev: The device to submit the queue to """ - if self._q: self._submit(dev) + + if var_vals is not None: self._apply_var_vals(var_vals) + self._submit(dev) return self - def _submit(self, dev:DeviceType): raise NotImplementedError("backend should overload this function") - - # *** commands for compute queues *** - - @hcq_command - def memory_barrier(self): - """ - Enqueues a memory barrier command to ensure memory coherence between agents. Only on compute queues. - """ - self._memory_barrier() - def _memory_barrier(self): pass - - @hcq_command - def exec(self, prg:ProgramType, args_state:ArgsStateType, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int]): - """ - Enqueues an execution command for a kernel program. Only on compute queues. - - Args: - prg: The program to execute - args_state: The args state to execute program with - global_size: The global work size - local_size: The local work size - """ - self._exec(prg, args_state, global_size, local_size) - def _exec(self, prg:ProgramType, args_state:ArgsStateType, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int]): - raise NotImplementedError("backend should overload this function") - - def update_exec(self, cmd_idx:int, global_size:Optional[Tuple[int,int,int]]=None, local_size:Optional[Tuple[int,int,int]]=None): - """ - Updates a previously queued execution command. Only on compute queues. - - Args: - cmd_idx: Index of the execution command to update - global_size: New global work size (if None, keeps the original) - local_size: New local work size (if None, keeps the original) - """ - if self.cmds_meta[cmd_idx] != "exec": raise RuntimeError("called update_exec not on an exec command") - self._update_exec(cmd_idx, global_size, local_size) - return self - def _update_exec(self, cmd_idx, global_size, local_size): raise NotImplementedError("backend should overload this function") - - # *** commands for copy queues *** - - @hcq_command - def copy(self, dest:int, src:int, copy_size:int): - """ - Enqueues a copy command to transfer data. Only on copy queues. - - Args: - dest: The destination of the copy - src: The source of the copy - copy_size: The size of data to copy - """ - self._copy(dest, src, copy_size) - def _copy(self, dest:int, src:int, copy_size:int): raise NotImplementedError("backend should overload this function") - - def update_copy(self, cmd_idx:int, dest:Optional[int]=None, src:Optional[int]=None): - """ - Updates a previously queued copy command. Only on copy queues. - - Args: - cmd_idx: Index of the copy command to update - dest: New destination of the copy (if None, keeps the original) - src: New source of the copy (if None, keeps the original) - """ - if self.cmds_meta[cmd_idx] != "copy": raise RuntimeError("called update_copy not on an copy command") - self._update_copy(cmd_idx, dest, src) - return self - def _update_copy(self, cmd_idx:int, dest:Optional[int], src:Optional[int]): - raise NotImplementedError("backend should overload this function") + def _submit(self, dev:DeviceType): raise NotImplementedError("need _submit") class HCQSignal(Generic[DeviceType]): - def __init__(self, base_addr:int=0, value:int=0, timeline_for_device:Optional[DeviceType]=None, timestamp_divider=1, value_off=0, timestamp_off=8): + def __init__(self, base_addr:sint=0, value:int=0, timeline_for_device:Optional[DeviceType]=None, timestamp_divider=1, value_off=0, timestamp_off=8): self.base_addr, self.value_addr, self.timestamp_addr = base_addr, base_addr+value_off, base_addr+timestamp_off self.timestamp_divider:decimal.Decimal = decimal.Decimal(timestamp_divider) self.timeline_for_device:Optional[DeviceType] = timeline_for_device - self.value_mv, self.timestamp_mv = to_mv(self.value_addr, 8).cast('Q'), to_mv(self.timestamp_addr, 8).cast('Q') - self.value_mv[0] = value + if isinstance(base_addr, int): + self.value_mv, self.timestamp_mv = to_mv(self.value_addr, 8).cast('Q'), to_mv(self.timestamp_addr, 8).cast('Q') + self.value_mv[0] = value @property def value(self) -> int: return self.value_mv[0]