mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -2100,17 +2100,23 @@ class Tensor(OpMixin):
|
||||
noop, i_ = [None] * (self.ndim-len(k_)), self.shape[-len(k_):]
|
||||
assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size"
|
||||
o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)]
|
||||
# input size scaling factor to make sure shrink for stride is possible
|
||||
f_ = [smax(1, ceildiv(o*s - d, i)) for o,s,i,d in zip(o_,s_,i_,d_)]
|
||||
# repeats such that we don't need padding
|
||||
x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
|
||||
# handle dilation
|
||||
x = x.shrink_to(noop + [k*(i*f+d) for k,i,d,f in zip(k_,i_,d_,f_)]).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)))
|
||||
# handle stride
|
||||
x = x.shrink_to(noop + flatten((k,o*s) for k,o,s in zip(k_,o_,s_))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
|
||||
x = x.shrink_to(noop + flatten((k,o,1) for k,o in zip(k_,o_))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
|
||||
# permute to move reduce to the end
|
||||
return x.permute(*range(len(noop)), *[len(noop)+i*2+1 for i in range(len(i_))], *[len(noop)+i*2 for i in range(len(i_))])
|
||||
if any(resolve(k > s) for k,s in zip(k_,s_)) or any(d != 1 for d in d_):
|
||||
# input size scaling factor to make sure shrink for stride is possible
|
||||
f_ = [smax(1, ceildiv(o*s - d, i)) for o,s,i,d in zip(o_,s_,i_,d_)]
|
||||
# repeats such that we don't need padding
|
||||
x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)])
|
||||
# handle dilation
|
||||
x = x.shrink_to(noop + [k*(i*f+d) for k,i,d,f in zip(k_,i_,d_,f_)]).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_)))
|
||||
# handle stride
|
||||
x = x.shrink_to(noop + flatten((k,o*s) for k,o,s in zip(k_,o_,s_))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_)))
|
||||
x = x.shrink_to(noop + flatten((k,o,1) for k,o in zip(k_,o_))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_)))
|
||||
# permute to move reduce to the end
|
||||
return x.permute(*range(len(noop)), *[len(noop)+i*2+1 for i in range(len(i_))], *[len(noop)+i*2 for i in range(len(i_))])
|
||||
# TODO: once the shapetracker can optimize well, remove this alternative implementation
|
||||
x = self.pad(tuple(noop + [(0, max(0,o*s-i)) for i,o,s in zip(i_,o_,s_)])).shrink(tuple(noop + [(0,o*s) for o,s in zip(o_,s_)]))
|
||||
x = x.reshape(noop + flatten(((o,s) for o,s in zip(o_,s_))))
|
||||
x = x.shrink_to(noop + flatten((o,k) for o,k in zip(o_,k_)))
|
||||
return x.permute(*range(len(noop)), *[len(noop)+i*2 for i in range(len(i_))], *[len(noop)+i*2+1 for i in range(len(i_))])
|
||||
|
||||
def _resolve_pool_pads(self, padding:int|Sequence[int], dims:int) -> Sequence[int]:
|
||||
if not isinstance(padding, int) and not (len(padding) == 2*dims or len(padding) == dims):
|
||||
|
||||
Reference in New Issue
Block a user