mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
much simpler reduce
This commit is contained in:
@@ -6,7 +6,7 @@ from collections import defaultdict
|
||||
from typing import List, Tuple, Optional, Dict, Union, Set
|
||||
from tinygrad.helpers import prod, ConvArgs
|
||||
from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps
|
||||
from tinygrad.shapetracker import ShapeTracker, View, strides_for_shape
|
||||
from tinygrad.shapetracker import ShapeTracker
|
||||
|
||||
CLCACHE = int(os.getenv("CLCACHE", "1"))
|
||||
class CLBuffer:
|
||||
@@ -117,30 +117,24 @@ class GPUBuffer:
|
||||
def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, start="0.0", reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc") -> GPUBuffer:
|
||||
assert C is None
|
||||
|
||||
# this takes a ret index to an inp index, indexing 0 on the reduced strides
|
||||
# if it's not a reduce, this should be a NOOP
|
||||
# get the input/output shape and the reduce amount
|
||||
reduce_shape = (bufs[0][1].shape, ret.shape) if reduce_shape is None else reduce_shape
|
||||
view = View(reduce_shape[1], strides_for_shape(reduce_shape[0]))
|
||||
loop : List[Tuple[str, str]] = []
|
||||
if reduce_shape[1] != reduce_shape[0]: # this is a reduce
|
||||
# reverse operation of expand, this validates inputs
|
||||
# generate loops with combined adjacent reduce axis
|
||||
acc = 1
|
||||
for shp,stride in ShapeTracker(reduce_shape[1]).movement_op(MovementOps.EXPAND, reduce_shape[0]).views[-1].shape_strides[::-1]:
|
||||
if stride == 0: loop.append((f"for (int axis_{len(loop)} = 0; axis_{len(loop)} < {shp}; axis_{len(loop)}++) {{", f"idx += {acc}; }} idx -= {shp*acc};"))
|
||||
acc *= shp
|
||||
red = prod([s for s,n in zip(*reduce_shape) if n == 1])
|
||||
|
||||
kernel_name = "reduce" if len(loop) > 0 else "elementwise"
|
||||
# if it's a partial reduce, assert last non reduced axis is before the first reduced axis
|
||||
if red > 1 and prod(ret.shape) != 1: assert max([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s == n and n != 1]) < min([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s != 1 and n == 1])
|
||||
|
||||
kernel_name = "reduce" if red > 1 else "elementwise"
|
||||
views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs}
|
||||
buf_types = [f"__global const float *{name}_g" for name, _ in bufs if name not in views or views[name][1]]
|
||||
conv_prg = CLProgram(kernel_name, f"""{chr(10).join([x[0] for x in views.values()])}
|
||||
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types)}) {{
|
||||
float acc = {start}; int gid = get_global_id(0); int idx = gid; {view.expr.replace('//', '/')};
|
||||
{' '.join([ls for ls, _ in loop[::-1]])}
|
||||
float acc = {start}; int gid = get_global_id(0);
|
||||
for (int idx = gid * {red}; idx < gid * {red} + {red}; idx++) {{
|
||||
{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in bufs if name in earlybufs])}
|
||||
acc = {earlycode};
|
||||
{' '.join([le for _, le in loop])} idx = gid;
|
||||
{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in bufs if name not in earlybufs])}
|
||||
}}
|
||||
{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, gid);' if views[name][1] else f'get_{name}(gid);') for name, _ in bufs if name not in earlybufs])}
|
||||
output[gid] = {code};
|
||||
}}""", argdtypes=tuple(None if i < 1+len(buf_types) else np.int32 for i in range(1+len(buf_types))))
|
||||
conv_prg([prod(ret.shape), 1, 1], None, ret.cl, *[buf.cl for name, buf in bufs if name not in views or views[name][1]], op_estimate=prod(reduce_shape[0])*len(earlybufs) + prod(reduce_shape[1])*len(bufs))
|
||||
|
||||
Reference in New Issue
Block a user