symbolic codegen and exec (#1552)

* symbolic codegen and exec

* fix and add test

* no sketchy

* merge_dicts type

* dtypes._arg_int32
This commit is contained in:
chenyu
2023-08-16 14:43:41 -07:00
committed by GitHub
parent 1e1d48b4e6
commit 11dd9b1741
16 changed files with 201 additions and 43 deletions

View File

@@ -80,7 +80,7 @@ class CLProgram:
def max_work_group_size(): return CL.cl_ctxs[0].devices[0].max_work_group_size
def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]:
cl_bufs = [x._buf if isinstance(x, CLBuffer) else x for x in bufs]
cl_bufs = [x._buf if isinstance(x, CLBuffer) else np.int32(x) if isinstance(x, int) else x for x in bufs]
e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [g*l for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=[x.event for x in bufs if isinstance(x, CLBuffer) and hasattr(x, "event")])
if wait:
e.wait()
@@ -91,7 +91,7 @@ class CLProgram:
return None
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ",
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ", arg_int_prefix = "const int",
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable",
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
gid = [f'get_group_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True))