pad_to to mixin [pr] (#15365)

This commit is contained in:
chenyu
2026-03-19 05:02:01 -04:00
committed by GitHub
parent 1abb6297f6
commit d81b03cff4
2 changed files with 3 additions and 5 deletions

View File

@@ -174,6 +174,9 @@ class MovementMixin:
def shrink_to(self, shape, *args) -> Self:
return self.shrink(tuple([None if ns is None else (0, ns) for ns in argfix(shape, *args)]))
def pad_to(self, shape, *args) -> Self:
return self._mop(Ops.PAD, tuple([(0, 0 if ns is None else ns-s) for s,ns in zip(self.shape, argfix(shape, *args), strict=True)]))
def view(self, shape, *args) -> Self:
"""`.view` is an alias for `.reshape`."""
return self.reshape(shape, *args)

View File

@@ -1168,11 +1168,6 @@ class Tensor(OpMixin):
if mode in {"reflect", "replicate"}: return self._pad_reflect_replicate(pX, mode)
raise NotImplementedError(f"{mode=} is not supported")
# convenience
def pad_to(self, shape, *args):
if len(new_shape := argfix(shape, *args)) != self.ndim: raise ValueError(f"dim mismatch, cannot pad {self.shape} to {new_shape}")
return self.pad(tuple([None if ns is None else (0, ns-s) for s,ns in zip(self.shape, new_shape)]))
# ***** movement high level ops *****
def _getitem(self, indices, v: Tensor|None = None) -> Tensor: