make gpu code readable

This commit is contained in:
George Hotz
2022-09-06 21:17:36 -07:00
parent 790af99a48
commit 1c92a6da22

View File

@@ -78,8 +78,10 @@ class CLProgram:
class GPUBuffer:
code_for_op = {
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.SIGN: "sign(A)", UnaryOps.RECIPROCAL: "((float)1.0/A)",
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)",
UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.SIGN: "sign(A)", UnaryOps.RECIPROCAL: "((float)1.0/A)",
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)",
BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)"
}
start_for_op = {ReduceOps.SUM: "0.0", ReduceOps.MAX: "-INFINITY"}
@@ -140,27 +142,37 @@ class GPUBuffer:
# if it's a partial reduce, assert last non reduced axis is before the first reduced axis
if red > 1 and prod(ret.shape) != 1:
assert max([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s == n and n != 1]) < min([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s != 1 and n == 1])
inter_red = 256 if (prod(ret.shape) < 8192 and red >= 256) else 1
kernel_name = "reduce" if red > 1 else "elementwise"
views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs}
buf_types : List[str] = [views[name][1] for name, _ in bufs if views[name][1] is not None] # type: ignore
buf_cl = [buf.cl if 'image2d_t' not in views[name][1] else buf.image for name, buf in bufs if views[name][1] is not None] + ([cl.LocalMemory(inter_red*4)] if inter_red > 1 else []) # type: ignore
buf_cl = [buf.cl if 'image2d_t' not in views[name][1] else buf.image for name, buf in bufs if views[name][1] is not None]
# use local memory if it's a multistage reduce
inter_red = 256 if (prod(ret.shape) < 8192 and red >= 256) else 1
if inter_red > 1:
buf_cl.append(cl.LocalMemory(inter_red*4))
conv_prg = CLProgram(kernel_name, f"""{chr(10).join([x[0] for x in views.values()])}
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + (["__local float *temp"] if inter_red > 1 else []))}) {{
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
float acc = {GPUBuffer.start_for_op[op]}; int gid = get_global_id(0); {'int mid = get_global_id(1);' if inter_red > 1 else 'int mid = 0;'}
float acc = {GPUBuffer.start_for_op[op]};
int gid = get_global_id(0);
{'int mid = get_global_id(1);' if inter_red > 1 else 'int mid = 0;'}
for (int idx = gid * {red} + {red//inter_red + 1} * mid; idx < gid * {red} + min({red}, {red//inter_red + 1} * (mid+1)); idx++) {{
{chr(10).join([f' float {name} = ' + views[name][2] for name, _ in bufs if name in earlybufs])}
acc = {earlycode};
}} int idx = gid;"""+(f"""
temp[mid] = acc; barrier(CLK_LOCAL_MEM_FENCE);
if (mid == 0) {{ acc = {GPUBuffer.start_for_op[op]};
for (int rdx = 0; rdx < {inter_red}; rdx++) {{ acc = {GPUBuffer.code_for_op[op].replace('A', 'temp[rdx]')}; }}
""" if inter_red != 1 else "{")+f"""
for (int rdx = 0; rdx < {inter_red}; rdx++) {{
acc = {GPUBuffer.code_for_op[op].replace('A', 'temp[rdx]')};
}}""" if inter_red != 1 else "{")+f"""
{chr(10).join([f' float {name} = ' + views[name][2] for name, _ in bufs if name not in earlybufs])}
output[gid] = {code};
}}
}}""")
conv_prg([prod(ret.shape), inter_red, 1], [1, inter_red, 1] if inter_red > 1 else None, ret.cl, *buf_cl, op_estimate=prod(reduce_shape[0])*len(earlybufs) + prod(reduce_shape[1])*len(bufs))
return ret