break maxpool2d on GPU

This commit is contained in:
George Hotz
2020-12-29 13:05:57 -05:00
parent 061e37de39
commit 02655c07d5
5 changed files with 27 additions and 108 deletions

View File

@@ -164,6 +164,7 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x,w,stride=stride).relu(),
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), device=self.device)
@cpu_only
def test_maxpool2d(self):
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
with self.subTest(kernel_size=ksz):