support multidot on GPU

This commit is contained in:
George Hotz
2020-12-29 16:56:30 -05:00
parent 27208d729b
commit 6a6a82e999
2 changed files with 19 additions and 17 deletions

View File

@@ -75,7 +75,6 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, device=self.device)
def test_dot(self):
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, device=self.device)
@cpu_only
def test_multidot(self):
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, device=self.device)
def test_sum(self):
@@ -163,7 +162,6 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x,w,stride=stride).relu(),
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), device=self.device)
@cpu_only
def test_maxpool2d(self):
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
with self.subTest(kernel_size=ksz):