mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Fix in _reshape_mask (#4332)
* handle reshape with remainder in _reshape_mask * remove trailing whitespce * use helper_test_op to generate tensors from shapes * test in shapetracket too * remove whitespace * revert property name in other class tests
This commit is contained in:
@@ -986,6 +986,17 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,3,4), value=1), lambda x: x.pad(((3,4), None), value=1))
|
||||
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,0,0), value=1), lambda x: x.pad((None, None), value=1))
|
||||
|
||||
def test_pad_reshape(self):
|
||||
helper_test_op([(1, 2)],
|
||||
lambda x: torch.nn.functional.pad(x, (0, 1, 1, 0)).reshape((3, 2)),
|
||||
lambda x: x.pad2d((0, 1, 1, 0)).reshape((3, 2)), forward_only=True)
|
||||
helper_test_op([(1, 2)],
|
||||
lambda x: torch.nn.functional.pad(x, (0, 2, 1, 1)).reshape((4, 3)),
|
||||
lambda x: x.pad2d((0, 2, 1, 1)).reshape((4, 3)), forward_only=True)
|
||||
helper_test_op([(1, 1, 1, 2)],
|
||||
lambda x: torch.nn.functional.pad(x, (0, 4, 2, 2, 1, 2, 0, 2)).reshape((4, 3, 6, 5)),
|
||||
lambda x: x.pad(((0, 2), (1, 2), (2, 2), (0, 4))).reshape((4, 3, 6, 5)), forward_only=True)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGL", "incorrect result")
|
||||
def test_pad_slice(self):
|
||||
for value in 0., 3.456:
|
||||
|
||||
@@ -569,6 +569,22 @@ class TestMaskedShapeTracker(unittest.TestCase):
|
||||
self.st.pad(((1,1), (1,1)))
|
||||
self.st.assert_same()
|
||||
|
||||
def test_pad_reshape(self):
|
||||
st1 = CheckingShapeTracker((1, 2))
|
||||
st1.pad(((1, 0), (0, 1)))
|
||||
st1.reshape((3, 2))
|
||||
st1.assert_same()
|
||||
|
||||
st2 = CheckingShapeTracker((1, 2))
|
||||
st2.pad(((1, 1), (0, 2)))
|
||||
st2.reshape((4, 3))
|
||||
st2.assert_same()
|
||||
|
||||
st3 = CheckingShapeTracker((1, 1, 1, 2))
|
||||
st3.pad(((0, 2), (1, 2), (2, 2), (0, 4)))
|
||||
st3.reshape((4, 3, 6, 5))
|
||||
st3.assert_same()
|
||||
|
||||
class TestShapeTracker(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.st = CheckingShapeTracker((7,4))
|
||||
|
||||
Reference in New Issue
Block a user