unify conv and reduce

This commit is contained in:
George Hotz
2022-07-08 08:27:30 -07:00
parent 9c34b3eef3
commit e6733286df

View File

@@ -117,40 +117,22 @@ class GPUBuffer:
return type(x)(C.out_shape)._processing_op([("input", x.contiguous_op()), ("weight", w.contiguous_op())], "acc", C)
def reduce_op(x, op:ReduceOps, new_shape:Tuple[int, ...]):
if op == ReduceOps.SUM: code, start = "out += a", "0.0"
elif op == ReduceOps.MAX: code, start = "out = max(a,out)", "-INFINITY"
else: raise Exception(f"{op} isn't supported")
if op == ReduceOps.SUM: code, start = "acc + A", "0.0"
elif op == ReduceOps.MAX: code, start = "max(A, acc)", "-INFINITY"
return type(x)(new_shape)._processing_op([("A", x)], code, None, start)
# reverse operation of expand, this validates inputs
st = ShapeTracker(new_shape).movement_op(MovementOps.EXPAND, x.shape)
# this takes a ret index to an inp index, indexing 0 on the reduced strides
view = View(new_shape, strides_for_shape(x.shape))
# generate loops with combined adjacent reduce axis
acc = 1
def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, start="0.0") -> GPUBuffer:
ints, params, ewbufs, conv_src = '', [], bufs, ''
global_size = [prod(ret.shape), 1, 1]
loop : List[Tuple[str, str]] = []
for shp,stride in st.views[-1].shape_strides[::-1]:
if stride == 0: loop.append((f"for (int axis_{len(loop)} = 0; axis_{len(loop)} < {shp}; axis_{len(loop)}++) {{", f"idx += {acc}; }} idx -= {shp*acc};"))
acc *= shp
# TODO: support multistage reduces
ret = type(x)(new_shape)
CLProgram("reduce", f"""{x.contiguous_view('A')}
__kernel void reduce(__global const float *a_g, __global float *res_g) {{
int gid = get_global_id(0); int idx = gid; {view.expr.replace('//', '/')};
float out = {start};
{''.join([ls for ls, _ in loop[::-1]])}
float a = get_A(a_g, idx); {code};
{''.join([le for _, le in loop])}
res_g[gid] = out;
}}""")([prod(ret.shape)], None, x.cl, ret.cl)
return ret
def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None) -> GPUBuffer:
if C is not None:
# this takes a ret index to an inp index, indexing 0 on the reduced strides
# if it's not a reduce, this should be a NOOP
view = View(ret.shape, strides_for_shape(bufs[0][1].shape))
if C is not None: # this is a conv
ints = ''.join(f"int {x} = {getattr(C, x)};" for x in ["H", "W", "sy", "sx", "dx", "dy", "px", "py", "groups", "rcout", "cin"])
params = [(f"int {x}", getattr(C, x)) for x in ["oy", "ox", "iy", "ix"]]
global_size = [C.bs*C.cout, C.oy, C.ox]
assert ret.shape == C.out_shape, "output shape is wrong (can't reduce and conv together)"
# now input and weight can be anywhere in bufs
bufs = [(x[0], x[1].contiguous_op()) if x[0] in ["input", "weight"] else x for x in bufs]
@@ -159,7 +141,7 @@ class GPUBuffer:
conv_src = """
int B = gid/(groups*rcout); int g = (gid/rcout)%groups; int c = gid % rcout;
int Y = get_global_id(1); int X = get_global_id(2); gid = gid*oy*ox + Y*ox + X;
int Y = get_global_id(1); int X = get_global_id(2); gid = gid*oy*ox + Y*ox + X; idx = gid;
for (int ci = 0; ci < cin; ci++) {
for (int y = 0; y < H; y++) { for (int x = 0; x < W; x++) {
int idx_y = y*dy + Y*sy - py;
@@ -168,22 +150,28 @@ class GPUBuffer:
acc += valid * input_g[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + clamp(idx_y, 0, iy-1)*ix + clamp(idx_x, 0, ix-1)] * \
weight_g[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
} }
}
"""
else:
ints, params = '', []
global_size = [prod(ret.shape), 1, 1]
ewbufs = bufs
conv_src = ""
}"""
elif ret.shape != bufs[0][1].shape: # this is a reduce
# reverse operation of expand, this validates inputs
st = ShapeTracker(ret.shape).movement_op(MovementOps.EXPAND, bufs[0][1].shape)
kernel_name = "conv" if C is not None else "elementwise"
# generate loops with combined adjacent reduce axis
acc = 1
for shp,stride in st.views[-1].shape_strides[::-1]:
if stride == 0: loop.append((f"for (int axis_{len(loop)} = 0; axis_{len(loop)} < {shp}; axis_{len(loop)}++) {{", f"idx += {acc}; }} idx -= {shp*acc};"))
acc *= shp
kernel_name = "conv" if C is not None else ("reduce" if len(loop) > 0 else "elementwise")
views = {name:buf.contiguous_view_constant_fold(name) for name, buf in ewbufs}
buf_types = [f"__global const float *{name}_g" for name, _ in bufs if name not in views or views[name][1]]
conv_prg = CLProgram(kernel_name, f"""{''.join([x[0] for x in views.values()])}
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + [x[0] for x in params])}) {{
float acc = 0.0; int gid = get_global_id(0); {ints} {conv_src}
{''.join([f'float {name} = get_{name}({name}_g, gid);' if views[name][1] else f'float {name} = get_{name}(gid);' for name, _ in ewbufs])}
output[gid] = {code};
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + [x[0] for x in params])}) {{ {ints}
float acc = {start}; int gid = get_global_id(0); int idx = gid; {view.expr.replace('//', '/')}; {conv_src}
{''.join([ls for ls, _ in loop[::-1]])}
{''.join([f'float {name} = get_{name}({name}_g, idx);' if views[name][1] else f'float {name} = get_{name}(idx);' for name, _ in ewbufs])}
acc = {code};
{''.join([le for _, le in loop])}
output[gid] = acc;
}}""", argdtypes=tuple(None if i < 1+len(buf_types) else np.int32 for i in range(1+len(buf_types)+len(params))))
conv_prg(global_size, None, ret.cl, *[buf.cl for name, buf in bufs if name not in views or views[name][1]], *[x[1] for x in params])
return ret