diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index cb56081026..81b921cfec 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -112,9 +112,9 @@ class GPUBuffer: def binary_op(x, op:BinaryOps, y:GPUBuffer): return type(x)(x.shape)._processing_op([("A", x), ("B", y)], GPUBuffer.code_for_op[op]) def contiguous_op(x): return x if x.st.contiguous else x.unary_op(UnaryOps.NOOP) def movement_op(x, op:MovementOps, arg) -> GPUBuffer: return type(x)(ShapeTracker(x.st).movement_op(op, arg), x) - def reduce_op(x, op:ReduceOps, new_shape:Tuple[int, ...]): return type(x)(new_shape)._processing_op([("A", x)], code="acc", earlycode=GPUBuffer.code_for_op[op], earlybufs=set("A"), start=GPUBuffer.start_for_op[op]) + def reduce_op(x, op:ReduceOps, new_shape:Tuple[int, ...]): return type(x)(new_shape)._processing_op([("A", x)], code="acc", earlycode=GPUBuffer.code_for_op[op], earlybufs=set("A"), op=op) - def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, start="0.0", reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc") -> GPUBuffer: + def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, op=ReduceOps.SUM, reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc") -> GPUBuffer: assert C is None # get the input/output shape and the reduce amount @@ -123,19 +123,27 @@ class GPUBuffer: # if it's a partial reduce, assert last non reduced axis is before the first reduced axis if red > 1 and prod(ret.shape) != 1: assert max([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s == n and n != 1]) < min([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s != 1 and n == 1]) + inter_red = 256 if (prod(ret.shape) < 256 and red >= 256) else 1 kernel_name = "reduce" if red > 1 else "elementwise" views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs} buf_types = [f"__global const float *{name}_g" for name, _ in bufs if name not in views or views[name][1]] conv_prg = CLProgram(kernel_name, f"""{chr(10).join([x[0] for x in views.values()])} - __kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types)}) {{ - float acc = {start}; int gid = get_global_id(0); - for (int idx = gid * {red}; idx < gid * {red} + {red}; idx++) {{ + __kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + (["__local float *temp"] if inter_red > 1 else []))}) {{ + float acc = {GPUBuffer.start_for_op[op]}; int gid = get_global_id(0); int mid = get_global_id(1); + for (int idx = gid * {red} + {red//inter_red + 1} * mid; idx < gid * {red} + min({red}, {red//inter_red + 1} * (mid+1)); idx++) {{ {chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in bufs if name in earlybufs])} acc = {earlycode}; + }}"""+(f""" + temp[mid] = acc; barrier(CLK_LOCAL_MEM_FENCE); + if (mid == 0) {{ acc = {GPUBuffer.start_for_op[op]}; + for (int rdx = 0; rdx < {inter_red}; rdx++) {{ acc = {GPUBuffer.code_for_op[op].replace('A', 'temp[rdx]')}; }} + """ if inter_red != 1 else "{")+f""" +{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, gid);' if views[name][1] else f'get_{name}(gid);') for name, _ in bufs if name not in earlybufs])} + output[gid] = {code}; }} -{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, gid);' if views[name][1] else f'get_{name}(gid);') for name, _ in bufs if name not in earlybufs])} - output[gid] = {code}; - }}""", argdtypes=tuple(None if i < 1+len(buf_types) else np.int32 for i in range(1+len(buf_types)))) - conv_prg([prod(ret.shape), 1, 1], None, ret.cl, *[buf.cl for name, buf in bufs if name not in views or views[name][1]], op_estimate=prod(reduce_shape[0])*len(earlybufs) + prod(reduce_shape[1])*len(bufs)) + }}""") + conv_prg([prod(ret.shape), inter_red, 1], [1, inter_red, 1] if inter_red > 1 else None, ret.cl, + *([buf.cl for name, buf in bufs if name not in views or views[name][1]] + ([cl.LocalMemory(inter_red*4)] if inter_red > 1 else [])), + op_estimate=prod(reduce_shape[0])*len(earlybufs) + prod(reduce_shape[1])*len(bufs)) return ret diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 7fbecb4bf9..50a8142fb4 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -123,7 +123,7 @@ def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer buf_names : Dict[LazyBuffer, str] = {x:f"arg_{i}" for i,x in enumerate(real_srcs.keys())} return self.dbuffer(self.shape)._processing_op([(buf_names[lb], db) for lb,db in real_srcs.items()], - earlycode=_ast(LazyOp(self.op.op, (src.op,), self.op.arg), buf_names, self.dbuffer.code_for_op), earlybufs=buf_names.values(), start=self.dbuffer.start_for_op[self.op.op]), \ + earlycode=_ast(LazyOp(self.op.op, (src.op,), self.op.arg), buf_names, self.dbuffer.code_for_op), earlybufs=buf_names.values(), op=self.op.op), \ list(real_srcs.values()), ReduceOps else: real_src = src.realize(self.device)