1s are always mergable

This commit is contained in:
George Hotz
2022-11-03 10:50:48 -07:00
parent c48fc47d01
commit caea34c529

View File

@@ -13,7 +13,7 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup
assert len(shape) == len(strides)
ret = [(shape[0], strides[0])]
for i in range(1, len(shape)):
if (strides[i] != 0 and ret[-1][1] == shape[i]*strides[i]) or (strides[i] == 0 and ret[-1][1] == 0):
if (strides[i] != 0 and ret[-1][1] == shape[i]*strides[i]) or ret[-1][0] == 1 or (strides[i] == 0 and ret[-1][1] == 0):
ret[-1] = (ret[-1][0] * shape[i], strides[i])
else:
ret.append((shape[i], strides[i]))