mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
add circular pad mode to Tensor.pad (#7918)
* start * send it * no more neg circular pads * quick fix onnx too --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -1468,6 +1468,18 @@ class TestOps(unittest.TestCase):
|
||||
# no max pad sizes for replicate
|
||||
helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (3,11,0,30), mode="replicate"), lambda x: x.pad((3,11,0,30), mode="replicate"))
|
||||
|
||||
def test_pad_circular_mode(self):
|
||||
helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (0,2,3,2), mode="circular"), lambda x: x.pad((0,2,3,2), mode="circular"))
|
||||
helper_test_op([(5,5,5)], lambda x: torch.nn.functional.pad(x, (0,2), mode="circular"), lambda x: x.pad((0,2), mode="circular"))
|
||||
helper_test_op([(1,1,5,5,5)], lambda x: torch.nn.functional.pad(x, (1,2,3,5,1,2),mode="circular"),lambda x:x.pad((1,2,3,5,1,2),mode="circular"))
|
||||
# circular pad cannot wrap around more than once
|
||||
self.helper_test_exception([(1,1,5,5)],
|
||||
lambda x: torch.nn.functional.pad(x, (3,6,0,0), mode="circular"), lambda x: x.pad((3,6,0,0), mode="circular"),
|
||||
expected=(RuntimeError, ValueError))
|
||||
with self.assertRaises(NotImplementedError):
|
||||
# negative pads with circular pads is not supported
|
||||
Tensor.randn(1,1,5,5).pad((3,-5,1,-5), mode="circular")
|
||||
|
||||
def test_pad_reshape(self):
|
||||
helper_test_op([(1, 2)],
|
||||
lambda x: torch.nn.functional.pad(x, (0, 1, 1, 0)).reshape((3, 2)),
|
||||
|
||||
Reference in New Issue
Block a user