if you like your transformers twice as slow, use the GPU

This commit is contained in:
George Hotz
2020-12-29 17:14:23 -05:00
parent 6a6a82e999
commit f9170505b3
4 changed files with 7 additions and 5 deletions

View File

@@ -77,6 +77,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, device=self.device)
def test_multidot(self):
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, device=self.device)
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, device=self.device)
def test_sum(self):
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, device=self.device)
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), device=self.device)