mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
fix typing
This commit is contained in:
@@ -167,7 +167,7 @@ class OpenCLBuffer(GPUBuffer):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
return type(x)(C.out_shape)._processing_op([("input", x.contiguous_op()), ("weight", w.contiguous_op())], "acc", C)
|
||||
|
||||
def contiguous_view_constant_fold(x, name:str) -> Tuple[str, str, str]:
|
||||
def contiguous_view_constant_fold(x, name:str) -> Tuple[str, Optional[str], str]:
|
||||
if x.is_image():
|
||||
#print("is image")
|
||||
return f"""inline float get_{name}(const sampler_t smp, read_only image2d_t x, int gid) {{
|
||||
|
||||
@@ -102,7 +102,7 @@ class GPUBuffer:
|
||||
def contiguous_view(x, name:str) -> str:
|
||||
return f"inline float get_{name}(__global const float *x, int gid) {{ int valid = 1; int idx = gid; {x.st.expr().replace('//', '/')}; return valid ? x[idx] : 0.0;}}"
|
||||
|
||||
def contiguous_view_constant_fold(x, name:str) -> Tuple[str, str, str]:
|
||||
def contiguous_view_constant_fold(x, name:str) -> Tuple[str, Optional[str], str]:
|
||||
if x._base_shape == (1,) and x._backing is not None:
|
||||
return f"inline float get_{name}(int gid) {{ int valid = 1; int idx = gid; {x.st.expr().replace('//', '/')}; return valid ? {x._backing[0]} : 0.0;}}", None, f"get_{name}(idx);"
|
||||
else:
|
||||
@@ -127,7 +127,8 @@ class GPUBuffer:
|
||||
|
||||
kernel_name = "reduce" if red > 1 else "elementwise"
|
||||
views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs}
|
||||
buf_types = [views[name][1] for name, _ in bufs if views[name][1] is not None]
|
||||
buf_types : List[str] = [views[name][1] for name, _ in bufs if views[name][1] is not None] # type: ignore
|
||||
buf_cl = [buf.cl if 'image2d_t' not in views[name][1] else buf.image for name, buf in bufs if views[name][1] is not None] + ([cl.LocalMemory(inter_red*4)] if inter_red > 1 else []) # type: ignore
|
||||
conv_prg = CLProgram(kernel_name, f"""{chr(10).join([x[0] for x in views.values()])}
|
||||
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + (["__local float *temp"] if inter_red > 1 else []))}) {{
|
||||
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
@@ -144,7 +145,5 @@ class GPUBuffer:
|
||||
output[gid] = {code};
|
||||
}}
|
||||
}}""")
|
||||
conv_prg([prod(ret.shape), inter_red, 1], [1, inter_red, 1] if inter_red > 1 else None, ret.cl,
|
||||
*([buf.cl if 'image2d_t' not in views[name][1] else buf.image for name, buf in bufs if views[name][1] is not None] + ([cl.LocalMemory(inter_red*4)] if inter_red > 1 else [])),
|
||||
op_estimate=prod(reduce_shape[0])*len(earlybufs) + prod(reduce_shape[1])*len(bufs))
|
||||
conv_prg([prod(ret.shape), inter_red, 1], [1, inter_red, 1] if inter_red > 1 else None, ret.cl, *buf_cl, op_estimate=prod(reduce_shape[0])*len(earlybufs) + prod(reduce_shape[1])*len(bufs))
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user