diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 1839045cac..25fd540921 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -260,20 +260,22 @@ class Kernel: shapes.append(self.output_shape) strides.append(special_strides) - # merge dimensions if we can, multi get_shape_strides + # merge dimensions if we can, multi _merge_dims # NOTE: this does not always preserve the reduce dimension # TODO: move this into shapetracker, with tests! - rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))] + # TODO: how does this work with multi-reduce? + rets = [[(s[0], st[0])] for s,st in zip(shapes, strides)] for i in range(1, len(shapes[0])): can_merge = [] - for j in range(len(shapes)): + for s,st,ret in zip(shapes, strides, rets): # TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case - can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0))) # noqa: E501 + si, sti, last_st = s[i], st[i], ret[-1][1] + can_merge.append((sti is not None) and ((sti != 0 and last_st == si*sti) or (sti == 0 and last_st == 0))) # more can merge than this mergeable = all(can_merge) and i != self.first_reduce - for j in range(len(shapes)): - if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i]) - else: rets[j].append((shapes[j][i], strides[j][i])) + for j,(s,st) in enumerate(zip(shapes, strides)): + if mergeable: rets[j][-1] = (rets[j][-1][0] * s[i], st[i]) + else: rets[j].append((s[i], st[i])) # do the reshapes for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))