reduce with loops

This commit is contained in:
George Hotz
2022-06-14 16:38:33 -07:00
parent 6261a0639b
commit 906cce9916
2 changed files with 25 additions and 34 deletions

View File

@@ -3,7 +3,7 @@ import numpy as np
import pyopencl as cl
from tinygrad.helpers import prod
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
from tinygrad.shapetracker import ShapeTracker
from tinygrad.shapetracker import ShapeTracker, View, strides_for_shape
cl_ctx, cl_queue = None, None
def get_cl_ctx(): return cl_ctx
@@ -90,38 +90,31 @@ def reduce_op(op, inp, ret):
if op == ReduceOps.SUM: code, start = "out += a", "0.0"
elif op == ReduceOps.MAX: code, start = "out = max(a,out)", "-INFINITY"
else: raise Exception(f"{op} isn't supported")
# TODO: this is insanely slow
# NOTE: ret.shape can be (1,), it's mostly by luck that this works
reduce = clbuild("reduce", """
__kernel void reduce(__global const float *a_g, int sz, __global float *res_g, int prod, int n_dims,
__global const int *shape_x, __global const int *shape_ret) {
int gid = get_global_id(0);
float out = """+start+""";
for (int x = 0; x < sz; x++) {
int idx = 0; // compute index into a_g
int tprod = prod;
int tsz = sz;
for (int dim = 0; dim < n_dims; dim++) {
idx *= shape_x[dim];
if (shape_x[dim] == shape_ret[dim]) { // dim from gid, don't reduce
tprod /= shape_x[dim];
idx += (gid / tprod) % shape_x[dim];
} else { // dim from x
tsz /= shape_x[dim];
idx += (x / tsz) % shape_x[dim];
}
}
float a = a_g[idx];
"""+code+""";
}
# reverse operation of expand
# this take a ret index to an inp index
view = View(ret.shape, strides_for_shape(inp.shape))
acc = 1
loop_start, loop_end = [], []
for i,o in list(zip(inp.shape, ret.shape))[::-1]:
if i != o: # reduce axis
assert o == 1
loop_start.append(f"for (int axis_{len(loop_start)} = 0; axis_{len(loop_start)} < {i}; axis_{len(loop_start)}++) {{")
loop_end.append(f"idx += {acc}; }} idx -= {i}*{acc};")
acc *= i
prg = """
__kernel void reduce(__global const float *a_g, __global float *res_g) {
int gid = get_global_id(0); int idx = gid;"""+view.expr.replace('//', '/')+""";
float out = """+start+""";\n"""+ \
'\n'.join(loop_start[::-1])+"""
float a = a_g[idx];
"""+code+""";\n"""+ \
'\n'.join(loop_end)+"""
res_g[gid] = out;
}""")
reduce([prod(ret.shape)], None, inp.cl,
i32(prod(inp.shape)//prod(ret.shape)), ret.cl,
i32(prod(ret.shape)), i32(len(ret.shape)),
buffer_np(np.array(inp.shape, dtype=np.int32)),
buffer_np(np.array(ret.shape, dtype=np.int32)))
}"""
clbuild("reduce", prg)([prod(ret.shape)], None, inp.cl, ret.cl)
def contiguous(x, ret, st):
clbuild("contiguous", """__kernel void contiguous(__global const float *x, __global float *ret) {

View File

@@ -78,10 +78,8 @@ class ShapeTracker:
strides = strides_for_shape(self.shape)
self.views.append(View([self.shape[a] for a in axis], [strides[a] for a in axis]))
def slice(self, *arg): # NOTE: this slice cannot pad the edges
def slice(self, *arg):
assert len(arg) == len(self.shape)
# if you pass in an expansion, it will be correct for the shrink, but return junk in the expanded region
#assert all([x>=0 and y<=self.shape[i] for i,(x,y) in enumerate(arg)])
strides = strides_for_shape(self.shape)
offset = sum([strides[i]*x for i,(x,_) in enumerate(arg)])
self.views += [View([y-x for x,y in arg], strides, offset), ZeroView(self.shape, arg)]