mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
make gpu code readable
This commit is contained in:
@@ -78,8 +78,10 @@ class CLProgram:
|
||||
|
||||
class GPUBuffer:
|
||||
code_for_op = {
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.SIGN: "sign(A)", UnaryOps.RECIPROCAL: "((float)1.0/A)",
|
||||
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)",
|
||||
UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.SIGN: "sign(A)", UnaryOps.RECIPROCAL: "((float)1.0/A)",
|
||||
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)",
|
||||
BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
|
||||
ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)"
|
||||
}
|
||||
start_for_op = {ReduceOps.SUM: "0.0", ReduceOps.MAX: "-INFINITY"}
|
||||
@@ -140,27 +142,37 @@ class GPUBuffer:
|
||||
# if it's a partial reduce, assert last non reduced axis is before the first reduced axis
|
||||
if red > 1 and prod(ret.shape) != 1:
|
||||
assert max([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s == n and n != 1]) < min([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s != 1 and n == 1])
|
||||
inter_red = 256 if (prod(ret.shape) < 8192 and red >= 256) else 1
|
||||
|
||||
kernel_name = "reduce" if red > 1 else "elementwise"
|
||||
views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs}
|
||||
|
||||
buf_types : List[str] = [views[name][1] for name, _ in bufs if views[name][1] is not None] # type: ignore
|
||||
buf_cl = [buf.cl if 'image2d_t' not in views[name][1] else buf.image for name, buf in bufs if views[name][1] is not None] + ([cl.LocalMemory(inter_red*4)] if inter_red > 1 else []) # type: ignore
|
||||
buf_cl = [buf.cl if 'image2d_t' not in views[name][1] else buf.image for name, buf in bufs if views[name][1] is not None]
|
||||
|
||||
# use local memory if it's a multistage reduce
|
||||
inter_red = 256 if (prod(ret.shape) < 8192 and red >= 256) else 1
|
||||
if inter_red > 1:
|
||||
buf_cl.append(cl.LocalMemory(inter_red*4))
|
||||
|
||||
conv_prg = CLProgram(kernel_name, f"""{chr(10).join([x[0] for x in views.values()])}
|
||||
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + (["__local float *temp"] if inter_red > 1 else []))}) {{
|
||||
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
float acc = {GPUBuffer.start_for_op[op]}; int gid = get_global_id(0); {'int mid = get_global_id(1);' if inter_red > 1 else 'int mid = 0;'}
|
||||
float acc = {GPUBuffer.start_for_op[op]};
|
||||
int gid = get_global_id(0);
|
||||
{'int mid = get_global_id(1);' if inter_red > 1 else 'int mid = 0;'}
|
||||
for (int idx = gid * {red} + {red//inter_red + 1} * mid; idx < gid * {red} + min({red}, {red//inter_red + 1} * (mid+1)); idx++) {{
|
||||
{chr(10).join([f' float {name} = ' + views[name][2] for name, _ in bufs if name in earlybufs])}
|
||||
acc = {earlycode};
|
||||
}} int idx = gid;"""+(f"""
|
||||
temp[mid] = acc; barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (mid == 0) {{ acc = {GPUBuffer.start_for_op[op]};
|
||||
for (int rdx = 0; rdx < {inter_red}; rdx++) {{ acc = {GPUBuffer.code_for_op[op].replace('A', 'temp[rdx]')}; }}
|
||||
""" if inter_red != 1 else "{")+f"""
|
||||
for (int rdx = 0; rdx < {inter_red}; rdx++) {{
|
||||
acc = {GPUBuffer.code_for_op[op].replace('A', 'temp[rdx]')};
|
||||
}}""" if inter_red != 1 else "{")+f"""
|
||||
{chr(10).join([f' float {name} = ' + views[name][2] for name, _ in bufs if name not in earlybufs])}
|
||||
output[gid] = {code};
|
||||
}}
|
||||
}}""")
|
||||
|
||||
conv_prg([prod(ret.shape), inter_red, 1], [1, inter_red, 1] if inter_red > 1 else None, ret.cl, *buf_cl, op_estimate=prod(reduce_shape[0])*len(earlybufs) + prod(reduce_shape[1])*len(bufs))
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user