mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-10 14:45:35 -05:00
docs: add more info on HCQProgram (#5683)
* docs: add more info on HCQProgram * linter * linter2 * one more type
This commit is contained in:
@@ -416,28 +416,48 @@ class HCQProgram:
|
||||
def __init__(self, device:HCQCompiled, name:str, kernargs_alloc_size:int, kernargs_args_offset:int=0):
|
||||
self.device, self.name, self.kernargs_alloc_size, self.kernargs_args_offset = device, name, kernargs_alloc_size, kernargs_args_offset
|
||||
|
||||
def fill_kernargs(self, bufs:Tuple[Any, ...], vals:Tuple[int, ...]=(), kernargs_ptr:Optional[int]=None):
|
||||
def fill_kernargs(self, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=(), kernargs_ptr:Optional[int]=None) -> int:
|
||||
"""
|
||||
Fills arguments for the kernel, optionally allocating space from device if kernargs_ptr is not set.
|
||||
Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided.
|
||||
|
||||
Args:
|
||||
bufs: Buffers to be written to kernel arguments.
|
||||
vals: Values to be written to kernel arguments.
|
||||
kernargs_ptr: Optional pointer to pre-allocated kernel arguments memory.
|
||||
|
||||
Returns:
|
||||
Pointer to the filled kernel arguments.
|
||||
"""
|
||||
self._fill_kernargs(ptr:=(kernargs_ptr or self.device._alloc_kernargs(self.kernargs_alloc_size)), bufs, vals)
|
||||
return ptr
|
||||
def _fill_kernargs(self, kernargs_ptr:int, bufs:Tuple[Any, ...], vals:Tuple[int, ...]=()): raise NotImplementedError("need fill_kernargs")
|
||||
def _fill_kernargs(self, kernargs_ptr:int, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()): raise NotImplementedError("need fill_kernargs")
|
||||
|
||||
def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
||||
kernargs_ptr = self.fill_kernargs(args, vals)
|
||||
def __call__(self, *bufs:HCQBuffer, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1),
|
||||
vals:Tuple[int, ...]=(), wait:bool=False) -> Optional[float]:
|
||||
"""
|
||||
Enqueues the program for execution with the given arguments and dimensions.
|
||||
|
||||
Args:
|
||||
bufs: Buffer arguments to execute the kernel with.
|
||||
global_size: Specifies the global work size for kernel execution (equivalent to CUDA's grid size).
|
||||
local_size: Specifies the local work size for kernel execution (equivalent to CUDA's block size).
|
||||
vals: Value arguments to execute the kernel with.
|
||||
wait: If True, waits for the kernel to complete execution.
|
||||
|
||||
Returns:
|
||||
Execution time of the kernel if 'wait' is True, otherwise None.
|
||||
"""
|
||||
|
||||
q = self.device.hw_compute_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1).memory_barrier()
|
||||
|
||||
with hcq_profile(self.device, queue=q, desc=self.name, enabled=wait or PROFILE) as (sig_st, sig_en):
|
||||
q.exec(self, kernargs_ptr, global_size, local_size)
|
||||
q.exec(self, self.fill_kernargs(bufs, vals), global_size, local_size)
|
||||
|
||||
q.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
||||
self.device.timeline_value += 1
|
||||
|
||||
if wait:
|
||||
self.device.timeline_signal.wait(self.device.timeline_value - 1)
|
||||
return (sig_en.timestamp - sig_st.timestamp) / 1e6
|
||||
if wait: self.device.timeline_signal.wait(self.device.timeline_value - 1)
|
||||
return ((sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
|
||||
|
||||
class HCQCompiled(Compiled):
|
||||
"""
|
||||
|
||||
@@ -301,11 +301,11 @@ class NVProgram(HCQProgram):
|
||||
kernargs = [arg_half for arg in bufs for arg_half in data64_le(arg.va_addr)] + list(vals)
|
||||
to_mv(kernargs_ptr, (len(self.constbuffer_0) + len(kernargs)) * 4).cast('I')[:] = array.array('I', self.constbuffer_0 + kernargs)
|
||||
|
||||
def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
||||
def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
||||
if prod(local_size) > 1024 or self.max_threads < prod(local_size): raise RuntimeError("Too many resources requsted for launch")
|
||||
if any(cur > mx for cur,mx in zip(global_size, [2147483647, 65535, 65535])) or any(cur > mx for cur,mx in zip(local_size, [1024, 1024, 64])):
|
||||
raise RuntimeError(f"Invalid global/local dims {global_size=}, {local_size=}")
|
||||
return super().__call__(*args, global_size=global_size, local_size=local_size, vals=vals, wait=wait)
|
||||
return super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait)
|
||||
|
||||
class NVAllocator(HCQAllocator):
|
||||
def __init__(self, device:NVDevice): super().__init__(device)
|
||||
|
||||
Reference in New Issue
Block a user