diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 7abe4754e0..5886c4f183 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -153,11 +153,9 @@ def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape)) def reshape_multi(root:UOp, multi:UOp): - arg = root.marg - if (new_axis:=root.axis) is None: return multi.src[0].reshape(arg).multi(new_axis) - assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)" - new_shape_axis = prod(multi.src[0].shape[multi.axis:]) // prod(arg[new_axis+1:]) - return multi.src[0].reshape(tuple(s if a!=new_axis else new_shape_axis for a,s in enumerate(arg))).multi(new_axis) + if prod(multi.shape) != prod(new_shape:=root.marg): raise RuntimeError("reshape must maintain prod(shape)") + if (new_axis:=root.axis) is not None: new_shape = tuple(s//len(multi.device) if a==new_axis else s for a,s in enumerate(new_shape)) + return multi.src[0].reshape(new_shape).multi(new_axis) def expand_multi(root:UOp, multi:UOp): # NOTE: this assert isn't needed, sharded axis can have dim 1