mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Const pad support to pad2d and slice (#1392)
* slice to pad2d migrate * Gain line * Mypy happy * Mypy happy * Revert * whitespace
This commit is contained in:
@@ -251,10 +251,10 @@ class Tensor:
|
||||
# ***** movement hlops *****
|
||||
|
||||
# NOTE: using slice is discouraged and things should migrate to pad and shrink
|
||||
def slice(self, arg:Sequence[Optional[Tuple[int, int]]]) -> Tensor:
|
||||
def slice(self, arg:Sequence[Optional[Tuple[int, int]]], value:float=0) -> Tensor:
|
||||
arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)])
|
||||
padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)])
|
||||
return self.pad(padding).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)]))
|
||||
return self.pad(padding, value=value).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)]))
|
||||
|
||||
# - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
|
||||
# - A slice i:j returns the elements with indices in [i, j)
|
||||
@@ -377,9 +377,9 @@ class Tensor:
|
||||
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
||||
|
||||
# (padding_left, padding_right, padding_top, padding_bottom)
|
||||
def pad2d(self, padding:Union[List[int], Tuple[int, ...]]):
|
||||
def pad2d(self, padding:Union[List[int], Tuple[int, ...]], value:float=0):
|
||||
slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1]
|
||||
return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc)
|
||||
return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
|
||||
|
||||
@property
|
||||
def T(self) -> Tensor: return self.transpose()
|
||||
|
||||
Reference in New Issue
Block a user