mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
symbolic codegen and exec (#1552)
* symbolic codegen and exec * fix and add test * no sketchy * merge_dicts type * dtypes._arg_int32
This commit is contained in:
@@ -80,7 +80,7 @@ class CLProgram:
|
||||
def max_work_group_size(): return CL.cl_ctxs[0].devices[0].max_work_group_size
|
||||
|
||||
def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]:
|
||||
cl_bufs = [x._buf if isinstance(x, CLBuffer) else x for x in bufs]
|
||||
cl_bufs = [x._buf if isinstance(x, CLBuffer) else np.int32(x) if isinstance(x, int) else x for x in bufs]
|
||||
e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [g*l for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=[x.event for x in bufs if isinstance(x, CLBuffer) and hasattr(x, "event")])
|
||||
if wait:
|
||||
e.wait()
|
||||
@@ -91,7 +91,7 @@ class CLProgram:
|
||||
return None
|
||||
|
||||
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
|
||||
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ",
|
||||
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ", arg_int_prefix = "const int",
|
||||
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable",
|
||||
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
|
||||
gid = [f'get_group_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True))
|
||||
|
||||
Reference in New Issue
Block a user