nv hcq bind api (#4629)

* hcq bind api for nv

* linter

* linter

* add test

* small comment
This commit is contained in:
nimlgen
2024-05-19 23:17:10 +03:00
committed by GitHub
parent d308f4fa9a
commit c9f7f2da70
3 changed files with 103 additions and 53 deletions

View File

@@ -107,6 +107,38 @@ class TestHCQ(unittest.TestCase):
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}, should not be updated"
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
def test_bind_run(self):
temp_signal = TestHCQ.d0._get_signal(value=0)
q = TestHCQ.compute_queue()
q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
q.signal(temp_signal, 2).wait(temp_signal, 2)
q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.runner.p.global_size,
TestHCQ.runner.p.local_size)
q.bind(TestHCQ.d0)
for _ in range(1000):
TestHCQ.d0._set_signal(temp_signal, 1)
q.submit(TestHCQ.d0)
TestHCQ.compute_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]) == 2000.0, f"got val {val}"
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
def test_update_exec_binded(self):
q = TestHCQ.compute_queue()
exec_ptr = q.ptr()
q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
q.bind(TestHCQ.d0)
q.update_exec(exec_ptr, (1,1,1), (1,1,1))
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}, should not be updated"
@unittest.skipIf(CI, "Can't handle async update on CPU")
def test_wait_signal(self):
temp_signal = TestHCQ.d0._get_signal(value=0)
@@ -191,6 +223,22 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 1.0, f"got val {val}"
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
def test_bind_copy(self):
q = TestHCQ.copy_queue()
q.copy(TestHCQ.a.lazydata.buffer._buf.va_addr, TestHCQ.b.lazydata.buffer._buf.va_addr, 8)
q.copy(TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr, 8)
q.bind(TestHCQ.d0)
for _ in range(1000):
q.submit(TestHCQ.d0)
TestHCQ.copy_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
# confirm the signal didn't exceed the put value
with self.assertRaises(RuntimeError):
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value + 1, timeout=50)
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}"
def test_copy_bandwidth(self):
# THEORY: the bandwidth is low here because it's only using one SDMA queue. I suspect it's more stable like this at least.
SZ = 2_000_000_000

View File

