mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
unify conv and reduce
This commit is contained in:
@@ -117,40 +117,22 @@ class GPUBuffer:
|
||||
return type(x)(C.out_shape)._processing_op([("input", x.contiguous_op()), ("weight", w.contiguous_op())], "acc", C)
|
||||
|
||||
def reduce_op(x, op:ReduceOps, new_shape:Tuple[int, ...]):
|
||||
if op == ReduceOps.SUM: code, start = "out += a", "0.0"
|
||||
elif op == ReduceOps.MAX: code, start = "out = max(a,out)", "-INFINITY"
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
if op == ReduceOps.SUM: code, start = "acc + A", "0.0"
|
||||
elif op == ReduceOps.MAX: code, start = "max(A, acc)", "-INFINITY"
|
||||
return type(x)(new_shape)._processing_op([("A", x)], code, None, start)
|
||||
|
||||
# reverse operation of expand, this validates inputs
|
||||
st = ShapeTracker(new_shape).movement_op(MovementOps.EXPAND, x.shape)
|
||||
# this takes a ret index to an inp index, indexing 0 on the reduced strides
|
||||
view = View(new_shape, strides_for_shape(x.shape))
|
||||
|
||||
# generate loops with combined adjacent reduce axis
|
||||
acc = 1
|
||||
def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, start="0.0") -> GPUBuffer:
|
||||
ints, params, ewbufs, conv_src = '', [], bufs, ''
|
||||
global_size = [prod(ret.shape), 1, 1]
|
||||
loop : List[Tuple[str, str]] = []
|
||||
for shp,stride in st.views[-1].shape_strides[::-1]:
|
||||
if stride == 0: loop.append((f"for (int axis_{len(loop)} = 0; axis_{len(loop)} < {shp}; axis_{len(loop)}++) {{", f"idx += {acc}; }} idx -= {shp*acc};"))
|
||||
acc *= shp
|
||||
|
||||
# TODO: support multistage reduces
|
||||
ret = type(x)(new_shape)
|
||||
CLProgram("reduce", f"""{x.contiguous_view('A')}
|
||||
__kernel void reduce(__global const float *a_g, __global float *res_g) {{
|
||||
int gid = get_global_id(0); int idx = gid; {view.expr.replace('//', '/')};
|
||||
float out = {start};
|
||||
{''.join([ls for ls, _ in loop[::-1]])}
|
||||
float a = get_A(a_g, idx); {code};
|
||||
{''.join([le for _, le in loop])}
|
||||
res_g[gid] = out;
|
||||
}}""")([prod(ret.shape)], None, x.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None) -> GPUBuffer:
|
||||
if C is not None:
|
||||
# this takes a ret index to an inp index, indexing 0 on the reduced strides
|
||||
# if it's not a reduce, this should be a NOOP
|
||||
view = View(ret.shape, strides_for_shape(bufs[0][1].shape))
|
||||
if C is not None: # this is a conv
|
||||
ints = ''.join(f"int {x} = {getattr(C, x)};" for x in ["H", "W", "sy", "sx", "dx", "dy", "px", "py", "groups", "rcout", "cin"])
|
||||
params = [(f"int {x}", getattr(C, x)) for x in ["oy", "ox", "iy", "ix"]]
|
||||
global_size = [C.bs*C.cout, C.oy, C.ox]
|
||||
assert ret.shape == C.out_shape, "output shape is wrong (can't reduce and conv together)"
|
||||
|
||||
# now input and weight can be anywhere in bufs
|
||||
bufs = [(x[0], x[1].contiguous_op()) if x[0] in ["input", "weight"] else x for x in bufs]
|
||||
@@ -159,7 +141,7 @@ class GPUBuffer:
|
||||
|
||||
conv_src = """
|
||||
int B = gid/(groups*rcout); int g = (gid/rcout)%groups; int c = gid % rcout;
|
||||
int Y = get_global_id(1); int X = get_global_id(2); gid = gid*oy*ox + Y*ox + X;
|
||||
int Y = get_global_id(1); int X = get_global_id(2); gid = gid*oy*ox + Y*ox + X; idx = gid;
|
||||
for (int ci = 0; ci < cin; ci++) {
|
||||
for (int y = 0; y < H; y++) { for (int x = 0; x < W; x++) {
|
||||
int idx_y = y*dy + Y*sy - py;
|
||||
@@ -168,22 +150,28 @@ class GPUBuffer:
|
||||
acc += valid * input_g[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + clamp(idx_y, 0, iy-1)*ix + clamp(idx_x, 0, ix-1)] * \
|
||||
weight_g[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
|
||||
} }
|
||||
}
|
||||
"""
|
||||
else:
|
||||
ints, params = '', []
|
||||
global_size = [prod(ret.shape), 1, 1]
|
||||
ewbufs = bufs
|
||||
conv_src = ""
|
||||
}"""
|
||||
elif ret.shape != bufs[0][1].shape: # this is a reduce
|
||||
# reverse operation of expand, this validates inputs
|
||||
st = ShapeTracker(ret.shape).movement_op(MovementOps.EXPAND, bufs[0][1].shape)
|
||||
|
||||
kernel_name = "conv" if C is not None else "elementwise"
|
||||
# generate loops with combined adjacent reduce axis
|
||||
acc = 1
|
||||
for shp,stride in st.views[-1].shape_strides[::-1]:
|
||||
if stride == 0: loop.append((f"for (int axis_{len(loop)} = 0; axis_{len(loop)} < {shp}; axis_{len(loop)}++) {{", f"idx += {acc}; }} idx -= {shp*acc};"))
|
||||
acc *= shp
|
||||
|
||||
kernel_name = "conv" if C is not None else ("reduce" if len(loop) > 0 else "elementwise")
|
||||
views = {name:buf.contiguous_view_constant_fold(name) for name, buf in ewbufs}
|
||||
buf_types = [f"__global const float *{name}_g" for name, _ in bufs if name not in views or views[name][1]]
|
||||
conv_prg = CLProgram(kernel_name, f"""{''.join([x[0] for x in views.values()])}
|
||||
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + [x[0] for x in params])}) {{
|
||||
float acc = 0.0; int gid = get_global_id(0); {ints} {conv_src}
|
||||
{''.join([f'float {name} = get_{name}({name}_g, gid);' if views[name][1] else f'float {name} = get_{name}(gid);' for name, _ in ewbufs])}
|
||||
output[gid] = {code};
|
||||
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + [x[0] for x in params])}) {{ {ints}
|
||||
float acc = {start}; int gid = get_global_id(0); int idx = gid; {view.expr.replace('//', '/')}; {conv_src}
|
||||
{''.join([ls for ls, _ in loop[::-1]])}
|
||||
{''.join([f'float {name} = get_{name}({name}_g, idx);' if views[name][1] else f'float {name} = get_{name}(idx);' for name, _ in ewbufs])}
|
||||
acc = {code};
|
||||
{''.join([le for _, le in loop])}
|
||||
output[gid] = acc;
|
||||
}}""", argdtypes=tuple(None if i < 1+len(buf_types) else np.int32 for i in range(1+len(buf_types)+len(params))))
|
||||
conv_prg(global_size, None, ret.cl, *[buf.cl for name, buf in bufs if name not in views or views[name][1]], *[x[1] for x in params])
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user