add cat support

This commit is contained in:
George Hotz
2021-11-28 23:21:49 -05:00
parent ce3d198bb7
commit 3cdc77f526
2 changed files with 18 additions and 0 deletions

View File

@@ -220,6 +220,10 @@ class TestOps(unittest.TestCase):
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz),
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), rtol=1e-5)
def test_cat(self):
for dim in range(-1, 2):
helper_test_op([(45,65), (45,65)], lambda x,y: torch.cat((x,y), dim), lambda x,y: x.cat(y, dim), forward_only=True)
if __name__ == '__main__':
np.random.seed(1337)
unittest.main(verbosity=2)