mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
cl: fix multi-image arg kernels (#14920)
This commit is contained in:
committed by
GitHub
parent
24286c5593
commit
815780f72f
@@ -61,8 +61,9 @@ class CLProgram:
|
||||
if isinstance(dt, ImageDType):
|
||||
fmt = cl.cl_image_format(cl.CL_RGBA, {2:cl.CL_HALF_FLOAT, 4:cl.CL_FLOAT}[dt.itemsize])
|
||||
desc = cl.cl_image_desc(cl.CL_MEM_OBJECT_IMAGE2D, dt.shape[1], dt.shape[0], image_row_pitch=dt.pitch, buffer=b)
|
||||
b = checked(cl.clCreateImage(self.dev.context, cl.CL_MEM_READ_WRITE, fmt, desc, None, status:=ctypes.c_int32()), status)
|
||||
check(cl.clSetKernelArg(self.kernel, real_i, ctypes.sizeof(b), ctypes.byref(b)))
|
||||
img = checked(cl.clCreateImage(self.dev.context, cl.CL_MEM_READ_WRITE, fmt, desc, None, status:=ctypes.c_int32()), status)
|
||||
check(cl.clSetKernelArg(self.kernel, real_i, ctypes.sizeof(img), ctypes.byref(img)))
|
||||
else: check(cl.clSetKernelArg(self.kernel, real_i, ctypes.sizeof(b), ctypes.byref(b)))
|
||||
for i,v in enumerate(vals,start=i+1): check(cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(v))))
|
||||
if local_size is not None: global_size = cast(tuple[int,int,int], tuple(int(g*l) for g,l in zip(global_size, local_size)))
|
||||
event = cl.cl_event() if wait else None
|
||||
|
||||
Reference in New Issue
Block a user