diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index d73b66b562..8f4cc62800 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -1,5 +1,5 @@ import math -from typing import Tuple, Optional, cast +from typing import Tuple, Optional from tinygrad.helpers import argsort, DType from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps from tinygrad.tensor import Function @@ -203,9 +203,7 @@ class Shrink(Function): return x.shrink(arg) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - assert all(isinstance(x[0], int) and isinstance(x[1], int) for x in self.narg), "symbolic shrink does not support backward" - # need this cast because mypy cannot narrow the type even with assert - return grad_output.pad(cast(Tuple[Tuple[int, int], ...], self.narg)) + return grad_output.pad(self.narg) class Flip(Function): def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c47d5df8e2..9b615d5f28 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -406,9 +406,9 @@ class Tensor: # NOTE: using slice is discouraged and things should migrate to pad and shrink def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor: - arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)]) - padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)]) - return self.pad(padding, value=value).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)])) + arg_ = tuple(a if a is not None else (0, s) for s,a in zip(self.shape, arg)) + padding = tuple((max(0, -l), max(0, r-s)) for s,(l,r) in zip(self.shape, arg_)) + return self.pad(padding, value=value).shrink(tuple((l + pl, r + pl) for (l,r),(pl,_) in zip(arg_, padding))) def gather(self:Tensor, idx:Tensor, dim:int) -> Tensor: assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim" @@ -420,13 +420,13 @@ class Tensor: return ((idx == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim) # noqa: E501 def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor: - dim = (dim + len(self.shape)) if dim < 0 else dim + if dim < 0: dim += self.ndim assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args) catargs = [self, *args] assert all(t.shape for t in catargs), "zero-dimensional tensor cannot be concatenated" shapes = [s.shape[dim] for s in catargs] shape_cumsum = [0, *accumulate(shapes)] - slc:List[List[Tuple[sint, sint]]] = [[(0, 0) for _ in self.shape] for _ in catargs] + slc:List[List[Optional[Tuple[sint, sint]]]] = [[None for _ in self.shape] for _ in catargs] for shp,k,s in zip(shapes, shape_cumsum[:-1], slc): s[dim] = (k, shape_cumsum[-1] - k - shp) return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)]) @@ -457,7 +457,7 @@ class Tensor: return self if self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:]) def unsqueeze(self, dim:int) -> Tensor: - if dim < 0: dim = len(self.shape) + dim + 1 + if dim < 0: dim = self.ndim + dim + 1 return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:]) # (padding_left, padding_right, padding_top, padding_bottom) @@ -576,10 +576,10 @@ class Tensor: x, w = self, weight.reshape(groups, weight.shape[0]//groups, weight.shape[1], *weight.shape[2:]).permute(0,2,1,*trailing).flip(trailing) stride = make_pair(stride, len(HW)) if any(s>1 for s in stride): - x = x.reshape(*x.shape[:2], *flatten((k,1) for k in x.shape[2:])) - x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride))) - x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)]) - x = x.shrink(((0,x.shape[0]), (0,x.shape[1]), *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)])) + x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:])) + x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride))) + x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)]) + x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)])) padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW))))))) # noqa: E501 return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=padding)