@@ -38,8 +38,6 @@ class HCQGraph(MultiGraphRunner):
if ji.prg.device.dname.startswith("NV"): to_mv(self.kargs_addrs[j], 0x160).cast('I')[:] = array.array('I', ji.prg.clprg.constbuffer_0)
# Build queues.
self.queue_list: List[Tuple[Any, ...]] = []
self.comp_queues: Dict[Compiled, Any] = collections.defaultdict(self.comp_hcq_t)
self.comp_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
self.comp_signal_val = {dev: 0 for dev in self.devices}
@@ -58,7 +56,7 @@ class HCQGraph(MultiGraphRunner):
for j,ji in enumerate(self.jit_cache):
if isinstance(ji.prg, CompiledRunner):
deps = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], (self.comp_signal[ji.prg.device], sig_val:=j+1))
deps = [x for x in deps if x != self.comp_signal[ji.prg.device]] # remove wait for the same queue as all operations are ordered.
deps = [x for x in deps if id(x[0]) != id(self.comp_signal[ji.prg.device])] # remove wait for the same queue as all operations are ordered.
self.comp_signal_val[ji.prg.device] = sig_val
for sig, val in deps: self.comp_queues[ji.prg.device].wait(sig, val)
@@ -81,10 +79,10 @@ class HCQGraph(MultiGraphRunner):
for dev in self.devices:
if self.copy_signal_val[dev] > 0: self.comp_queues[dev].wait(self.copy_signal[dev], self.copy_signal_val[dev])
for dep_dev in self.copy_to_devs: self.comp_queues[dev].wait(self.copy_signal[dep_dev], self.copy_signal_val[dep_dev])
for dep_dev in self.copy_to_devs[dev]: self.comp_queues[dev].wait(self.copy_signal[dep_dev], self.copy_signal_val[dep_dev])
self.queue_list.append((self.comp_queues.pop(dev), dev))
if self.copy_signal_val[dev] > 0: self.queue_list.append((self.copy_queues.pop(dev), dev))
if hasattr(self.comp_queues[dev], 'bind'): self.comp_queues[dev].bind(dev)
if hasattr(self.copy_queues[dev], 'bind') and self.copy_signal_val[dev] > 0: self.copy_queues[dev].bind(dev)
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
# Wait and restore signals
@@ -109,14 +107,17 @@ class HCQGraph(MultiGraphRunner):
queue.update_exec(cmd_ptr, *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
for dev in self.devices:
# Submit sync with world and queues.
self.comp_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
.wait(self.kickoff_signal, self.kickoff_value).submit(dev)
self.copy_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
.wait(self.kickoff_signal, self.kickoff_value).submit(dev)
self.comp_queues[dev].submit(dev)
for queue, dev in self.queue_list: queue.submit(dev)
if self.copy_signal_val[dev] > 0:
self.copy_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
.wait(self.kickoff_signal, self.kickoff_value).submit(dev)
self.copy_queues[dev].submit(dev)
for dev in self.devices:
# Signal the final value
self.comp_hcq_t().signal(dev.timeline_signal, dev.timeline_value).submit(dev)
self.graph_timeline[dev] = dev.timeline_value
dev.timeline_value += 1

View File

@@ -83,10 +83,46 @@ class NVCompiler(Compiler):
raise CompileError(f"compile failed: {_get_bytes(prog, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, cuda_check).decode()}")
return _get_bytes(prog, cuda.nvrtcGetCUBIN, cuda.nvrtcGetCUBINSize, cuda_check)
class HWComputeQueue:
def __init__(self): self.q = []
class HWQueue:
def __init__(self): self.q, self.binded_device = [], None
def __del__(self):
if self.binded_device is not None: self.binded_device._gpu_free(self.hw_page)
def ptr(self) -> int: return len(self.q)
def wait(self, signal, value=0):
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
(3 << 0) | (1 << 12) | (1 << 24)] # ACQUIRE | ACQUIRE_SWITCH_TSG | PAYLOAD_SIZE_64BIT
return self
def signal(self, signal, value=0, timestamp=False):
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
(1 << 0) | (1 << 20) | (1 << 24) | ((1 << 25) if timestamp else 0)] # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP
self.q += [nvmethod(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 1), 0x0]
return self
def bind(self, device: NVDevice):
self.binded_device = device
self.hw_page = device._gpu_alloc(len(self.q) * 4, map_to_cpu=True)
hw_view = to_mv(self.hw_page.base, self.hw_page.length).cast("I")
for i, value in enumerate(self.q): hw_view[i] = value
# From now on, the queue is on the device for faster submission.
self.q = hw_view # type: ignore
def _submit(self, dev, gpu_ring, put_value, gpfifo_entries, gpfifo_token, gpu_ring_controls):
if dev == self.binded_device: cmdq_addr = self.hw_page.base
else:
dev.cmdq[dev.cmdq_wptr//4:dev.cmdq_wptr//4+len(self.q)] = array.array('I', self.q)
cmdq_addr = dev.cmdq_page.base+dev.cmdq_wptr
dev.cmdq_wptr += len(self.q) * 4
gpu_ring[put_value % gpfifo_entries] = (cmdq_addr//4 << 2) | (len(self.q) << 42) | (1 << 41)
gpu_ring_controls.GPPut = (put_value + 1) % gpfifo_entries
dev.gpu_mmio[0x90 // 4] = gpfifo_token
return put_value + 1
class HWComputeQueue(HWQueue):
def copy_from_cpu(self, gpuaddr, data):
self.q += [nvmethod(1, nv_gpu.NVC6C0_OFFSET_OUT_UPPER, 2), *nvdata64(gpuaddr)]
self.q += [nvmethod(1, nv_gpu.NVC6C0_LINE_LENGTH_IN, 2), len(data)*4, 0x1]
@@ -107,61 +143,26 @@ class HWComputeQueue:
def update_exec(self, cmd_ptr, global_size, local_size):
# Patch the exec cmd with new launch dims
assert self.q[cmd_ptr + 2] == nvmethod(1, nv_gpu.NVC6C0_SET_INLINE_QMD_ADDRESS_A, 0x42),"The pointer does not point to a packet of this type"
self.q[cmd_ptr + 5 + 12 : cmd_ptr + 5 + 15] = global_size
self.q[cmd_ptr + 5 + 12 : cmd_ptr + 5 + 15] = array.array('I', global_size)
self.q[cmd_ptr + 5 + 18] = (self.q[cmd_ptr + 5 + 18] & 0xffff) | ((local_size[0] & 0xffff) << 16)
self.q[cmd_ptr + 5 + 19] = (local_size[1] & 0xffff) | ((local_size[2] & 0xffff) << 16)
def wait(self, signal, value=0):
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
(3 << 0) | (1 << 12) | (1 << 24)] # ACQUIRE | ACQUIRE_SWITCH_TSG | PAYLOAD_SIZE_64BIT
return self
def signal(self, signal, value=0, timestamp=False):
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
(1 << 0) | (1 << 20) | (1 << 24) | ((1 << 25) if timestamp else 0)] # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP
self.q += [nvmethod(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 1), 0x0]
return self
def submit(self, dev:NVDevice):
if len(self.q) == 0: return
assert len(self.q) < (1 << 21)
dev.cmdq[dev.cmdq_wptr//4:dev.cmdq_wptr//4+len(self.q)] = array.array('I', self.q)
fifo_entry = dev.compute_put_value % dev.compute_gpfifo_entries
dev.compute_gpu_ring[fifo_entry] = ((dev.cmdq_page.base+dev.cmdq_wptr)//4 << 2) | (len(self.q) << 42) | (1 << 41)
dev.compute_gpu_ring_controls.GPPut = (dev.compute_put_value + 1) % dev.compute_gpfifo_entries
dev.compute_put_value += 1
dev.gpu_mmio[0x90 // 4] = dev.compute_gpfifo_token
dev.cmdq_wptr += len(self.q) * 4
class HWCopyQueue:
def __init__(self): self.q = []
dev.compute_put_value = self._submit(dev, dev.compute_gpu_ring, dev.compute_put_value, dev.compute_gpfifo_entries,
dev.compute_gpfifo_token, dev.compute_gpu_ring_controls)
class HWCopyQueue(HWQueue):
def copy(self, dest, src, copy_size):
self.q += [nvmethod(4, nv_gpu.NVC6B5_OFFSET_IN_UPPER, 4), *nvdata64(src), *nvdata64(dest)]
self.q += [nvmethod(4, nv_gpu.NVC6B5_LINE_LENGTH_IN, 1), copy_size]
self.q += [nvmethod(4, nv_gpu.NVC6B5_LAUNCH_DMA, 1), 0x182] # TRANSFER_TYPE_NON_PIPELINED | DST_MEMORY_LAYOUT_PITCH | SRC_MEMORY_LAYOUT_PITCH
return self
def wait(self, signal, value=0):
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
(3 << 0) | (1 << 12) | (1 << 24)] # ACQUIRE | ACQUIRE_SWITCH_TSG | PAYLOAD_SIZE_64BIT
return self
def signal(self, signal, value=0, timestamp=False):
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
(1 << 0) | (1 << 20) | (1 << 24) | ((1 << 25) if timestamp else 0)] # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP
self.q += [nvmethod(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 1), 0x0]
return self
def submit(self, dev:NVDevice):
if len(self.q) == 0: return
dev.cmdq[dev.cmdq_wptr//4:dev.cmdq_wptr//4+len(self.q)] = array.array('I', self.q)
fifo_entry = dev.dma_put_value % dev.dma_gpfifo_entries
dev.dma_gpu_ring[fifo_entry] = ((dev.cmdq_page.base+dev.cmdq_wptr)//4 << 2) | (len(self.q) << 42)
dev.dma_gpu_ring_controls.GPPut = (dev.dma_put_value + 1) % dev.dma_gpfifo_entries
dev.dma_put_value += 1
dev.gpu_mmio[0x90 // 4] = dev.dma_gpfifo_token
dev.cmdq_wptr += len(self.q) * 4
dev.dma_put_value = self._submit(dev, dev.dma_gpu_ring, dev.dma_put_value, dev.dma_gpfifo_entries,
dev.dma_gpfifo_token, dev.dma_gpu_ring_controls)
SHT_PROGBITS, SHT_NOBITS, SHF_ALLOC, SHF_EXECINSTR = 0x1, 0x8, 0x2, 0x4
class NVProgram: