diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index ac3285ea30..a26bc1bf27 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -117,40 +117,22 @@ class GPUBuffer: return type(x)(C.out_shape)._processing_op([("input", x.contiguous_op()), ("weight", w.contiguous_op())], "acc", C) def reduce_op(x, op:ReduceOps, new_shape:Tuple[int, ...]): - if op == ReduceOps.SUM: code, start = "out += a", "0.0" - elif op == ReduceOps.MAX: code, start = "out = max(a,out)", "-INFINITY" - else: raise Exception(f"{op} isn't supported") + if op == ReduceOps.SUM: code, start = "acc + A", "0.0" + elif op == ReduceOps.MAX: code, start = "max(A, acc)", "-INFINITY" + return type(x)(new_shape)._processing_op([("A", x)], code, None, start) - # reverse operation of expand, this validates inputs - st = ShapeTracker(new_shape).movement_op(MovementOps.EXPAND, x.shape) - # this takes a ret index to an inp index, indexing 0 on the reduced strides - view = View(new_shape, strides_for_shape(x.shape)) - - # generate loops with combined adjacent reduce axis - acc = 1 + def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, start="0.0") -> GPUBuffer: + ints, params, ewbufs, conv_src = '', [], bufs, '' + global_size = [prod(ret.shape), 1, 1] loop : List[Tuple[str, str]] = [] - for shp,stride in st.views[-1].shape_strides[::-1]: - if stride == 0: loop.append((f"for (int axis_{len(loop)} = 0; axis_{len(loop)} < {shp}; axis_{len(loop)}++) {{", f"idx += {acc}; }} idx -= {shp*acc};")) - acc *= shp - - # TODO: support multistage reduces - ret = type(x)(new_shape) - CLProgram("reduce", f"""{x.contiguous_view('A')} - __kernel void reduce(__global const float *a_g, __global float *res_g) {{ - int gid = get_global_id(0); int idx = gid; {view.expr.replace('//', '/')}; - float out = {start}; - {''.join([ls for ls, _ in loop[::-1]])} - float a = get_A(a_g, idx); {code}; - {''.join([le for _, le in loop])} - res_g[gid] = out; - }}""")([prod(ret.shape)], None, x.cl, ret.cl) - return ret - - def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None) -> GPUBuffer: - if C is not None: + # this takes a ret index to an inp index, indexing 0 on the reduced strides + # if it's not a reduce, this should be a NOOP + view = View(ret.shape, strides_for_shape(bufs[0][1].shape)) + if C is not None: # this is a conv ints = ''.join(f"int {x} = {getattr(C, x)};" for x in ["H", "W", "sy", "sx", "dx", "dy", "px", "py", "groups", "rcout", "cin"]) params = [(f"int {x}", getattr(C, x)) for x in ["oy", "ox", "iy", "ix"]] global_size = [C.bs*C.cout, C.oy, C.ox] + assert ret.shape == C.out_shape, "output shape is wrong (can't reduce and conv together)" # now input and weight can be anywhere in bufs bufs = [(x[0], x[1].contiguous_op()) if x[0] in ["input", "weight"] else x for x in bufs] @@ -159,7 +141,7 @@ class GPUBuffer: conv_src = """ int B = gid/(groups*rcout); int g = (gid/rcout)%groups; int c = gid % rcout; - int Y = get_global_id(1); int X = get_global_id(2); gid = gid*oy*ox + Y*ox + X; + int Y = get_global_id(1); int X = get_global_id(2); gid = gid*oy*ox + Y*ox + X; idx = gid; for (int ci = 0; ci < cin; ci++) { for (int y = 0; y < H; y++) { for (int x = 0; x < W; x++) { int idx_y = y*dy + Y*sy - py; @@ -168,22 +150,28 @@ class GPUBuffer: acc += valid * input_g[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + clamp(idx_y, 0, iy-1)*ix + clamp(idx_x, 0, ix-1)] * \ weight_g[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x]; } } - } - """ - else: - ints, params = '', [] - global_size = [prod(ret.shape), 1, 1] - ewbufs = bufs - conv_src = "" + }""" + elif ret.shape != bufs[0][1].shape: # this is a reduce + # reverse operation of expand, this validates inputs + st = ShapeTracker(ret.shape).movement_op(MovementOps.EXPAND, bufs[0][1].shape) - kernel_name = "conv" if C is not None else "elementwise" + # generate loops with combined adjacent reduce axis + acc = 1 + for shp,stride in st.views[-1].shape_strides[::-1]: + if stride == 0: loop.append((f"for (int axis_{len(loop)} = 0; axis_{len(loop)} < {shp}; axis_{len(loop)}++) {{", f"idx += {acc}; }} idx -= {shp*acc};")) + acc *= shp + + kernel_name = "conv" if C is not None else ("reduce" if len(loop) > 0 else "elementwise") views = {name:buf.contiguous_view_constant_fold(name) for name, buf in ewbufs} buf_types = [f"__global const float *{name}_g" for name, _ in bufs if name not in views or views[name][1]] conv_prg = CLProgram(kernel_name, f"""{''.join([x[0] for x in views.values()])} - __kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + [x[0] for x in params])}) {{ - float acc = 0.0; int gid = get_global_id(0); {ints} {conv_src} - {''.join([f'float {name} = get_{name}({name}_g, gid);' if views[name][1] else f'float {name} = get_{name}(gid);' for name, _ in ewbufs])} - output[gid] = {code}; + __kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + [x[0] for x in params])}) {{ {ints} + float acc = {start}; int gid = get_global_id(0); int idx = gid; {view.expr.replace('//', '/')}; {conv_src} + {''.join([ls for ls, _ in loop[::-1]])} + {''.join([f'float {name} = get_{name}({name}_g, idx);' if views[name][1] else f'float {name} = get_{name}(idx);' for name, _ in ewbufs])} + acc = {code}; + {''.join([le for _, le in loop])} + output[gid] = acc; }}""", argdtypes=tuple(None if i < 1+len(buf_types) else np.int32 for i in range(1+len(buf_types)+len(params)))) conv_prg(global_size, None, ret.cl, *[buf.cl for name, buf in bufs if name not in views or views[name][1]], *[x[1] for x in params]) return ret