mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
tensor.py cleanup around Tensor.slice (#2921)
use None for no-op slice and pad
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import math
|
||||
from typing import Tuple, Optional, cast
|
||||
from typing import Tuple, Optional
|
||||
from tinygrad.helpers import argsort, DType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
|
||||
from tinygrad.tensor import Function
|
||||
@@ -203,9 +203,7 @@ class Shrink(Function):
|
||||
return x.shrink(arg)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
assert all(isinstance(x[0], int) and isinstance(x[1], int) for x in self.narg), "symbolic shrink does not support backward"
|
||||
# need this cast because mypy cannot narrow the type even with assert
|
||||
return grad_output.pad(cast(Tuple[Tuple[int, int], ...], self.narg))
|
||||
return grad_output.pad(self.narg)
|
||||
|
||||
class Flip(Function):
|
||||
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
|
||||
@@ -406,9 +406,9 @@ class Tensor:
|
||||
|
||||
# NOTE: using slice is discouraged and things should migrate to pad and shrink
|
||||
def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
|
||||
arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)])
|
||||
padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)])
|
||||
return self.pad(padding, value=value).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)]))
|
||||
arg_ = tuple(a if a is not None else (0, s) for s,a in zip(self.shape, arg))
|
||||
padding = tuple((max(0, -l), max(0, r-s)) for s,(l,r) in zip(self.shape, arg_))
|
||||
return self.pad(padding, value=value).shrink(tuple((l + pl, r + pl) for (l,r),(pl,_) in zip(arg_, padding)))
|
||||
|
||||
def gather(self:Tensor, idx:Tensor, dim:int) -> Tensor:
|
||||
assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim"
|
||||
@@ -420,13 +420,13 @@ class Tensor:
|
||||
return ((idx == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim) # noqa: E501
|
||||
|
||||
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
||||
dim = (dim + len(self.shape)) if dim < 0 else dim
|
||||
if dim < 0: dim += self.ndim
|
||||
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
|
||||
catargs = [self, *args]
|
||||
assert all(t.shape for t in catargs), "zero-dimensional tensor cannot be concatenated"
|
||||
shapes = [s.shape[dim] for s in catargs]
|
||||
shape_cumsum = [0, *accumulate(shapes)]
|
||||
slc:List[List[Tuple[sint, sint]]] = [[(0, 0) for _ in self.shape] for _ in catargs]
|
||||
slc:List[List[Optional[Tuple[sint, sint]]]] = [[None for _ in self.shape] for _ in catargs]
|
||||
for shp,k,s in zip(shapes, shape_cumsum[:-1], slc): s[dim] = (k, shape_cumsum[-1] - k - shp)
|
||||
return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
|
||||
|
||||
@@ -457,7 +457,7 @@ class Tensor:
|
||||
return self if self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
|
||||
|
||||
def unsqueeze(self, dim:int) -> Tensor:
|
||||
if dim < 0: dim = len(self.shape) + dim + 1
|
||||
if dim < 0: dim = self.ndim + dim + 1
|
||||
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
||||
|
||||
# (padding_left, padding_right, padding_top, padding_bottom)
|
||||
@@ -576,10 +576,10 @@ class Tensor:
|
||||
x, w = self, weight.reshape(groups, weight.shape[0]//groups, weight.shape[1], *weight.shape[2:]).permute(0,2,1,*trailing).flip(trailing)
|
||||
stride = make_pair(stride, len(HW))
|
||||
if any(s>1 for s in stride):
|
||||
x = x.reshape(*x.shape[:2], *flatten((k,1) for k in x.shape[2:]))
|
||||
x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride)))
|
||||
x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)])
|
||||
x = x.shrink(((0,x.shape[0]), (0,x.shape[1]), *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
|
||||
x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
|
||||
x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
|
||||
x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
|
||||
x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
|
||||
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW))))))) # noqa: E501
|
||||
return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user