hcq update queue in place (#4626)

* do not self wait in hcq

* faster enqueue

* comments

* tests

* linter

* fix typo
This commit is contained in:
nimlgen
2024-05-17 22:18:20 +03:00
committed by GitHub
parent ca1df20fa9
commit 10cf8e459b
5 changed files with 45 additions and 19 deletions

View File

@@ -77,6 +77,7 @@ class PM4Executor(AMDQueue):
elif op == amd_gpu.PACKET3_RELEASE_MEM: self._exec_release_mem(n)
elif op == amd_gpu.PACKET3_WAIT_REG_MEM: cont = self._exec_wait_reg_mem(n)
elif op == amd_gpu.PACKET3_DISPATCH_DIRECT: self._exec_dispatch_direct(n)
elif op == amd_gpu.PACKET3_EVENT_WRITE: self._exec_event_write(n)
else: raise RuntimeError(f"PM4: Unknown opcode: {op}")
if not cont: return
@@ -153,6 +154,10 @@ class PM4Executor(AMDQueue):
assert prg_sz > 0, "Invalid prg ptr (not found in mapped ranges)"
remu.run_asm(prg_addr, prg_sz, *gl, *lc, args_addr)
def _exec_event_write(self, n):
assert n == 0
_ = self._next_dword() # do not emulate events for now
class SDMAExecutor(AMDQueue):
def __init__(self, gpu, base, size, rptr, wptr):
self.gpu, self.base = gpu, base

View File

@@ -96,6 +96,17 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 3.0, f"got val {val}"
def test_update_exec(self):
q = TestHCQ.compute_queue()
exec_ptr = q.ptr()
q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
q.update_exec(exec_ptr, (1,1,1), (1,1,1))
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}, should not be updated"
@unittest.skipIf(CI, "Can't handle async update on CPU")
def test_wait_signal(self):
temp_signal = TestHCQ.d0._get_signal(value=0)

View File

@@ -52,22 +52,20 @@ class HCQGraph(MultiGraphRunner):
self.kickoff_value = 0
self.graph_timeline = {dev: 0 for dev in self.devices}
self.exec_ptrs: Dict[int, Tuple[Any, int]] = {}
self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
for j,ji in enumerate(self.jit_cache):
if isinstance(ji.prg, CompiledRunner):
deps = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], (self.comp_signal[ji.prg.device], sig_val:=j+1))
deps.append((self.comp_signal[ji.prg.device], self.comp_signal_val[ji.prg.device]))
deps = [x for x in deps if x != self.comp_signal[ji.prg.device]] # remove wait for the same queue as all operations are ordered.
self.comp_signal_val[ji.prg.device] = sig_val
# Rebuilt runners with dynamic launch dims online.
if j in self.jc_idx_with_updatable_launch_dims:
if ji.prg.device in self.comp_queues: self.queue_list.append((self.comp_queues.pop(ji.prg.device), ji.prg.device))
self.queue_list.append((j, deps))
else:
for sig, val in deps: self.comp_queues[ji.prg.device].wait(sig, val)
self.comp_queues[ji.prg.device].exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals)) \
.signal(self.comp_signal[ji.prg.device], sig_val)
for sig, val in deps: self.comp_queues[ji.prg.device].wait(sig, val)
self.exec_ptrs[j] = (self.comp_queues[ji.prg.device], self.comp_queues[ji.prg.device].ptr())
self.comp_queues[ji.prg.device].exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals)) \
.signal(self.comp_signal[ji.prg.device], sig_val)
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
Device[src.device]._gpu_map(dest._buf) #type: ignore
@@ -106,22 +104,17 @@ class HCQGraph(MultiGraphRunner):
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
for j in self.jc_idx_with_updatable_launch_dims:
queue, cmd_ptr = self.exec_ptrs[j]
queue.update_exec(cmd_ptr, *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
for dev in self.devices:
self.comp_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
.wait(self.kickoff_signal, self.kickoff_value).submit(dev)
self.copy_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
.wait(self.kickoff_signal, self.kickoff_value).submit(dev)
for entry in self.queue_list:
if isinstance(entry[0], self.comp_hcq_t) or isinstance(entry[0], self.copy_hcq_t): queue, dev = entry
else:
# Kernel with dynamic launch bounds, rebuild it.
j, ji, deps, dev = entry[0], self.jit_cache[entry[0]], entry[1], self.jit_cache[entry[0]].prg.device
queue = self.comp_hcq_t()
for sig, val in deps: queue.wait(sig, val)
queue.exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals)) \
.signal(self.comp_signal[dev], value=j+1)
queue.submit(dev)
for queue, dev in self.queue_list: queue.submit(dev)
for dev in self.devices:
self.comp_hcq_t().signal(dev.timeline_signal, dev.timeline_value).submit(dev)

