From 906cce9916cb43ed6d28c2d950fb17420dd60085 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 14 Jun 2022 16:38:33 -0700 Subject: [PATCH] reduce with loops --- tinygrad/llops/ops_gpu.py | 55 +++++++++++++++++---------------------- tinygrad/shapetracker.py | 4 +-- 2 files changed, 25 insertions(+), 34 deletions(-) diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index b6d4e7797f..7c10ef1554 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -3,7 +3,7 @@ import numpy as np import pyopencl as cl from tinygrad.helpers import prod from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps -from tinygrad.shapetracker import ShapeTracker +from tinygrad.shapetracker import ShapeTracker, View, strides_for_shape cl_ctx, cl_queue = None, None def get_cl_ctx(): return cl_ctx @@ -90,38 +90,31 @@ def reduce_op(op, inp, ret): if op == ReduceOps.SUM: code, start = "out += a", "0.0" elif op == ReduceOps.MAX: code, start = "out = max(a,out)", "-INFINITY" else: raise Exception(f"{op} isn't supported") - # TODO: this is insanely slow - # NOTE: ret.shape can be (1,), it's mostly by luck that this works - reduce = clbuild("reduce", """ - __kernel void reduce(__global const float *a_g, int sz, __global float *res_g, int prod, int n_dims, - __global const int *shape_x, __global const int *shape_ret) { - int gid = get_global_id(0); - float out = """+start+"""; - for (int x = 0; x < sz; x++) { - int idx = 0; // compute index into a_g - int tprod = prod; - int tsz = sz; - for (int dim = 0; dim < n_dims; dim++) { - idx *= shape_x[dim]; - if (shape_x[dim] == shape_ret[dim]) { // dim from gid, don't reduce - tprod /= shape_x[dim]; - idx += (gid / tprod) % shape_x[dim]; - } else { // dim from x - tsz /= shape_x[dim]; - idx += (x / tsz) % shape_x[dim]; - } - } - float a = a_g[idx]; - """+code+"""; - } + # reverse operation of expand + # this take a ret index to an inp index + view = View(ret.shape, strides_for_shape(inp.shape)) + + acc = 1 + loop_start, loop_end = [], [] + for i,o in list(zip(inp.shape, ret.shape))[::-1]: + if i != o: # reduce axis + assert o == 1 + loop_start.append(f"for (int axis_{len(loop_start)} = 0; axis_{len(loop_start)} < {i}; axis_{len(loop_start)}++) {{") + loop_end.append(f"idx += {acc}; }} idx -= {i}*{acc};") + acc *= i + + prg = """ + __kernel void reduce(__global const float *a_g, __global float *res_g) { + int gid = get_global_id(0); int idx = gid;"""+view.expr.replace('//', '/')+"""; + float out = """+start+""";\n"""+ \ + '\n'.join(loop_start[::-1])+""" + float a = a_g[idx]; + """+code+""";\n"""+ \ + '\n'.join(loop_end)+""" res_g[gid] = out; - }""") - reduce([prod(ret.shape)], None, inp.cl, - i32(prod(inp.shape)//prod(ret.shape)), ret.cl, - i32(prod(ret.shape)), i32(len(ret.shape)), - buffer_np(np.array(inp.shape, dtype=np.int32)), - buffer_np(np.array(ret.shape, dtype=np.int32))) + }""" + clbuild("reduce", prg)([prod(ret.shape)], None, inp.cl, ret.cl) def contiguous(x, ret, st): clbuild("contiguous", """__kernel void contiguous(__global const float *x, __global float *ret) { diff --git a/tinygrad/shapetracker.py b/tinygrad/shapetracker.py index f7d8cae4cb..393d132466 100644 --- a/tinygrad/shapetracker.py +++ b/tinygrad/shapetracker.py @@ -78,10 +78,8 @@ class ShapeTracker: strides = strides_for_shape(self.shape) self.views.append(View([self.shape[a] for a in axis], [strides[a] for a in axis])) - def slice(self, *arg): # NOTE: this slice cannot pad the edges + def slice(self, *arg): assert len(arg) == len(self.shape) - # if you pass in an expansion, it will be correct for the shrink, but return junk in the expanded region - #assert all([x>=0 and y<=self.shape[i] for i,(x,y) in enumerate(arg)]) strides = strides_for_shape(self.shape) offset = sum([strides[i]*x for i,(x,_) in enumerate(arg)]) self.views += [View([y-x for x,y in arg], strides, offset), ZeroView(self.shape, arg)]