diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 87f18d9417..2a6a5c6147 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1947,21 +1947,15 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method noop_, i_, ksz_ = [None] * len(self.shape[:-len(k_)]), self.shape[-len(k_):], [d*(k-1)+1 for k,d in zip(k_,d_)] assert all(ksz <= i for ksz,i in zip(ksz_, i_)), "kernel size cannot be greater than actual input size" o_ = [ceildiv(i - d * (k-1), s) for i,d,k,s in zip(i_, d_, k_, s_)] - if any(resolve(k > s) for k,s in zip(k_, s_)) or any(d != 1 for d in d_): - # repeats such that we don't need padding - xup = self.repeat([1]*len(noop_) + [o+ceildiv((o*s),i) for o,i,s in zip(o_,i_,s_)]) - # handle stride - xup = xup.shrink(tuple(noop_ + [(0, o*(i+s)) for o,i,s in zip(o_,i_,s_)])).reshape(noop_ + flatten((o,i+s) for o,i,s in zip(o_,i_,s_))) - # handle dilation - xup = xup.shrink( - tuple(noop_ + flatten((None, (0,ksz+d-1)) for ksz,d in zip(ksz_, d_)))).reshape(noop_ + flatten((o,k,d) for o,k,d in zip(o_,k_,d_))) - xup = xup.shrink(tuple(noop_ + flatten(((0,o), (0,k), (0,1)) for o,k in zip(o_,k_)))).reshape(noop_ + flatten((o,k) for o,k in zip(o_,k_))) - # permute to move reduce to the end - return xup.permute(*range(len(noop_)), *[len(noop_)+i*2 for i in range(len(i_))], *[len(noop_)+i*2+1 for i in range(len(i_))]) - # TODO: once the shapetracker can optimize well, remove this alternative implementation - xup = self.pad(tuple(noop_ + [(0, max(0,o*s-i)) for i,o,s in zip(i_, o_, s_)])).shrink(tuple(noop_ + [(0,o*s) for o,s in zip(o_, s_)])) - xup = xup.reshape(noop_ + flatten(((o,s) for o,s in zip(o_, s_)))) - xup = xup.shrink(tuple(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))) + # repeats such that we don't need padding + xup = self.repeat([1]*len(noop_) + [o+ceildiv((o*s),i) for o,i,s in zip(o_,i_,s_)]) + # handle stride + xup = xup.shrink(tuple(noop_ + [(0, o*(i+s)) for o,i,s in zip(o_,i_,s_)])).reshape(noop_ + flatten((o,i+s) for o,i,s in zip(o_,i_,s_))) + # handle dilation + xup = xup.shrink( + tuple(noop_ + flatten((None, (0,ksz+d-1)) for ksz,d in zip(ksz_, d_)))).reshape(noop_ + flatten((o,k,d) for o,k,d in zip(o_,k_,d_))) + xup = xup.shrink(tuple(noop_ + flatten(((0,o), (0,k), (0,1)) for o,k in zip(o_,k_)))).reshape(noop_ + flatten((o,k) for o,k in zip(o_,k_))) + # permute to move reduce to the end return xup.permute(*range(len(noop_)), *[len(noop_)+i*2 for i in range(len(i_))], *[len(noop_)+i*2+1 for i in range(len(i_))]) def _padding2d(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]: