From 0260406f497c69a91608b8f6bf904be2f4ffbfb5 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 18 Feb 2026 11:46:26 -0500 Subject: [PATCH] simplify reshape_multi [pr] (#14864) --- tinygrad/schedule/multi.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 7abe4754e0..5886c4f183 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -153,11 +153,9 @@ def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape)) def reshape_multi(root:UOp, multi:UOp): - arg = root.marg - if (new_axis:=root.axis) is None: return multi.src[0].reshape(arg).multi(new_axis) - assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)" - new_shape_axis = prod(multi.src[0].shape[multi.axis:]) // prod(arg[new_axis+1:]) - return multi.src[0].reshape(tuple(s if a!=new_axis else new_shape_axis for a,s in enumerate(arg))).multi(new_axis) + if prod(multi.shape) != prod(new_shape:=root.marg): raise RuntimeError("reshape must maintain prod(shape)") + if (new_axis:=root.axis) is not None: new_shape = tuple(s//len(multi.device) if a==new_axis else s for a,s in enumerate(new_shape)) + return multi.src[0].reshape(new_shape).multi(new_axis) def expand_multi(root:UOp, multi:UOp): # NOTE: this assert isn't needed, sharded axis can have dim 1