From 287699c32caae565c0875324f6259f33cb272db6 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sat, 14 Jan 2023 10:51:43 -0800 Subject: [PATCH] simplify ones after axis splitting --- test/external_test_image.py | 2 +- tinygrad/llops/ops_gpu.py | 1 + tinygrad/ops.py | 33 +++++++++++++++++++-------------- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/test/external_test_image.py b/test/external_test_image.py index c94a1b5e1d..1f4156713d 100644 --- a/test/external_test_image.py +++ b/test/external_test_image.py @@ -45,7 +45,7 @@ class TestImage(unittest.TestCase): def test_op_conv(self): bs, in_chans, out_chans = 1,12,32 - tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=0) + tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=1) tiny_dconv = Conv2d(out_chans, out_chans, 1, bias=None, padding=0) tiny_dat = Tensor.ones(bs, 12, 64, 128) p2 = tiny_conv(tiny_dat).relu() diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index a769b68369..da878561a6 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -251,6 +251,7 @@ class CLASTKernel(ASTKernel): [i for i in range(self.shape_len+1) if i != lb_valid+1] + [lb_valid+1]) self.late_are_float4 = True + self.simplify_ones() self.output_shape = self.shapes[0][:self.first_reduce] if DEBUG >= 2: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index eef9534c60..18ea66a245 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -129,20 +129,28 @@ class ASTKernel: def process(self): # get shape, strides, and offset # if it's a multiview buffer we take the final view - shapes = [x.shape for x in self.bufs] - strides = [x.st.views[-1].strides for x in self.bufs] + self.shapes = [x.shape for x in self.bufs] + self.strides = [x.st.views[-1].strides for x in self.bufs] + self.offsets = [x.st.views[-1].offset for x in self.bufs] # include the offsets (as is) + self.last_reduce = len(self.shapes[0]) + self.simplify_ones() + self.simplify_merge_adjacent() + def simplify_ones(self): # remove places where the shape is all ones # TODO: this should be factored in to multi shape stride - all_ones = [all(s[i]==1 for s in shapes) for i in range(len(shapes[0]))] + all_ones = [all(s[i]==1 for s in self.shapes) for i in range(len(self.shapes[0]))] # keep at least 1 one if all(all_ones): all_ones[-1] = False - shapes = [[s[i] for i in range(len(s)) if not all_ones[i]] for s in shapes] - strides = [[s[i] for i in range(len(s)) if not all_ones[i]] for s in strides] - + self.shapes = [[s[i] for i in range(len(s)) if not all_ones[i]] for s in self.shapes] + self.strides = [[s[i] for i in range(len(s)) if not all_ones[i]] for s in self.strides] + self.last_reduce -= sum(all_ones) # find first mismatch, don't reduce this - first_reduce = get_first_reduce(shapes) + self.first_reduce = get_first_reduce(self.shapes) + + def simplify_merge_adjacent(self): + shapes, strides = self.shapes, self.strides # merge dimensions if we can, multi get_shape_strides # TODO: does this always preserve the reduce dimension, NO @@ -154,18 +162,16 @@ class ASTKernel: # TODO: added the always mergability of 1s, is this right? if so, add to shapetracker in the 1 case can_merge.append((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*strides[j][i]) or (strides[j][i] == 0 and rets[j][-1][1] == 0)) # more can merge than this - can_merge = all(can_merge) and i != first_reduce + can_merge = all(can_merge) and i != self.first_reduce + if can_merge: + self.last_reduce -= 1 for j in range(len(shapes)): if can_merge: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i]) else: rets[j].append((shapes[j][i], strides[j][i])) self.shapes, self.strides = [[y[0] for y in x] for x in rets], [[y[1] for y in x] for x in rets] - self.first_reduce = get_first_reduce(self.shapes) # update this if axis merged - self.last_reduce = len(self.shapes[0]) - - # include the offsets (as is) - self.offsets = [x.st.views[-1].offset for x in self.bufs] + self.first_reduce = get_first_reduce(self.shapes) @property def shape_len(self): return len(self.shapes[0]) @@ -186,4 +192,3 @@ class ASTKernel: new_shapes.append(st.shape) new_strides.append(st.strides) self.shapes, self.strides = new_shapes, new_strides -