mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
1s are always mergable
This commit is contained in:
@@ -13,7 +13,7 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup
|
||||
assert len(shape) == len(strides)
|
||||
ret = [(shape[0], strides[0])]
|
||||
for i in range(1, len(shape)):
|
||||
if (strides[i] != 0 and ret[-1][1] == shape[i]*strides[i]) or (strides[i] == 0 and ret[-1][1] == 0):
|
||||
if (strides[i] != 0 and ret[-1][1] == shape[i]*strides[i]) or ret[-1][0] == 1 or (strides[i] == 0 and ret[-1][1] == 0):
|
||||
ret[-1] = (ret[-1][0] * shape[i], strides[i])
|
||||
else:
|
||||
ret.append((shape[i], strides[i]))
|
||||
|
||||
Reference in New Issue
Block a user