replace slice by pad/shrink in _pool (#4082)

This commit is contained in:
Akshit Talwar
2024-04-05 21:17:22 +05:30
committed by GitHub
parent a337922c44
commit 750ecf8fef
2 changed files with 12 additions and 5 deletions

View File

@@ -1432,6 +1432,13 @@ class TestOps(unittest.TestCase):
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=stride),
lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=stride))
def test_maxpool2d_bigger_stride_dilation(self):
for stride, dilation in zip([(2,3), (3,2), 2, 3, 4], [(3,2), (2,3), 2, 3, 6]):
with self.subTest(stride=stride):
helper_test_op([(32,2,110,28)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=stride, dilation=dilation),
lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=stride, dilation=dilation))
@unittest.skipIf(Device.DEFAULT == "CUDA", "CUDA fails on this")
def test_maxpool2d_unit_stride(self):
helper_test_op([(8, 2, 17, 14)],

View File

@@ -691,17 +691,17 @@ class Tensor:
# repeats such that we don't need padding
xup = self.repeat([1]*len(noop_) + [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)])
# slice by dilation
xup = xup.slice(noop_ + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)]).reshape(noop_ + flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
xup = xup.shrink(tuple(noop_ + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])).reshape(noop_ + flatten((k,i+d) for k,i,d in zip(k_, i_, d_)))
# handle stride
xup = xup.slice(noop_ + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_))).reshape(noop_ + flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
xup = xup.slice(noop_ + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_))).reshape(noop_ + flatten((k,o) for k,o in zip(k_, o_)))
xup = xup.shrink(noop_ + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_))).reshape(noop_ + flatten((k,o,s) for k,o,s in zip(k_, o_, s_)))
xup = xup.shrink(noop_ + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_))).reshape(noop_ + flatten((k,o) for k,o in zip(k_, o_)))
# permute to move reduce to the end
return xup.permute(*range(len(noop_)), *[len(noop_)+i*2+1 for i in range(len(i_))], *[len(noop_)+i*2 for i in range(len(i_))])
# TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker
o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)]
xup = self.slice(noop_ + [(0,o*s) for o,s in zip(o_, s_)])
xup = self.pad(tuple(noop_ + [(0, max(0,o*s-i)) for i,o,s in zip(i_, o_, s_)])).shrink(tuple(noop_ + [(0,o*s) for o,s in zip(o_, s_)]))
xup = xup.reshape(noop_ + flatten(((o,s) for o,s in zip(o_, s_))))
xup = xup.slice(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
xup = xup.shrink(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))
return xup.permute(*range(len(noop_)), *[len(noop_)+i*2 for i in range(len(i_))], *[len(noop_)+i*2+1 for i in range(len(i_))])
# NOTE: these work for more than 2D