docs: add more info on HCQProgram (#5683)

* docs: add more info on HCQProgram

* linter

* linter2

* one more type
This commit is contained in:
nimlgen
2024-07-24 17:20:18 +03:00
committed by GitHub
parent baface413a
commit 2ea54176e2
2 changed files with 31 additions and 11 deletions

View File

@@ -416,28 +416,48 @@ class HCQProgram:
def __init__(self, device:HCQCompiled, name:str, kernargs_alloc_size:int, kernargs_args_offset:int=0):
self.device, self.name, self.kernargs_alloc_size, self.kernargs_args_offset = device, name, kernargs_alloc_size, kernargs_args_offset
def fill_kernargs(self, bufs:Tuple[Any, ...], vals:Tuple[int, ...]=(), kernargs_ptr:Optional[int]=None):
def fill_kernargs(self, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=(), kernargs_ptr:Optional[int]=None) -> int:
"""
Fills arguments for the kernel, optionally allocating space from device if kernargs_ptr is not set.
Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided.
Args:
bufs: Buffers to be written to kernel arguments.
vals: Values to be written to kernel arguments.
kernargs_ptr: Optional pointer to pre-allocated kernel arguments memory.
Returns:
Pointer to the filled kernel arguments.
"""
self._fill_kernargs(ptr:=(kernargs_ptr or self.device._alloc_kernargs(self.kernargs_alloc_size)), bufs, vals)
return ptr
def _fill_kernargs(self, kernargs_ptr:int, bufs:Tuple[Any, ...], vals:Tuple[int, ...]=()): raise NotImplementedError("need fill_kernargs")
def _fill_kernargs(self, kernargs_ptr:int, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()): raise NotImplementedError("need fill_kernargs")
def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
kernargs_ptr = self.fill_kernargs(args, vals)
def __call__(self, *bufs:HCQBuffer, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1),
vals:Tuple[int, ...]=(), wait:bool=False) -> Optional[float]:
"""
Enqueues the program for execution with the given arguments and dimensions.
Args:
bufs: Buffer arguments to execute the kernel with.
global_size: Specifies the global work size for kernel execution (equivalent to CUDA's grid size).
local_size: Specifies the local work size for kernel execution (equivalent to CUDA's block size).
vals: Value arguments to execute the kernel with.
wait: If True, waits for the kernel to complete execution.
Returns:
Execution time of the kernel if 'wait' is True, otherwise None.
"""
q = self.device.hw_compute_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1).memory_barrier()
with hcq_profile(self.device, queue=q, desc=self.name, enabled=wait or PROFILE) as (sig_st, sig_en):
q.exec(self, kernargs_ptr, global_size, local_size)
q.exec(self, self.fill_kernargs(bufs, vals), global_size, local_size)
q.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
self.device.timeline_value += 1
if wait:
self.device.timeline_signal.wait(self.device.timeline_value - 1)
return (sig_en.timestamp - sig_st.timestamp) / 1e6
if wait: self.device.timeline_signal.wait(self.device.timeline_value - 1)
return ((sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
class HCQCompiled(Compiled):
"""

View File

@@ -301,11 +301,11 @@ class NVProgram(HCQProgram):
kernargs = [arg_half for arg in bufs for arg_half in data64_le(arg.va_addr)] + list(vals)
to_mv(kernargs_ptr, (len(self.constbuffer_0) + len(kernargs)) * 4).cast('I')[:] = array.array('I', self.constbuffer_0 + kernargs)
def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
if prod(local_size) > 1024 or self.max_threads < prod(local_size): raise RuntimeError("Too many resources requsted for launch")
if any(cur > mx for cur,mx in zip(global_size, [2147483647, 65535, 65535])) or any(cur > mx for cur,mx in zip(local_size, [1024, 1024, 64])):
raise RuntimeError(f"Invalid global/local dims {global_size=}, {local_size=}")
return super().__call__(*args, global_size=global_size, local_size=local_size, vals=vals, wait=wait)
return super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait)
class NVAllocator(HCQAllocator):
def __init__(self, device:NVDevice): super().__init__(device)