mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
pad_to to mixin [pr] (#15365)
This commit is contained in:
@@ -174,6 +174,9 @@ class MovementMixin:
|
||||
def shrink_to(self, shape, *args) -> Self:
|
||||
return self.shrink(tuple([None if ns is None else (0, ns) for ns in argfix(shape, *args)]))
|
||||
|
||||
def pad_to(self, shape, *args) -> Self:
|
||||
return self._mop(Ops.PAD, tuple([(0, 0 if ns is None else ns-s) for s,ns in zip(self.shape, argfix(shape, *args), strict=True)]))
|
||||
|
||||
def view(self, shape, *args) -> Self:
|
||||
"""`.view` is an alias for `.reshape`."""
|
||||
return self.reshape(shape, *args)
|
||||
|
||||
@@ -1168,11 +1168,6 @@ class Tensor(OpMixin):
|
||||
if mode in {"reflect", "replicate"}: return self._pad_reflect_replicate(pX, mode)
|
||||
raise NotImplementedError(f"{mode=} is not supported")
|
||||
|
||||
# convenience
|
||||
def pad_to(self, shape, *args):
|
||||
if len(new_shape := argfix(shape, *args)) != self.ndim: raise ValueError(f"dim mismatch, cannot pad {self.shape} to {new_shape}")
|
||||
return self.pad(tuple([None if ns is None else (0, ns-s) for s,ns in zip(self.shape, new_shape)]))
|
||||
|
||||
# ***** movement high level ops *****
|
||||
|
||||
def _getitem(self, indices, v: Tensor|None = None) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user