From a5fd297df59926edd2b9b2b4b23686384853cdd7 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Fri, 7 Nov 2025 15:52:33 -0800 Subject: [PATCH] Revert "try this now" This reverts commit 607cdc21642449be21b9c687c3e5f9a38d4b0242. --- tinygrad/tensor.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 3ac2981a29..dba79f91ee 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2100,17 +2100,23 @@ class Tensor(OpMixin): noop, i_ = [None] * (self.ndim-len(k_)), self.shape[-len(k_):] assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size" o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)] - # input size scaling factor to make sure shrink for stride is possible - f_ = [smax(1, ceildiv(o*s - d, i)) for o,s,i,d in zip(o_,s_,i_,d_)] - # repeats such that we don't need padding - x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)]) - # handle dilation - x = x.shrink_to(noop + [k*(i*f+d) for k,i,d,f in zip(k_,i_,d_,f_)]).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_))) - # handle stride - x = x.shrink_to(noop + flatten((k,o*s) for k,o,s in zip(k_,o_,s_))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_))) - x = x.shrink_to(noop + flatten((k,o,1) for k,o in zip(k_,o_))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_))) - # permute to move reduce to the end - return x.permute(*range(len(noop)), *[len(noop)+i*2+1 for i in range(len(i_))], *[len(noop)+i*2 for i in range(len(i_))]) + if any(resolve(k > s) for k,s in zip(k_,s_)) or any(d != 1 for d in d_): + # input size scaling factor to make sure shrink for stride is possible + f_ = [smax(1, ceildiv(o*s - d, i)) for o,s,i,d in zip(o_,s_,i_,d_)] + # repeats such that we don't need padding + x = self.repeat([1]*len(noop) + [ceildiv(k*(i*f+d),i) for k,i,d,f in zip(k_,i_,d_,f_)]) + # handle dilation + x = x.shrink_to(noop + [k*(i*f+d) for k,i,d,f in zip(k_,i_,d_,f_)]).reshape(noop + flatten((k,(i*f+d)) for k,i,d,f in zip(k_,i_,d_,f_))) + # handle stride + x = x.shrink_to(noop + flatten((k,o*s) for k,o,s in zip(k_,o_,s_))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_))) + x = x.shrink_to(noop + flatten((k,o,1) for k,o in zip(k_,o_))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_))) + # permute to move reduce to the end + return x.permute(*range(len(noop)), *[len(noop)+i*2+1 for i in range(len(i_))], *[len(noop)+i*2 for i in range(len(i_))]) + # TODO: once the shapetracker can optimize well, remove this alternative implementation + x = self.pad(tuple(noop + [(0, max(0,o*s-i)) for i,o,s in zip(i_,o_,s_)])).shrink(tuple(noop + [(0,o*s) for o,s in zip(o_,s_)])) + x = x.reshape(noop + flatten(((o,s) for o,s in zip(o_,s_)))) + x = x.shrink_to(noop + flatten((o,k) for o,k in zip(o_,k_))) + return x.permute(*range(len(noop)), *[len(noop)+i*2 for i in range(len(i_))], *[len(noop)+i*2+1 for i in range(len(i_))]) def _resolve_pool_pads(self, padding:int|Sequence[int], dims:int) -> Sequence[int]: if not isinstance(padding, int) and not (len(padding) == 2*dims or len(padding) == dims):