cleanup kernel simplify_merge_adjacent (#4852)

cleanup kernel simplify_merge_adjacent
This commit is contained in:
chenyu
2024-06-06 12:04:54 -04:00
committed by GitHub
parent 47bfd7c2b7
commit dd5378378b

View File

@@ -260,20 +260,22 @@ class Kernel:
shapes.append(self.output_shape)
strides.append(special_strides)
# merge dimensions if we can, multi get_shape_strides
# merge dimensions if we can, multi _merge_dims
# NOTE: this does not always preserve the reduce dimension
# TODO: move this into shapetracker, with tests!
rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))]
# TODO: how does this work with multi-reduce?
rets = [[(s[0], st[0])] for s,st in zip(shapes, strides)]
for i in range(1, len(shapes[0])):
can_merge = []
for j in range(len(shapes)):
for s,st,ret in zip(shapes, strides, rets):
# TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0))) # noqa: E501
si, sti, last_st = s[i], st[i], ret[-1][1]
can_merge.append((sti is not None) and ((sti != 0 and last_st == si*sti) or (sti == 0 and last_st == 0)))
# more can merge than this
mergeable = all(can_merge) and i != self.first_reduce
for j in range(len(shapes)):
if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i])
else: rets[j].append((shapes[j][i], strides[j][i]))
for j,(s,st) in enumerate(zip(shapes, strides)):
if mergeable: rets[j][-1] = (rets[j][-1][0] * s[i], st[i])
else: rets[j].append((s[i], st[i]))
# do the reshapes
for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))