mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
fix bug caused by rounding
This commit is contained in:
@@ -262,6 +262,12 @@ class TestOps(unittest.TestCase):
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_strided_conv2d_simple(self):
|
||||
bs,H,W = 2,3,1
|
||||
helper_test_op([(bs,1,5,1), (1,1,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=2).relu(), atol=1e-4)
|
||||
|
||||
def test_strided_conv2d(self):
|
||||
bs = 4
|
||||
cin = 3
|
||||
|
||||
@@ -13,7 +13,7 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup
|
||||
assert len(shape) == len(strides)
|
||||
ret = [(shape[0], strides[0])]
|
||||
for i in range(1, len(shape)):
|
||||
if (strides[i] != 0 and ret[-1][1]//strides[i] == shape[i]) or (strides[i] == 0 and ret[-1][1] == 0):
|
||||
if (strides[i] != 0 and ret[-1][1] == shape[i]*strides[i]) or (strides[i] == 0 and ret[-1][1] == 0):
|
||||
ret[-1] = (ret[-1][0] * shape[i], strides[i])
|
||||
else:
|
||||
ret.append((shape[i], strides[i]))
|
||||
|
||||
Reference in New Issue
Block a user