add circular pad mode to Tensor.pad (#7918)

* start

* send it

* no more neg circular pads

* quick fix onnx too

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
geohotstan
2024-11-27 23:30:51 +08:00
committed by GitHub
parent a58e289d77
commit 753f07e193
3 changed files with 25 additions and 29 deletions

View File

@@ -1468,6 +1468,18 @@ class TestOps(unittest.TestCase):
# no max pad sizes for replicate
helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (3,11,0,30), mode="replicate"), lambda x: x.pad((3,11,0,30), mode="replicate"))
def test_pad_circular_mode(self):
helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (0,2,3,2), mode="circular"), lambda x: x.pad((0,2,3,2), mode="circular"))
helper_test_op([(5,5,5)], lambda x: torch.nn.functional.pad(x, (0,2), mode="circular"), lambda x: x.pad((0,2), mode="circular"))
helper_test_op([(1,1,5,5,5)], lambda x: torch.nn.functional.pad(x, (1,2,3,5,1,2),mode="circular"),lambda x:x.pad((1,2,3,5,1,2),mode="circular"))
# circular pad cannot wrap around more than once
self.helper_test_exception([(1,1,5,5)],
lambda x: torch.nn.functional.pad(x, (3,6,0,0), mode="circular"), lambda x: x.pad((3,6,0,0), mode="circular"),
expected=(RuntimeError, ValueError))
with self.assertRaises(NotImplementedError):
# negative pads with circular pads is not supported
Tensor.randn(1,1,5,5).pad((3,-5,1,-5), mode="circular")
def test_pad_reshape(self):
helper_test_op([(1, 2)],
lambda x: torch.nn.functional.pad(x, (0, 1, 1, 0)).reshape((3, 2)),