mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
hcq update queue in place (#4626)
* do not self wait in hcq * faster enqueue * comments * tests * linter * fix typo
This commit is contained in:
@@ -77,6 +77,7 @@ class PM4Executor(AMDQueue):
|
||||
elif op == amd_gpu.PACKET3_RELEASE_MEM: self._exec_release_mem(n)
|
||||
elif op == amd_gpu.PACKET3_WAIT_REG_MEM: cont = self._exec_wait_reg_mem(n)
|
||||
elif op == amd_gpu.PACKET3_DISPATCH_DIRECT: self._exec_dispatch_direct(n)
|
||||
elif op == amd_gpu.PACKET3_EVENT_WRITE: self._exec_event_write(n)
|
||||
else: raise RuntimeError(f"PM4: Unknown opcode: {op}")
|
||||
if not cont: return
|
||||
|
||||
@@ -153,6 +154,10 @@ class PM4Executor(AMDQueue):
|
||||
assert prg_sz > 0, "Invalid prg ptr (not found in mapped ranges)"
|
||||
remu.run_asm(prg_addr, prg_sz, *gl, *lc, args_addr)
|
||||
|
||||
def _exec_event_write(self, n):
|
||||
assert n == 0
|
||||
_ = self._next_dword() # do not emulate events for now
|
||||
|
||||
class SDMAExecutor(AMDQueue):
|
||||
def __init__(self, gpu, base, size, rptr, wptr):
|
||||
self.gpu, self.base = gpu, base
|
||||
|
||||
11
test/external/external_test_hcq.py
vendored
11
test/external/external_test_hcq.py
vendored
@@ -96,6 +96,17 @@ class TestHCQ(unittest.TestCase):
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 3.0, f"got val {val}"
|
||||
|
||||
def test_update_exec(self):
|
||||
q = TestHCQ.compute_queue()
|
||||
exec_ptr = q.ptr()
|
||||
q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
|
||||
q.update_exec(exec_ptr, (1,1,1), (1,1,1))
|
||||
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
|
||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}, should not be updated"
|
||||
|
||||
@unittest.skipIf(CI, "Can't handle async update on CPU")
|
||||
def test_wait_signal(self):
|
||||
temp_signal = TestHCQ.d0._get_signal(value=0)
|
||||
|
||||
@@ -52,22 +52,20 @@ class HCQGraph(MultiGraphRunner):
|
||||
self.kickoff_value = 0
|
||||
self.graph_timeline = {dev: 0 for dev in self.devices}
|
||||
|
||||
self.exec_ptrs: Dict[int, Tuple[Any, int]] = {}
|
||||
self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
deps = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], (self.comp_signal[ji.prg.device], sig_val:=j+1))
|
||||
deps.append((self.comp_signal[ji.prg.device], self.comp_signal_val[ji.prg.device]))
|
||||
deps = [x for x in deps if x != self.comp_signal[ji.prg.device]] # remove wait for the same queue as all operations are ordered.
|
||||
self.comp_signal_val[ji.prg.device] = sig_val
|
||||
|
||||
# Rebuilt runners with dynamic launch dims online.
|
||||
if j in self.jc_idx_with_updatable_launch_dims:
|
||||
if ji.prg.device in self.comp_queues: self.queue_list.append((self.comp_queues.pop(ji.prg.device), ji.prg.device))
|
||||
self.queue_list.append((j, deps))
|
||||
else:
|
||||
for sig, val in deps: self.comp_queues[ji.prg.device].wait(sig, val)
|
||||
self.comp_queues[ji.prg.device].exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals)) \
|
||||
.signal(self.comp_signal[ji.prg.device], sig_val)
|
||||
for sig, val in deps: self.comp_queues[ji.prg.device].wait(sig, val)
|
||||
|
||||
self.exec_ptrs[j] = (self.comp_queues[ji.prg.device], self.comp_queues[ji.prg.device].ptr())
|
||||
self.comp_queues[ji.prg.device].exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals)) \
|
||||
.signal(self.comp_signal[ji.prg.device], sig_val)
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
Device[src.device]._gpu_map(dest._buf) #type: ignore
|
||||
@@ -106,22 +104,17 @@ class HCQGraph(MultiGraphRunner):
|
||||
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
|
||||
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
|
||||
|
||||
for j in self.jc_idx_with_updatable_launch_dims:
|
||||
queue, cmd_ptr = self.exec_ptrs[j]
|
||||
queue.update_exec(cmd_ptr, *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
|
||||
|
||||
for dev in self.devices:
|
||||
self.comp_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
|
||||
.wait(self.kickoff_signal, self.kickoff_value).submit(dev)
|
||||
self.copy_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
|
||||
.wait(self.kickoff_signal, self.kickoff_value).submit(dev)
|
||||
|
||||
for entry in self.queue_list:
|
||||
if isinstance(entry[0], self.comp_hcq_t) or isinstance(entry[0], self.copy_hcq_t): queue, dev = entry
|
||||
else:
|
||||
# Kernel with dynamic launch bounds, rebuild it.
|
||||
j, ji, deps, dev = entry[0], self.jit_cache[entry[0]], entry[1], self.jit_cache[entry[0]].prg.device
|
||||
queue = self.comp_hcq_t()
|
||||
for sig, val in deps: queue.wait(sig, val)
|
||||
queue.exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals)) \
|
||||
.signal(self.comp_signal[dev], value=j+1)
|
||||
queue.submit(dev)
|
||||
for queue, dev in self.queue_list: queue.submit(dev)
|
||||
|
||||
for dev in self.devices:
|
||||
self.comp_hcq_t().signal(dev.timeline_signal, dev.timeline_value).submit(dev)
|
||||
|
||||
@@ -113,6 +113,7 @@ CS_W32_EN = 1 << 15
|
||||
|
||||
class HWPM4Queue:
|
||||
def __init__(self): self.q = []
|
||||
def ptr(self) -> int: return len(self.q)
|
||||
|
||||
def hdp_flush(self):
|
||||
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_WAIT_REG_MEM, 5),
|
||||
@@ -163,8 +164,15 @@ class HWPM4Queue:
|
||||
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_SET_SH_REG, 8), regCOMPUTE_START_X, 0, 0, 0, *local_size, 0, 0]
|
||||
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_SET_SH_REG, 1), regCOMPUTE_RESOURCE_LIMITS, 0]
|
||||
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_DISPATCH_DIRECT, 3), *global_size, CS_W32_EN | FORCE_START_AT_000 | COMPUTE_SHADER_EN]
|
||||
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_EVENT_WRITE, 0), amd_gpu.EVENT_TYPE(7) | amd_gpu.EVENT_INDEX(4)]
|
||||
return self
|
||||
|
||||
def update_exec(self, cmd_ptr, global_size, local_size):
|
||||
# Patch the exec cmd with new launch dims
|
||||
assert self.q[cmd_ptr + 67] == amd_gpu.PACKET3(amd_gpu.PACKET3_DISPATCH_DIRECT, 3),"The pointer does not point to a packet of this type"
|
||||
self.q[cmd_ptr + 59 : cmd_ptr + 62] = local_size
|
||||
self.q[cmd_ptr + 68 : cmd_ptr + 71] = global_size
|
||||
|
||||
def wait(self, signal:hsa.amd_signal_t, value=0):
|
||||
addr = ctypes.addressof(signal) + SIGNAL_VALUE_OFFSET
|
||||
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_WAIT_REG_MEM, 5),
|
||||
|
||||
@@ -85,6 +85,8 @@ class NVCompiler(Compiler):
|
||||
|
||||
class HWComputeQueue:
|
||||
def __init__(self): self.q = []
|
||||
def ptr(self) -> int: return len(self.q)
|
||||
|
||||
def copy_from_cpu(self, gpuaddr, data):
|
||||
self.q += [nvmethod(1, nv_gpu.NVC6C0_OFFSET_OUT_UPPER, 2), *nvdata64(gpuaddr)]
|
||||
self.q += [nvmethod(1, nv_gpu.NVC6C0_LINE_LENGTH_IN, 2), len(data)*4, 0x1]
|
||||
@@ -102,6 +104,13 @@ class HWComputeQueue:
|
||||
self.q += [x for x in to_mv(ctypes.addressof(prg.qmd), ctypes.sizeof(prg.qmd)).cast("I")]
|
||||
return self
|
||||
|
||||
def update_exec(self, cmd_ptr, global_size, local_size):
|
||||
# Patch the exec cmd with new launch dims
|
||||
assert self.q[cmd_ptr + 2] == nvmethod(1, nv_gpu.NVC6C0_SET_INLINE_QMD_ADDRESS_A, 0x42),"The pointer does not point to a packet of this type"
|
||||
self.q[cmd_ptr + 5 + 12 : cmd_ptr + 5 + 15] = global_size
|
||||
self.q[cmd_ptr + 5 + 18] = (self.q[cmd_ptr + 5 + 18] & 0xffff) | ((local_size[0] & 0xffff) << 16)
|
||||
self.q[cmd_ptr + 5 + 19] = (local_size[1] & 0xffff) | ((local_size[2] & 0xffff) << 16)
|
||||
|
||||
def wait(self, signal, value=0):
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
|
||||
(3 << 0) | (1 << 12) | (1 << 24)] # ACQUIRE | ACQUIRE_SWITCH_TSG | PAYLOAD_SIZE_64BIT
|
||||
|
||||
Reference in New Issue
Block a user