diff --git a/tinygrad/shapetracker.py b/tinygrad/shapetracker.py index 47fe5c9175..2c937bbbee 100644 --- a/tinygrad/shapetracker.py +++ b/tinygrad/shapetracker.py @@ -13,7 +13,7 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup assert len(shape) == len(strides) ret = [(shape[0], strides[0])] for i in range(1, len(shape)): - if (strides[i] != 0 and ret[-1][1] == shape[i]*strides[i]) or (strides[i] == 0 and ret[-1][1] == 0): + if (strides[i] != 0 and ret[-1][1] == shape[i]*strides[i]) or ret[-1][0] == 1 or (strides[i] == 0 and ret[-1][1] == 0): ret[-1] = (ret[-1][0] * shape[i], strides[i]) else: ret.append((shape[i], strides[i]))