mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-28 00:08:16 -05:00
cleanup kernel simplify_merge_adjacent (#4852)
cleanup kernel simplify_merge_adjacent
This commit is contained in:
@@ -260,20 +260,22 @@ class Kernel:
|
||||
shapes.append(self.output_shape)
|
||||
strides.append(special_strides)
|
||||
|
||||
# merge dimensions if we can, multi get_shape_strides
|
||||
# merge dimensions if we can, multi _merge_dims
|
||||
# NOTE: this does not always preserve the reduce dimension
|
||||
# TODO: move this into shapetracker, with tests!
|
||||
rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))]
|
||||
# TODO: how does this work with multi-reduce?
|
||||
rets = [[(s[0], st[0])] for s,st in zip(shapes, strides)]
|
||||
for i in range(1, len(shapes[0])):
|
||||
can_merge = []
|
||||
for j in range(len(shapes)):
|
||||
for s,st,ret in zip(shapes, strides, rets):
|
||||
# TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
|
||||
can_merge.append(strides[j][i] is not None and ((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*cast(int, strides[j][i])) or (strides[j][i] == 0 and rets[j][-1][1] == 0))) # noqa: E501
|
||||
si, sti, last_st = s[i], st[i], ret[-1][1]
|
||||
can_merge.append((sti is not None) and ((sti != 0 and last_st == si*sti) or (sti == 0 and last_st == 0)))
|
||||
# more can merge than this
|
||||
mergeable = all(can_merge) and i != self.first_reduce
|
||||
for j in range(len(shapes)):
|
||||
if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i])
|
||||
else: rets[j].append((shapes[j][i], strides[j][i]))
|
||||
for j,(s,st) in enumerate(zip(shapes, strides)):
|
||||
if mergeable: rets[j][-1] = (rets[j][-1][0] * s[i], st[i])
|
||||
else: rets[j].append((s[i], st[i]))
|
||||
|
||||
# do the reshapes
|
||||
for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
|
||||
|
||||
Reference in New Issue
Block a user