simplify reshape_multi [pr] (#14864)

This commit is contained in:
chenyu
2026-02-18 11:46:26 -05:00
committed by GitHub
parent 5746a605ce
commit 0260406f49

View File

@@ -153,11 +153,9 @@ def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint,
return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape))
def reshape_multi(root:UOp, multi:UOp):
arg = root.marg
if (new_axis:=root.axis) is None: return multi.src[0].reshape(arg).multi(new_axis)
assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)"
new_shape_axis = prod(multi.src[0].shape[multi.axis:]) // prod(arg[new_axis+1:])
return multi.src[0].reshape(tuple(s if a!=new_axis else new_shape_axis for a,s in enumerate(arg))).multi(new_axis)
if prod(multi.shape) != prod(new_shape:=root.marg): raise RuntimeError("reshape must maintain prod(shape)")
if (new_axis:=root.axis) is not None: new_shape = tuple(s//len(multi.device) if a==new_axis else s for a,s in enumerate(new_shape))
return multi.src[0].reshape(new_shape).multi(new_axis)
def expand_multi(root:UOp, multi:UOp):
# NOTE: this assert isn't needed, sharded axis can have dim 1