Fix in _reshape_mask (#4332)

* handle reshape with remainder in _reshape_mask

* remove trailing whitespce

* use helper_test_op to generate tensors from shapes

* test in shapetracket too

* remove whitespace

* revert property name in other class tests
This commit is contained in:
Obada Khalili
2024-04-28 18:57:39 +03:00
committed by GitHub
parent 664b563c91
commit e4befa41d7
3 changed files with 29 additions and 1 deletions

View File

@@ -986,6 +986,17 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,3,4), value=1), lambda x: x.pad(((3,4), None), value=1))
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,0,0), value=1), lambda x: x.pad((None, None), value=1))
def test_pad_reshape(self):
helper_test_op([(1, 2)],
lambda x: torch.nn.functional.pad(x, (0, 1, 1, 0)).reshape((3, 2)),
lambda x: x.pad2d((0, 1, 1, 0)).reshape((3, 2)), forward_only=True)
helper_test_op([(1, 2)],
lambda x: torch.nn.functional.pad(x, (0, 2, 1, 1)).reshape((4, 3)),
lambda x: x.pad2d((0, 2, 1, 1)).reshape((4, 3)), forward_only=True)
helper_test_op([(1, 1, 1, 2)],
lambda x: torch.nn.functional.pad(x, (0, 4, 2, 2, 1, 2, 0, 2)).reshape((4, 3, 6, 5)),
lambda x: x.pad(((0, 2), (1, 2), (2, 2), (0, 4))).reshape((4, 3, 6, 5)), forward_only=True)
@unittest.skipIf(Device.DEFAULT == "WEBGL", "incorrect result")
def test_pad_slice(self):
for value in 0., 3.456:

View File

@@ -569,6 +569,22 @@ class TestMaskedShapeTracker(unittest.TestCase):
self.st.pad(((1,1), (1,1)))
self.st.assert_same()
def test_pad_reshape(self):
st1 = CheckingShapeTracker((1, 2))
st1.pad(((1, 0), (0, 1)))
st1.reshape((3, 2))
st1.assert_same()
st2 = CheckingShapeTracker((1, 2))
st2.pad(((1, 1), (0, 2)))
st2.reshape((4, 3))
st2.assert_same()
st3 = CheckingShapeTracker((1, 1, 1, 2))
st3.pad(((0, 2), (1, 2), (2, 2), (0, 4)))
st3.reshape((4, 3, 6, 5))
st3.assert_same()
class TestShapeTracker(unittest.TestCase):
def setUp(self):
self.st = CheckingShapeTracker((7,4))