From a8aeebfb0c28a7ab18215a9ebb924a82c24b454e Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 14 Jun 2022 17:08:12 -0700 Subject: [PATCH] use shapetracker to combine adj reduce axis --- tinygrad/llops/ops_gpu.py | 17 +++++++++-------- tinygrad/shapetracker.py | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 7c10ef1554..84cacf6fef 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -91,18 +91,19 @@ def reduce_op(op, inp, ret): elif op == ReduceOps.MAX: code, start = "out = max(a,out)", "-INFINITY" else: raise Exception(f"{op} isn't supported") - # reverse operation of expand - # this take a ret index to an inp index + # reverse operation of expand, this validates inputs + st = ShapeTracker(*ret.shape).movement_op(MovementOps.EXPAND, inp.shape) + # this takes a ret index to an inp index, indexing 0 on the reduced strides view = View(ret.shape, strides_for_shape(inp.shape)) + # combined adjacent reduce axis acc = 1 loop_start, loop_end = [], [] - for i,o in list(zip(inp.shape, ret.shape))[::-1]: - if i != o: # reduce axis - assert o == 1 - loop_start.append(f"for (int axis_{len(loop_start)} = 0; axis_{len(loop_start)} < {i}; axis_{len(loop_start)}++) {{") - loop_end.append(f"idx += {acc}; }} idx -= {i}*{acc};") - acc *= i + for shp,stride in st.views[-1].shape_strides[::-1]: + if stride == 0: + loop_start.append(f"for (int axis_{len(loop_start)} = 0; axis_{len(loop_start)} < {shp}; axis_{len(loop_start)}++) {{") + loop_end.append(f"idx += {acc}; }} idx -= {shp*acc};") + acc *= shp prg = """ __kernel void reduce(__global const float *a_g, __global float *res_g) { diff --git a/tinygrad/shapetracker.py b/tinygrad/shapetracker.py index 393d132466..0c46db3b9b 100644 --- a/tinygrad/shapetracker.py +++ b/tinygrad/shapetracker.py @@ -13,7 +13,7 @@ class View: self.shape_strides = [(shape[0], strides[0])] for i in range(1, len(shape)): - if strides[i] != 0 and self.shape_strides[-1][1]//strides[i] == shape[i]: + if (strides[i] != 0 and self.shape_strides[-1][1]//strides[i] == shape[i]) or (strides[i] == 0 and self.shape_strides[-1][1] == 0): self.shape_strides[-1] = (self.shape_strides[-1][0] * shape[i], strides[i]) else: self.shape_strides.append((shape[i], strides[i]))