fix bug caused by rounding

This commit is contained in:
George Hotz
2022-07-17 12:49:58 -07:00
parent cff297ef9d
commit f93e297804
2 changed files with 7 additions and 1 deletions

View File

@@ -262,6 +262,12 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
def test_strided_conv2d_simple(self):
bs,H,W = 2,3,1
helper_test_op([(bs,1,5,1), (1,1,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(),
lambda x,w: Tensor.conv2d(x,w,stride=2).relu(), atol=1e-4)
def test_strided_conv2d(self):
bs = 4
cin = 3

View File

@@ -13,7 +13,7 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup
assert len(shape) == len(strides)
ret = [(shape[0], strides[0])]
for i in range(1, len(shape)):
if (strides[i] != 0 and ret[-1][1]//strides[i] == shape[i]) or (strides[i] == 0 and ret[-1][1] == 0):
if (strides[i] != 0 and ret[-1][1] == shape[i]*strides[i]) or (strides[i] == 0 and ret[-1][1] == 0):
ret[-1] = (ret[-1][0] * shape[i], strides[i])
else:
ret.append((shape[i], strides[i]))