small changes from lowerer (#5266)

This commit is contained in:
George Hotz
2024-07-02 15:03:54 -07:00
committed by GitHub
parent 7be776f9af
commit e53b164e1a
3 changed files with 11 additions and 4 deletions

View File

@@ -765,6 +765,12 @@ class TestOps(unittest.TestCase):
helper_test_op([(4,3), (1,3,3,5)], lambda x,y: x.matmul(y), Tensor.dot)
def test_small_gemm(self):
helper_test_op([(8,8), (8,8)], lambda x,y: x.matmul(y), lambda x,y: x@y)
def test_9_gemm(self):
helper_test_op([(9,9), (9,9)], lambda x,y: x.matmul(y), lambda x,y: x@y)
def test_small_gemm_padded(self):
helper_test_op([(9,9), (9,9)],
lambda x,y: torch.nn.functional.pad(x, (0,7,0,7)).matmul(torch.nn.functional.pad(y, (0,7,0,7))),
lambda x,y: x.pad(((0,7),(0,7)))@y.pad(((0,7),(0,7))))
def test_small_gemm_range(self):
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.arange(0,64,dtype=np.float32).reshape(8,8),
np.arange(64,128,dtype=np.float32).reshape(8,8)])