mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
refactor getters
This commit is contained in:
@@ -52,6 +52,51 @@ def get_replacements(prg_src:str, opencl_type:List[str]) -> Dict[str, str]:
|
||||
replacements["//ARGS"] = ","+','.join(opencl_type)
|
||||
return replacements
|
||||
|
||||
def get_getters(ewbufs, ret):
|
||||
fakebufs = []
|
||||
ewtypes = []
|
||||
getters = []
|
||||
for name, buf in ewbufs:
|
||||
view, unfolded = buf.contiguous_view_constant_fold(name)
|
||||
if not unfolded:
|
||||
getters.append(view)
|
||||
fakebufs.append(name)
|
||||
getters.append(f"inline float4 get4_{name}(int gid) {{"+
|
||||
f"return (float4)(get_{name}(gid+0), get_{name}(gid+1), get_{name}(gid+2), get_{name}(gid+3)); }}")
|
||||
elif buf.is_image() and buf.shape == ret.shape and buf.st.contiguous:
|
||||
# use an image here
|
||||
ewtypes.append(f"read_only image2d_t {name}_g")
|
||||
getters.append(f"inline float4 get4_{name}(read_only image2d_t x, const sampler_t smp, int2 loc, int gid) {{ return read_imagef(x, smp, loc); }}")
|
||||
elif buf.st.contiguous:
|
||||
# use float4
|
||||
ewtypes.append(f"__global const float4 *{name}_g")
|
||||
getters.append(f"inline float4 get4_{name}(__global const float4 *x, const sampler_t smp, int2 loc, int gid) {{ return x[gid/4]; }}")
|
||||
elif UNSAFE_FLOAT4:
|
||||
# aggressive constant folding
|
||||
fakebufs.append(name)
|
||||
prt = buf._backing.reshape((-1, 4))
|
||||
cc = []
|
||||
for ii in range(prt.shape[0]): cc.append("(float4)(%ff, %ff, %ff, %ff)" % (prt[ii][0], prt[ii][1], prt[ii][2], prt[ii][3]))
|
||||
getters.append(f"const __constant float4 const_{name}[] = {{"+', '.join(cc)+"};")
|
||||
getters.append(f"inline float4 get4_{name}(int gid) {{"+
|
||||
"int idx = gid;"+buf.st.expr()+";"+
|
||||
f"return const_{name}[idx/4]; }}")
|
||||
"""
|
||||
# use float4 indexed (HACK!)
|
||||
# TODO: work out when this is okay
|
||||
ewtypes.append(f"__global const float4 *{name}_g")
|
||||
getters.append(f"inline float4 get4_{name}(__global const float4 *x, const sampler_t smp, int2 loc, int gid) {{"+
|
||||
"int valid = 1; int idx = gid;"+buf.st.expr()+";"+
|
||||
f"return x[idx/4]; }}")
|
||||
"""
|
||||
else:
|
||||
# fallback to float
|
||||
getters.append(view)
|
||||
ewtypes.append(f"__global const float *{name}_g")
|
||||
getters.append(f"inline float4 get4_{name}(__global const float *x, const sampler_t smp, int2 loc, int gid) {{"+
|
||||
f"return (float4)(get_{name}(x,gid+0), get_{name}(x,gid+1), get_{name}(x,gid+2), get_{name}(x,gid+3)); }}")
|
||||
return fakebufs, ewtypes, getters
|
||||
|
||||
def roundup(x, n=4): return (x+(n-1))//n * n
|
||||
class OpenCLBuffer(GPUBuffer):
|
||||
def __init__(self, shape, hostbuf:Optional[OpenCLBuffer]=None, backing:Optional[np.ndarray]=None):
|
||||
@@ -137,50 +182,8 @@ class OpenCLBuffer(GPUBuffer):
|
||||
print("WARNING: recomputing CONV with", x, w)
|
||||
OpenCLBuffer.seen.add((x,w))
|
||||
|
||||
fakebufs = []
|
||||
ewtypes = []
|
||||
getters = []
|
||||
for name, buf in ewbufs:
|
||||
view, unfolded = buf.contiguous_view_constant_fold(name)
|
||||
if not unfolded:
|
||||
getters.append(view)
|
||||
fakebufs.append(name)
|
||||
getters.append(f"inline float4 get4_{name}(int gid) {{"+
|
||||
f"return (float4)(get_{name}(gid+0), get_{name}(gid+1), get_{name}(gid+2), get_{name}(gid+3)); }}")
|
||||
elif buf.is_image() and buf.shape == ret.shape and buf.st.contiguous:
|
||||
# use an image here
|
||||
ewtypes.append(f"read_only image2d_t {name}_g")
|
||||
getters.append(f"inline float4 get4_{name}(read_only image2d_t x, const sampler_t smp, int2 loc, int gid) {{ return read_imagef(x, smp, loc); }}")
|
||||
elif buf.st.contiguous:
|
||||
# use float4
|
||||
ewtypes.append(f"__global const float4 *{name}_g")
|
||||
getters.append(f"inline float4 get4_{name}(__global const float4 *x, const sampler_t smp, int2 loc, int gid) {{ return x[gid/4]; }}")
|
||||
elif UNSAFE_FLOAT4:
|
||||
# aggressive constant folding
|
||||
fakebufs.append(name)
|
||||
prt = buf._backing.reshape((-1, 4))
|
||||
cc = []
|
||||
for ii in range(prt.shape[0]): cc.append("(float4)(%ff, %ff, %ff, %ff)" % (prt[ii][0], prt[ii][1], prt[ii][2], prt[ii][3]))
|
||||
getters.append(f"const __constant float4 const_{name}[] = {{"+', '.join(cc)+"};")
|
||||
getters.append(f"inline float4 get4_{name}(int gid) {{"+
|
||||
"int idx = gid;"+buf.st.expr()+";"+
|
||||
f"return const_{name}[idx/4]; }}")
|
||||
"""
|
||||
# use float4 indexed (HACK!)
|
||||
# TODO: work out when this is okay
|
||||
ewtypes.append(f"__global const float4 *{name}_g")
|
||||
getters.append(f"inline float4 get4_{name}(__global const float4 *x, const sampler_t smp, int2 loc, int gid) {{"+
|
||||
"int valid = 1; int idx = gid;"+buf.st.expr()+";"+
|
||||
f"return x[idx/4]; }}")
|
||||
"""
|
||||
else:
|
||||
# fallback to float
|
||||
getters.append(view)
|
||||
ewtypes.append(f"__global const float *{name}_g")
|
||||
getters.append(f"inline float4 get4_{name}(__global const float *x, const sampler_t smp, int2 loc, int gid) {{"+
|
||||
f"return (float4)(get_{name}(x,gid+0), get_{name}(x,gid+1), get_{name}(x,gid+2), get_{name}(x,gid+3)); }}")
|
||||
|
||||
# remove fakebufs
|
||||
fakebufs, ewtypes, getters = get_getters(ewbufs, ret)
|
||||
ewbufs = [x for x in ewbufs if x[0] not in fakebufs]
|
||||
|
||||
elementwise_prefix = '\n'.join(getters)+ \
|
||||
|
||||
Reference in New Issue
Block a user