much simpler reduce

This commit is contained in:
George Hotz
2022-08-20 08:12:43 -07:00
parent 1eb12dafbc
commit e34eb855fe

View File

@@ -6,7 +6,7 @@ from collections import defaultdict
from typing import List, Tuple, Optional, Dict, Union, Set
from tinygrad.helpers import prod, ConvArgs
from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps
from tinygrad.shapetracker import ShapeTracker, View, strides_for_shape
from tinygrad.shapetracker import ShapeTracker
CLCACHE = int(os.getenv("CLCACHE", "1"))
class CLBuffer:
@@ -117,30 +117,24 @@ class GPUBuffer:
def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, start="0.0", reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc") -> GPUBuffer:
assert C is None
# this takes a ret index to an inp index, indexing 0 on the reduced strides
# if it's not a reduce, this should be a NOOP
# get the input/output shape and the reduce amount
reduce_shape = (bufs[0][1].shape, ret.shape) if reduce_shape is None else reduce_shape
view = View(reduce_shape[1], strides_for_shape(reduce_shape[0]))
loop : List[Tuple[str, str]] = []
if reduce_shape[1] != reduce_shape[0]: # this is a reduce
# reverse operation of expand, this validates inputs
# generate loops with combined adjacent reduce axis
acc = 1
for shp,stride in ShapeTracker(reduce_shape[1]).movement_op(MovementOps.EXPAND, reduce_shape[0]).views[-1].shape_strides[::-1]:
if stride == 0: loop.append((f"for (int axis_{len(loop)} = 0; axis_{len(loop)} < {shp}; axis_{len(loop)}++) {{", f"idx += {acc}; }} idx -= {shp*acc};"))
acc *= shp
red = prod([s for s,n in zip(*reduce_shape) if n == 1])
kernel_name = "reduce" if len(loop) > 0 else "elementwise"
# if it's a partial reduce, assert last non reduced axis is before the first reduced axis
if red > 1 and prod(ret.shape) != 1: assert max([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s == n and n != 1]) < min([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s != 1 and n == 1])
kernel_name = "reduce" if red > 1 else "elementwise"
views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs}
buf_types = [f"__global const float *{name}_g" for name, _ in bufs if name not in views or views[name][1]]
conv_prg = CLProgram(kernel_name, f"""{chr(10).join([x[0] for x in views.values()])}
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types)}) {{
float acc = {start}; int gid = get_global_id(0); int idx = gid; {view.expr.replace('//', '/')};
{' '.join([ls for ls, _ in loop[::-1]])}
float acc = {start}; int gid = get_global_id(0);
for (int idx = gid * {red}; idx < gid * {red} + {red}; idx++) {{
{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in bufs if name in earlybufs])}
acc = {earlycode};
{' '.join([le for _, le in loop])} idx = gid;
{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in bufs if name not in earlybufs])}
}}
{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, gid);' if views[name][1] else f'get_{name}(gid);') for name, _ in bufs if name not in earlybufs])}
output[gid] = {code};
}}""", argdtypes=tuple(None if i < 1+len(buf_types) else np.int32 for i in range(1+len(buf_types))))
conv_prg([prod(ret.shape), 1, 1], None, ret.cl, *[buf.cl for name, buf in bufs if name not in views or views[name][1]], op_estimate=prod(reduce_shape[0])*len(earlybufs) + prod(reduce_shape[1])*len(bufs))