refactor getters

This commit is contained in:
Comma Device
2022-08-22 13:29:08 -07:00
parent 13646ae07a
commit 1b5f4e52d9

View File

@@ -52,6 +52,51 @@ def get_replacements(prg_src:str, opencl_type:List[str]) -> Dict[str, str]:
replacements["//ARGS"] = ","+','.join(opencl_type)
return replacements
def get_getters(ewbufs, ret):
fakebufs = []
ewtypes = []
getters = []
for name, buf in ewbufs:
view, unfolded = buf.contiguous_view_constant_fold(name)
if not unfolded:
getters.append(view)
fakebufs.append(name)
getters.append(f"inline float4 get4_{name}(int gid) {{"+
f"return (float4)(get_{name}(gid+0), get_{name}(gid+1), get_{name}(gid+2), get_{name}(gid+3)); }}")
elif buf.is_image() and buf.shape == ret.shape and buf.st.contiguous:
# use an image here
ewtypes.append(f"read_only image2d_t {name}_g")
getters.append(f"inline float4 get4_{name}(read_only image2d_t x, const sampler_t smp, int2 loc, int gid) {{ return read_imagef(x, smp, loc); }}")
elif buf.st.contiguous:
# use float4
ewtypes.append(f"__global const float4 *{name}_g")
getters.append(f"inline float4 get4_{name}(__global const float4 *x, const sampler_t smp, int2 loc, int gid) {{ return x[gid/4]; }}")
elif UNSAFE_FLOAT4:
# aggressive constant folding
fakebufs.append(name)
prt = buf._backing.reshape((-1, 4))
cc = []
for ii in range(prt.shape[0]): cc.append("(float4)(%ff, %ff, %ff, %ff)" % (prt[ii][0], prt[ii][1], prt[ii][2], prt[ii][3]))
getters.append(f"const __constant float4 const_{name}[] = {{"+', '.join(cc)+"};")
getters.append(f"inline float4 get4_{name}(int gid) {{"+
"int idx = gid;"+buf.st.expr()+";"+
f"return const_{name}[idx/4]; }}")
"""
# use float4 indexed (HACK!)
# TODO: work out when this is okay
ewtypes.append(f"__global const float4 *{name}_g")
getters.append(f"inline float4 get4_{name}(__global const float4 *x, const sampler_t smp, int2 loc, int gid) {{"+
"int valid = 1; int idx = gid;"+buf.st.expr()+";"+
f"return x[idx/4]; }}")
"""
else:
# fallback to float
getters.append(view)
ewtypes.append(f"__global const float *{name}_g")
getters.append(f"inline float4 get4_{name}(__global const float *x, const sampler_t smp, int2 loc, int gid) {{"+
f"return (float4)(get_{name}(x,gid+0), get_{name}(x,gid+1), get_{name}(x,gid+2), get_{name}(x,gid+3)); }}")
return fakebufs, ewtypes, getters
def roundup(x, n=4): return (x+(n-1))//n * n
class OpenCLBuffer(GPUBuffer):
def __init__(self, shape, hostbuf:Optional[OpenCLBuffer]=None, backing:Optional[np.ndarray]=None):
@@ -137,50 +182,8 @@ class OpenCLBuffer(GPUBuffer):
print("WARNING: recomputing CONV with", x, w)
OpenCLBuffer.seen.add((x,w))
fakebufs = []
ewtypes = []
getters = []
for name, buf in ewbufs:
view, unfolded = buf.contiguous_view_constant_fold(name)
if not unfolded:
getters.append(view)
fakebufs.append(name)
getters.append(f"inline float4 get4_{name}(int gid) {{"+
f"return (float4)(get_{name}(gid+0), get_{name}(gid+1), get_{name}(gid+2), get_{name}(gid+3)); }}")
elif buf.is_image() and buf.shape == ret.shape and buf.st.contiguous:
# use an image here
ewtypes.append(f"read_only image2d_t {name}_g")
getters.append(f"inline float4 get4_{name}(read_only image2d_t x, const sampler_t smp, int2 loc, int gid) {{ return read_imagef(x, smp, loc); }}")
elif buf.st.contiguous:
# use float4
ewtypes.append(f"__global const float4 *{name}_g")
getters.append(f"inline float4 get4_{name}(__global const float4 *x, const sampler_t smp, int2 loc, int gid) {{ return x[gid/4]; }}")
elif UNSAFE_FLOAT4:
# aggressive constant folding
fakebufs.append(name)
prt = buf._backing.reshape((-1, 4))
cc = []
for ii in range(prt.shape[0]): cc.append("(float4)(%ff, %ff, %ff, %ff)" % (prt[ii][0], prt[ii][1], prt[ii][2], prt[ii][3]))
getters.append(f"const __constant float4 const_{name}[] = {{"+', '.join(cc)+"};")
getters.append(f"inline float4 get4_{name}(int gid) {{"+
"int idx = gid;"+buf.st.expr()+";"+
f"return const_{name}[idx/4]; }}")
"""
# use float4 indexed (HACK!)
# TODO: work out when this is okay
ewtypes.append(f"__global const float4 *{name}_g")
getters.append(f"inline float4 get4_{name}(__global const float4 *x, const sampler_t smp, int2 loc, int gid) {{"+
"int valid = 1; int idx = gid;"+buf.st.expr()+";"+
f"return x[idx/4]; }}")
"""
else:
# fallback to float
getters.append(view)
ewtypes.append(f"__global const float *{name}_g")
getters.append(f"inline float4 get4_{name}(__global const float *x, const sampler_t smp, int2 loc, int gid) {{"+
f"return (float4)(get_{name}(x,gid+0), get_{name}(x,gid+1), get_{name}(x,gid+2), get_{name}(x,gid+3)); }}")
# remove fakebufs
fakebufs, ewtypes, getters = get_getters(ewbufs, ret)
ewbufs = [x for x in ewbufs if x[0] not in fakebufs]
elementwise_prefix = '\n'.join(getters)+ \