View File

@@ -113,6 +113,7 @@ CS_W32_EN = 1 << 15
class HWPM4Queue:
def __init__(self): self.q = []
def ptr(self) -> int: return len(self.q)
def hdp_flush(self):
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_WAIT_REG_MEM, 5),
@@ -163,8 +164,15 @@ class HWPM4Queue:
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_SET_SH_REG, 8), regCOMPUTE_START_X, 0, 0, 0, *local_size, 0, 0]
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_SET_SH_REG, 1), regCOMPUTE_RESOURCE_LIMITS, 0]
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_DISPATCH_DIRECT, 3), *global_size, CS_W32_EN | FORCE_START_AT_000 | COMPUTE_SHADER_EN]
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_EVENT_WRITE, 0), amd_gpu.EVENT_TYPE(7) | amd_gpu.EVENT_INDEX(4)]
return self
def update_exec(self, cmd_ptr, global_size, local_size):
# Patch the exec cmd with new launch dims
assert self.q[cmd_ptr + 67] == amd_gpu.PACKET3(amd_gpu.PACKET3_DISPATCH_DIRECT, 3),"The pointer does not point to a packet of this type"
self.q[cmd_ptr + 59 : cmd_ptr + 62] = local_size
self.q[cmd_ptr + 68 : cmd_ptr + 71] = global_size
def wait(self, signal:hsa.amd_signal_t, value=0):
addr = ctypes.addressof(signal) + SIGNAL_VALUE_OFFSET
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_WAIT_REG_MEM, 5),

View File

@@ -85,6 +85,8 @@ class NVCompiler(Compiler):
class HWComputeQueue:
def __init__(self): self.q = []
def ptr(self) -> int: return len(self.q)
def copy_from_cpu(self, gpuaddr, data):
self.q += [nvmethod(1, nv_gpu.NVC6C0_OFFSET_OUT_UPPER, 2), *nvdata64(gpuaddr)]
self.q += [nvmethod(1, nv_gpu.NVC6C0_LINE_LENGTH_IN, 2), len(data)*4, 0x1]
@@ -102,6 +104,13 @@ class HWComputeQueue:
self.q += [x for x in to_mv(ctypes.addressof(prg.qmd), ctypes.sizeof(prg.qmd)).cast("I")]
return self
def update_exec(self, cmd_ptr, global_size, local_size):
# Patch the exec cmd with new launch dims
assert self.q[cmd_ptr + 2] == nvmethod(1, nv_gpu.NVC6C0_SET_INLINE_QMD_ADDRESS_A, 0x42),"The pointer does not point to a packet of this type"
self.q[cmd_ptr + 5 + 12 : cmd_ptr + 5 + 15] = global_size
self.q[cmd_ptr + 5 + 18] = (self.q[cmd_ptr + 5 + 18] & 0xffff) | ((local_size[0] & 0xffff) << 16)
self.q[cmd_ptr + 5 + 19] = (local_size[1] & 0xffff) | ((local_size[2] & 0xffff) << 16)
def wait(self, signal, value=0):
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
(3 << 0) | (1 << 12) | (1 << 24)] # ACQUIRE | ACQUIRE_SWITCH_TSG | PAYLOAD_SIZE_64BIT