mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
add cat support
This commit is contained in:
@@ -220,6 +220,10 @@ class TestOps(unittest.TestCase):
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), rtol=1e-5)
|
||||
|
||||
def test_cat(self):
|
||||
for dim in range(-1, 2):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: torch.cat((x,y), dim), lambda x,y: x.cat(y, dim), forward_only=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
np.random.seed(1337)
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user