mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
nv hcq bind api (#4629)
* hcq bind api for nv * linter * linter * add test * small comment
This commit is contained in:
48
test/external/external_test_hcq.py
vendored
48
test/external/external_test_hcq.py
vendored
@@ -107,6 +107,38 @@ class TestHCQ(unittest.TestCase):
|
||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
|
||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}, should not be updated"
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
|
||||
def test_bind_run(self):
|
||||
temp_signal = TestHCQ.d0._get_signal(value=0)
|
||||
q = TestHCQ.compute_queue()
|
||||
q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
|
||||
q.signal(temp_signal, 2).wait(temp_signal, 2)
|
||||
q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr+TestHCQ.kernargs_size, TestHCQ.runner.p.global_size,
|
||||
TestHCQ.runner.p.local_size)
|
||||
q.bind(TestHCQ.d0)
|
||||
for _ in range(1000):
|
||||
TestHCQ.d0._set_signal(temp_signal, 1)
|
||||
q.submit(TestHCQ.d0)
|
||||
TestHCQ.compute_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
assert (val:=TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]) == 2000.0, f"got val {val}"
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
|
||||
def test_update_exec_binded(self):
|
||||
q = TestHCQ.compute_queue()
|
||||
exec_ptr = q.ptr()
|
||||
q.exec(TestHCQ.runner.clprg, TestHCQ.d0.kernargs_ptr, TestHCQ.runner.p.global_size, TestHCQ.runner.p.local_size)
|
||||
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
q.bind(TestHCQ.d0)
|
||||
|
||||
q.update_exec(exec_ptr, (1,1,1), (1,1,1))
|
||||
q.submit(TestHCQ.d0)
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
|
||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}, should not be updated"
|
||||
|
||||
@unittest.skipIf(CI, "Can't handle async update on CPU")
|
||||
def test_wait_signal(self):
|
||||
temp_signal = TestHCQ.d0._get_signal(value=0)
|
||||
@@ -191,6 +223,22 @@ class TestHCQ(unittest.TestCase):
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 1.0, f"got val {val}"
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
|
||||
def test_bind_copy(self):
|
||||
q = TestHCQ.copy_queue()
|
||||
q.copy(TestHCQ.a.lazydata.buffer._buf.va_addr, TestHCQ.b.lazydata.buffer._buf.va_addr, 8)
|
||||
q.copy(TestHCQ.b.lazydata.buffer._buf.va_addr, TestHCQ.a.lazydata.buffer._buf.va_addr, 8)
|
||||
q.bind(TestHCQ.d0)
|
||||
for _ in range(1000):
|
||||
q.submit(TestHCQ.d0)
|
||||
TestHCQ.copy_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
# confirm the signal didn't exceed the put value
|
||||
with self.assertRaises(RuntimeError):
|
||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value + 1, timeout=50)
|
||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}"
|
||||
|
||||
def test_copy_bandwidth(self):
|
||||
# THEORY: the bandwidth is low here because it's only using one SDMA queue. I suspect it's more stable like this at least.
|
||||
SZ = 2_000_000_000
|
||||
|
||||
@@ -38,8 +38,6 @@ class HCQGraph(MultiGraphRunner):
|
||||
if ji.prg.device.dname.startswith("NV"): to_mv(self.kargs_addrs[j], 0x160).cast('I')[:] = array.array('I', ji.prg.clprg.constbuffer_0)
|
||||
|
||||
# Build queues.
|
||||
self.queue_list: List[Tuple[Any, ...]] = []
|
||||
|
||||
self.comp_queues: Dict[Compiled, Any] = collections.defaultdict(self.comp_hcq_t)
|
||||
self.comp_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
|
||||
self.comp_signal_val = {dev: 0 for dev in self.devices}
|
||||
@@ -58,7 +56,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
deps = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], (self.comp_signal[ji.prg.device], sig_val:=j+1))
|
||||
deps = [x for x in deps if x != self.comp_signal[ji.prg.device]] # remove wait for the same queue as all operations are ordered.
|
||||
deps = [x for x in deps if id(x[0]) != id(self.comp_signal[ji.prg.device])] # remove wait for the same queue as all operations are ordered.
|
||||
self.comp_signal_val[ji.prg.device] = sig_val
|
||||
|
||||
for sig, val in deps: self.comp_queues[ji.prg.device].wait(sig, val)
|
||||
@@ -81,10 +79,10 @@ class HCQGraph(MultiGraphRunner):
|
||||
|
||||
for dev in self.devices:
|
||||
if self.copy_signal_val[dev] > 0: self.comp_queues[dev].wait(self.copy_signal[dev], self.copy_signal_val[dev])
|
||||
for dep_dev in self.copy_to_devs: self.comp_queues[dev].wait(self.copy_signal[dep_dev], self.copy_signal_val[dep_dev])
|
||||
for dep_dev in self.copy_to_devs[dev]: self.comp_queues[dev].wait(self.copy_signal[dep_dev], self.copy_signal_val[dep_dev])
|
||||
|
||||
self.queue_list.append((self.comp_queues.pop(dev), dev))
|
||||
if self.copy_signal_val[dev] > 0: self.queue_list.append((self.copy_queues.pop(dev), dev))
|
||||
if hasattr(self.comp_queues[dev], 'bind'): self.comp_queues[dev].bind(dev)
|
||||
if hasattr(self.copy_queues[dev], 'bind') and self.copy_signal_val[dev] > 0: self.copy_queues[dev].bind(dev)
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
# Wait and restore signals
|
||||
@@ -109,14 +107,17 @@ class HCQGraph(MultiGraphRunner):
|
||||
queue.update_exec(cmd_ptr, *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
|
||||
|
||||
for dev in self.devices:
|
||||
# Submit sync with world and queues.
|
||||
self.comp_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
|
||||
.wait(self.kickoff_signal, self.kickoff_value).submit(dev)
|
||||
self.copy_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
|
||||
.wait(self.kickoff_signal, self.kickoff_value).submit(dev)
|
||||
self.comp_queues[dev].submit(dev)
|
||||
|
||||
for queue, dev in self.queue_list: queue.submit(dev)
|
||||
if self.copy_signal_val[dev] > 0:
|
||||
self.copy_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
|
||||
.wait(self.kickoff_signal, self.kickoff_value).submit(dev)
|
||||
self.copy_queues[dev].submit(dev)
|
||||
|
||||
for dev in self.devices:
|
||||
# Signal the final value
|
||||
self.comp_hcq_t().signal(dev.timeline_signal, dev.timeline_value).submit(dev)
|
||||
self.graph_timeline[dev] = dev.timeline_value
|
||||
dev.timeline_value += 1
|
||||
|
||||
@@ -83,10 +83,46 @@ class NVCompiler(Compiler):
|
||||
raise CompileError(f"compile failed: {_get_bytes(prog, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, cuda_check).decode()}")
|
||||
return _get_bytes(prog, cuda.nvrtcGetCUBIN, cuda.nvrtcGetCUBINSize, cuda_check)
|
||||
|
||||
class HWComputeQueue:
|
||||
def __init__(self): self.q = []
|
||||
class HWQueue:
|
||||
def __init__(self): self.q, self.binded_device = [], None
|
||||
def __del__(self):
|
||||
if self.binded_device is not None: self.binded_device._gpu_free(self.hw_page)
|
||||
|
||||
def ptr(self) -> int: return len(self.q)
|
||||
|
||||
def wait(self, signal, value=0):
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
|
||||
(3 << 0) | (1 << 12) | (1 << 24)] # ACQUIRE | ACQUIRE_SWITCH_TSG | PAYLOAD_SIZE_64BIT
|
||||
return self
|
||||
|
||||
def signal(self, signal, value=0, timestamp=False):
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
|
||||
(1 << 0) | (1 << 20) | (1 << 24) | ((1 << 25) if timestamp else 0)] # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 1), 0x0]
|
||||
return self
|
||||
|
||||
def bind(self, device: NVDevice):
|
||||
self.binded_device = device
|
||||
self.hw_page = device._gpu_alloc(len(self.q) * 4, map_to_cpu=True)
|
||||
hw_view = to_mv(self.hw_page.base, self.hw_page.length).cast("I")
|
||||
for i, value in enumerate(self.q): hw_view[i] = value
|
||||
|
||||
# From now on, the queue is on the device for faster submission.
|
||||
self.q = hw_view # type: ignore
|
||||
|
||||
def _submit(self, dev, gpu_ring, put_value, gpfifo_entries, gpfifo_token, gpu_ring_controls):
|
||||
if dev == self.binded_device: cmdq_addr = self.hw_page.base
|
||||
else:
|
||||
dev.cmdq[dev.cmdq_wptr//4:dev.cmdq_wptr//4+len(self.q)] = array.array('I', self.q)
|
||||
cmdq_addr = dev.cmdq_page.base+dev.cmdq_wptr
|
||||
dev.cmdq_wptr += len(self.q) * 4
|
||||
|
||||
gpu_ring[put_value % gpfifo_entries] = (cmdq_addr//4 << 2) | (len(self.q) << 42) | (1 << 41)
|
||||
gpu_ring_controls.GPPut = (put_value + 1) % gpfifo_entries
|
||||
dev.gpu_mmio[0x90 // 4] = gpfifo_token
|
||||
return put_value + 1
|
||||
|
||||
class HWComputeQueue(HWQueue):
|
||||
def copy_from_cpu(self, gpuaddr, data):
|
||||
self.q += [nvmethod(1, nv_gpu.NVC6C0_OFFSET_OUT_UPPER, 2), *nvdata64(gpuaddr)]
|
||||
self.q += [nvmethod(1, nv_gpu.NVC6C0_LINE_LENGTH_IN, 2), len(data)*4, 0x1]
|
||||
@@ -107,61 +143,26 @@ class HWComputeQueue:
|
||||
def update_exec(self, cmd_ptr, global_size, local_size):
|
||||
# Patch the exec cmd with new launch dims
|
||||
assert self.q[cmd_ptr + 2] == nvmethod(1, nv_gpu.NVC6C0_SET_INLINE_QMD_ADDRESS_A, 0x42),"The pointer does not point to a packet of this type"
|
||||
self.q[cmd_ptr + 5 + 12 : cmd_ptr + 5 + 15] = global_size
|
||||
self.q[cmd_ptr + 5 + 12 : cmd_ptr + 5 + 15] = array.array('I', global_size)
|
||||
self.q[cmd_ptr + 5 + 18] = (self.q[cmd_ptr + 5 + 18] & 0xffff) | ((local_size[0] & 0xffff) << 16)
|
||||
self.q[cmd_ptr + 5 + 19] = (local_size[1] & 0xffff) | ((local_size[2] & 0xffff) << 16)
|
||||
|
||||
def wait(self, signal, value=0):
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
|
||||
(3 << 0) | (1 << 12) | (1 << 24)] # ACQUIRE | ACQUIRE_SWITCH_TSG | PAYLOAD_SIZE_64BIT
|
||||
return self
|
||||
|
||||
def signal(self, signal, value=0, timestamp=False):
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
|
||||
(1 << 0) | (1 << 20) | (1 << 24) | ((1 << 25) if timestamp else 0)] # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 1), 0x0]
|
||||
return self
|
||||
|
||||
def submit(self, dev:NVDevice):
|
||||
if len(self.q) == 0: return
|
||||
assert len(self.q) < (1 << 21)
|
||||
dev.cmdq[dev.cmdq_wptr//4:dev.cmdq_wptr//4+len(self.q)] = array.array('I', self.q)
|
||||
fifo_entry = dev.compute_put_value % dev.compute_gpfifo_entries
|
||||
dev.compute_gpu_ring[fifo_entry] = ((dev.cmdq_page.base+dev.cmdq_wptr)//4 << 2) | (len(self.q) << 42) | (1 << 41)
|
||||
dev.compute_gpu_ring_controls.GPPut = (dev.compute_put_value + 1) % dev.compute_gpfifo_entries
|
||||
dev.compute_put_value += 1
|
||||
dev.gpu_mmio[0x90 // 4] = dev.compute_gpfifo_token
|
||||
dev.cmdq_wptr += len(self.q) * 4
|
||||
|
||||
class HWCopyQueue:
|
||||
def __init__(self): self.q = []
|
||||
dev.compute_put_value = self._submit(dev, dev.compute_gpu_ring, dev.compute_put_value, dev.compute_gpfifo_entries,
|
||||
dev.compute_gpfifo_token, dev.compute_gpu_ring_controls)
|
||||
|
||||
class HWCopyQueue(HWQueue):
|
||||
def copy(self, dest, src, copy_size):
|
||||
self.q += [nvmethod(4, nv_gpu.NVC6B5_OFFSET_IN_UPPER, 4), *nvdata64(src), *nvdata64(dest)]
|
||||
self.q += [nvmethod(4, nv_gpu.NVC6B5_LINE_LENGTH_IN, 1), copy_size]
|
||||
self.q += [nvmethod(4, nv_gpu.NVC6B5_LAUNCH_DMA, 1), 0x182] # TRANSFER_TYPE_NON_PIPELINED | DST_MEMORY_LAYOUT_PITCH | SRC_MEMORY_LAYOUT_PITCH
|
||||
return self
|
||||
|
||||
def wait(self, signal, value=0):
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
|
||||
(3 << 0) | (1 << 12) | (1 << 24)] # ACQUIRE | ACQUIRE_SWITCH_TSG | PAYLOAD_SIZE_64BIT
|
||||
return self
|
||||
|
||||
def signal(self, signal, value=0, timestamp=False):
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
|
||||
(1 << 0) | (1 << 20) | (1 << 24) | ((1 << 25) if timestamp else 0)] # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 1), 0x0]
|
||||
return self
|
||||
|
||||
def submit(self, dev:NVDevice):
|
||||
if len(self.q) == 0: return
|
||||
dev.cmdq[dev.cmdq_wptr//4:dev.cmdq_wptr//4+len(self.q)] = array.array('I', self.q)
|
||||
fifo_entry = dev.dma_put_value % dev.dma_gpfifo_entries
|
||||
dev.dma_gpu_ring[fifo_entry] = ((dev.cmdq_page.base+dev.cmdq_wptr)//4 << 2) | (len(self.q) << 42)
|
||||
dev.dma_gpu_ring_controls.GPPut = (dev.dma_put_value + 1) % dev.dma_gpfifo_entries
|
||||
dev.dma_put_value += 1
|
||||
dev.gpu_mmio[0x90 // 4] = dev.dma_gpfifo_token
|
||||
dev.cmdq_wptr += len(self.q) * 4
|
||||
dev.dma_put_value = self._submit(dev, dev.dma_gpu_ring, dev.dma_put_value, dev.dma_gpfifo_entries,
|
||||
dev.dma_gpfifo_token, dev.dma_gpu_ring_controls)
|
||||
|
||||
SHT_PROGBITS, SHT_NOBITS, SHF_ALLOC, SHF_EXECINSTR = 0x1, 0x8, 0x2, 0x4
|
||||
class NVProgram:
|
||||
|
||||
Reference in New Issue
Block a user