diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index f4d644d1f1..cb56081026 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -6,7 +6,7 @@ from collections import defaultdict from typing import List, Tuple, Optional, Dict, Union, Set from tinygrad.helpers import prod, ConvArgs from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps -from tinygrad.shapetracker import ShapeTracker, View, strides_for_shape +from tinygrad.shapetracker import ShapeTracker CLCACHE = int(os.getenv("CLCACHE", "1")) class CLBuffer: @@ -117,30 +117,24 @@ class GPUBuffer: def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, start="0.0", reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc") -> GPUBuffer: assert C is None - # this takes a ret index to an inp index, indexing 0 on the reduced strides - # if it's not a reduce, this should be a NOOP + # get the input/output shape and the reduce amount reduce_shape = (bufs[0][1].shape, ret.shape) if reduce_shape is None else reduce_shape - view = View(reduce_shape[1], strides_for_shape(reduce_shape[0])) - loop : List[Tuple[str, str]] = [] - if reduce_shape[1] != reduce_shape[0]: # this is a reduce - # reverse operation of expand, this validates inputs - # generate loops with combined adjacent reduce axis - acc = 1 - for shp,stride in ShapeTracker(reduce_shape[1]).movement_op(MovementOps.EXPAND, reduce_shape[0]).views[-1].shape_strides[::-1]: - if stride == 0: loop.append((f"for (int axis_{len(loop)} = 0; axis_{len(loop)} < {shp}; axis_{len(loop)}++) {{", f"idx += {acc}; }} idx -= {shp*acc};")) - acc *= shp + red = prod([s for s,n in zip(*reduce_shape) if n == 1]) - kernel_name = "reduce" if len(loop) > 0 else "elementwise" + # if it's a partial reduce, assert last non reduced axis is before the first reduced axis + if red > 1 and prod(ret.shape) != 1: assert max([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s == n and n != 1]) < min([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s != 1 and n == 1]) + + kernel_name = "reduce" if red > 1 else "elementwise" views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs} buf_types = [f"__global const float *{name}_g" for name, _ in bufs if name not in views or views[name][1]] conv_prg = CLProgram(kernel_name, f"""{chr(10).join([x[0] for x in views.values()])} __kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types)}) {{ - float acc = {start}; int gid = get_global_id(0); int idx = gid; {view.expr.replace('//', '/')}; - {' '.join([ls for ls, _ in loop[::-1]])} + float acc = {start}; int gid = get_global_id(0); + for (int idx = gid * {red}; idx < gid * {red} + {red}; idx++) {{ {chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in bufs if name in earlybufs])} acc = {earlycode}; - {' '.join([le for _, le in loop])} idx = gid; -{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in bufs if name not in earlybufs])} + }} +{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, gid);' if views[name][1] else f'get_{name}(gid);') for name, _ in bufs if name not in earlybufs])} output[gid] = {code}; }}""", argdtypes=tuple(None if i < 1+len(buf_types) else np.int32 for i in range(1+len(buf_types)))) conv_prg([prod(ret.shape), 1, 1], None, ret.cl, *[buf.cl for name, buf in bufs if name not in views or views[name][1]], op_estimate=prod(reduce_shape[0])*len(earlybufs) + prod(reduce_shape[1])*len(bufs))