diff --git a/accel/opencl/ops_opencl.py b/accel/opencl/ops_opencl.py index b697e46404..cd1c4cdde4 100644 --- a/accel/opencl/ops_opencl.py +++ b/accel/opencl/ops_opencl.py @@ -167,7 +167,7 @@ class OpenCLBuffer(GPUBuffer): assert op == ProcessingOps.CONV, f"{op} isn't supported" return type(x)(C.out_shape)._processing_op([("input", x.contiguous_op()), ("weight", w.contiguous_op())], "acc", C) - def contiguous_view_constant_fold(x, name:str) -> Tuple[str, str, str]: + def contiguous_view_constant_fold(x, name:str) -> Tuple[str, Optional[str], str]: if x.is_image(): #print("is image") return f"""inline float get_{name}(const sampler_t smp, read_only image2d_t x, int gid) {{ diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 871ca3a819..92983654e9 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -102,7 +102,7 @@ class GPUBuffer: def contiguous_view(x, name:str) -> str: return f"inline float get_{name}(__global const float *x, int gid) {{ int valid = 1; int idx = gid; {x.st.expr().replace('//', '/')}; return valid ? x[idx] : 0.0;}}" - def contiguous_view_constant_fold(x, name:str) -> Tuple[str, str, str]: + def contiguous_view_constant_fold(x, name:str) -> Tuple[str, Optional[str], str]: if x._base_shape == (1,) and x._backing is not None: return f"inline float get_{name}(int gid) {{ int valid = 1; int idx = gid; {x.st.expr().replace('//', '/')}; return valid ? {x._backing[0]} : 0.0;}}", None, f"get_{name}(idx);" else: @@ -127,7 +127,8 @@ class GPUBuffer: kernel_name = "reduce" if red > 1 else "elementwise" views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs} - buf_types = [views[name][1] for name, _ in bufs if views[name][1] is not None] + buf_types : List[str] = [views[name][1] for name, _ in bufs if views[name][1] is not None] # type: ignore + buf_cl = [buf.cl if 'image2d_t' not in views[name][1] else buf.image for name, buf in bufs if views[name][1] is not None] + ([cl.LocalMemory(inter_red*4)] if inter_red > 1 else []) # type: ignore conv_prg = CLProgram(kernel_name, f"""{chr(10).join([x[0] for x in views.values()])} __kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + (["__local float *temp"] if inter_red > 1 else []))}) {{ const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; @@ -144,7 +145,5 @@ class GPUBuffer: output[gid] = {code}; }} }}""") - conv_prg([prod(ret.shape), inter_red, 1], [1, inter_red, 1] if inter_red > 1 else None, ret.cl, - *([buf.cl if 'image2d_t' not in views[name][1] else buf.image for name, buf in bufs if views[name][1] is not None] + ([cl.LocalMemory(inter_red*4)] if inter_red > 1 else [])), - op_estimate=prod(reduce_shape[0])*len(earlybufs) + prod(reduce_shape[1])*len(bufs)) + conv_prg([prod(ret.shape), inter_red, 1], [1, inter_red, 1] if inter_red > 1 else None, ret.cl, *buf_cl, op_estimate=prod(reduce_shape[0])*len(earlybufs) + prod(reduce_shape[1])*len(bufs)) return ret