mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-12 23:54:58 -05:00
vals is an argument (#2599)
* vals is an argument * don't even know how that's legal python
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import Tuple, Optional, Union, List, cast
|
||||
from typing import Tuple, Optional, List
|
||||
import ctypes, functools
|
||||
import gpuctypes.opencl as cl
|
||||
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, diskcache, OSX, ImageDType, DEBUG
|
||||
@@ -40,10 +40,9 @@ class CLProgram:
|
||||
check(cl.clReleaseKernel(self.kernel))
|
||||
check(cl.clReleaseProgram(self.program))
|
||||
|
||||
def __call__(self, *bufs:Union[cl.cl_mem, int], global_size:Tuple[int,...], local_size:Optional[Tuple[int,...]]=None, wait=False) -> Optional[float]:
|
||||
for i,b in enumerate(bufs):
|
||||
bc = ctypes.c_int32(b) if isinstance(b, int) else cast(cl.cl_mem, b)
|
||||
cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(bc), ctypes.byref(bc))
|
||||
def __call__(self, *bufs:cl.cl_mem, global_size:Tuple[int,...], local_size:Optional[Tuple[int,...]]=None, vals:Tuple[int, ...]=(), wait=False) -> Optional[float]:
|
||||
for i,b in enumerate(bufs): cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b))
|
||||
for i,b in enumerate(vals,start=len(bufs)): cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(b)))
|
||||
if local_size is not None: global_size = tuple(int(g*l) for g,l in zip(global_size, local_size))
|
||||
event = cl.cl_event() if wait else None
|
||||
check(cl.clEnqueueNDRangeKernel(self.device.queue, self.kernel, len(global_size), None, (ctypes.c_size_t * len(global_size))(*global_size), (ctypes.c_size_t * len(local_size))(*local_size) if local_size else None, 0, None, event))
|
||||
|
||||
Reference in New Issue
Block a user