mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
reduce with loops
This commit is contained in:
@@ -3,7 +3,7 @@ import numpy as np
|
||||
import pyopencl as cl
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
|
||||
from tinygrad.shapetracker import ShapeTracker
|
||||
from tinygrad.shapetracker import ShapeTracker, View, strides_for_shape
|
||||
|
||||
cl_ctx, cl_queue = None, None
|
||||
def get_cl_ctx(): return cl_ctx
|
||||
@@ -90,38 +90,31 @@ def reduce_op(op, inp, ret):
|
||||
if op == ReduceOps.SUM: code, start = "out += a", "0.0"
|
||||
elif op == ReduceOps.MAX: code, start = "out = max(a,out)", "-INFINITY"
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
# TODO: this is insanely slow
|
||||
# NOTE: ret.shape can be (1,), it's mostly by luck that this works
|
||||
reduce = clbuild("reduce", """
|
||||
__kernel void reduce(__global const float *a_g, int sz, __global float *res_g, int prod, int n_dims,
|
||||
__global const int *shape_x, __global const int *shape_ret) {
|
||||
int gid = get_global_id(0);
|
||||
|
||||
float out = """+start+""";
|
||||
for (int x = 0; x < sz; x++) {
|
||||
int idx = 0; // compute index into a_g
|
||||
int tprod = prod;
|
||||
int tsz = sz;
|
||||
for (int dim = 0; dim < n_dims; dim++) {
|
||||
idx *= shape_x[dim];
|
||||
if (shape_x[dim] == shape_ret[dim]) { // dim from gid, don't reduce
|
||||
tprod /= shape_x[dim];
|
||||
idx += (gid / tprod) % shape_x[dim];
|
||||
} else { // dim from x
|
||||
tsz /= shape_x[dim];
|
||||
idx += (x / tsz) % shape_x[dim];
|
||||
}
|
||||
}
|
||||
float a = a_g[idx];
|
||||
"""+code+""";
|
||||
}
|
||||
# reverse operation of expand
|
||||
# this take a ret index to an inp index
|
||||
view = View(ret.shape, strides_for_shape(inp.shape))
|
||||
|
||||
acc = 1
|
||||
loop_start, loop_end = [], []
|
||||
for i,o in list(zip(inp.shape, ret.shape))[::-1]:
|
||||
if i != o: # reduce axis
|
||||
assert o == 1
|
||||
loop_start.append(f"for (int axis_{len(loop_start)} = 0; axis_{len(loop_start)} < {i}; axis_{len(loop_start)}++) {{")
|
||||
loop_end.append(f"idx += {acc}; }} idx -= {i}*{acc};")
|
||||
acc *= i
|
||||
|
||||
prg = """
|
||||
__kernel void reduce(__global const float *a_g, __global float *res_g) {
|
||||
int gid = get_global_id(0); int idx = gid;"""+view.expr.replace('//', '/')+""";
|
||||
float out = """+start+""";\n"""+ \
|
||||
'\n'.join(loop_start[::-1])+"""
|
||||
float a = a_g[idx];
|
||||
"""+code+""";\n"""+ \
|
||||
'\n'.join(loop_end)+"""
|
||||
res_g[gid] = out;
|
||||
}""")
|
||||
reduce([prod(ret.shape)], None, inp.cl,
|
||||
i32(prod(inp.shape)//prod(ret.shape)), ret.cl,
|
||||
i32(prod(ret.shape)), i32(len(ret.shape)),
|
||||
buffer_np(np.array(inp.shape, dtype=np.int32)),
|
||||
buffer_np(np.array(ret.shape, dtype=np.int32)))
|
||||
}"""
|
||||
clbuild("reduce", prg)([prod(ret.shape)], None, inp.cl, ret.cl)
|
||||
|
||||
def contiguous(x, ret, st):
|
||||
clbuild("contiguous", """__kernel void contiguous(__global const float *x, __global float *ret) {
|
||||
|
||||
@@ -78,10 +78,8 @@ class ShapeTracker:
|
||||
strides = strides_for_shape(self.shape)
|
||||
self.views.append(View([self.shape[a] for a in axis], [strides[a] for a in axis]))
|
||||
|
||||
def slice(self, *arg): # NOTE: this slice cannot pad the edges
|
||||
def slice(self, *arg):
|
||||
assert len(arg) == len(self.shape)
|
||||
# if you pass in an expansion, it will be correct for the shrink, but return junk in the expanded region
|
||||
#assert all([x>=0 and y<=self.shape[i] for i,(x,y) in enumerate(arg)])
|
||||
strides = strides_for_shape(self.shape)
|
||||
offset = sum([strides[i]*x for i,(x,_) in enumerate(arg)])
|
||||
self.views += [View([y-x for x,y in arg], strides, offset), ZeroView(self.shape, arg)]
|
||||
|
||||
Reference in New Issue
Block a user