fix typing

This commit is contained in:
George Hotz
2022-08-22 16:25:15 -07:00
parent e0a8d0f836
commit 2162cd3383
2 changed files with 5 additions and 6 deletions

View File

@@ -167,7 +167,7 @@ class OpenCLBuffer(GPUBuffer):
assert op == ProcessingOps.CONV, f"{op} isn't supported"
return type(x)(C.out_shape)._processing_op([("input", x.contiguous_op()), ("weight", w.contiguous_op())], "acc", C)
def contiguous_view_constant_fold(x, name:str) -> Tuple[str, str, str]:
def contiguous_view_constant_fold(x, name:str) -> Tuple[str, Optional[str], str]:
if x.is_image():
#print("is image")
return f"""inline float get_{name}(const sampler_t smp, read_only image2d_t x, int gid) {{

View File

@@ -102,7 +102,7 @@ class GPUBuffer:
def contiguous_view(x, name:str) -> str:
return f"inline float get_{name}(__global const float *x, int gid) {{ int valid = 1; int idx = gid; {x.st.expr().replace('//', '/')}; return valid ? x[idx] : 0.0;}}"
def contiguous_view_constant_fold(x, name:str) -> Tuple[str, str, str]:
def contiguous_view_constant_fold(x, name:str) -> Tuple[str, Optional[str], str]:
if x._base_shape == (1,) and x._backing is not None:
return f"inline float get_{name}(int gid) {{ int valid = 1; int idx = gid; {x.st.expr().replace('//', '/')}; return valid ? {x._backing[0]} : 0.0;}}", None, f"get_{name}(idx);"
else:
@@ -127,7 +127,8 @@ class GPUBuffer:
kernel_name = "reduce" if red > 1 else "elementwise"
views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs}
buf_types = [views[name][1] for name, _ in bufs if views[name][1] is not None]
buf_types : List[str] = [views[name][1] for name, _ in bufs if views[name][1] is not None] # type: ignore
buf_cl = [buf.cl if 'image2d_t' not in views[name][1] else buf.image for name, buf in bufs if views[name][1] is not None] + ([cl.LocalMemory(inter_red*4)] if inter_red > 1 else []) # type: ignore
conv_prg = CLProgram(kernel_name, f"""{chr(10).join([x[0] for x in views.values()])}
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + (["__local float *temp"] if inter_red > 1 else []))}) {{
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
@@ -144,7 +145,5 @@ class GPUBuffer:
output[gid] = {code};
}}
}}""")
conv_prg([prod(ret.shape), inter_red, 1], [1, inter_red, 1] if inter_red > 1 else None, ret.cl,
*([buf.cl if 'image2d_t' not in views[name][1] else buf.image for name, buf in bufs if views[name][1] is not None] + ([cl.LocalMemory(inter_red*4)] if inter_red > 1 else [])),
op_estimate=prod(reduce_shape[0])*len(earlybufs) + prod(reduce_shape[1])*len(bufs))
conv_prg([prod(ret.shape), inter_red, 1], [1, inter_red, 1] if inter_red > 1 else None, ret.cl, *buf_cl, op_estimate=prod(reduce_shape[0])*len(earlybufs) + prod(reduce_shape[1])*len(bufs))
return ret