diff --git a/test/test_ops.py b/test/test_ops.py index 730100e4fa..74aacc9014 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1042,8 +1042,8 @@ class TestOps(unittest.TestCase): lambda x: Tensor.avg_pool2d(x, kernel_size=(111,28)), rtol=1e-5) def test_cat(self): - for dim in range(-1, 2): - helper_test_op([(45,65), (45,65)], lambda x,y: torch.cat((x,y), dim), lambda x,y: x.cat(y, dim=dim)) + for dim in range(-2, 3): + helper_test_op([(45,65, 90), (45,65,90), (45,65,90)], lambda x,y,z: torch.cat((x,y,z), dim), lambda x,y,z: x.cat(y, z, dim=dim)) with self.assertRaises(AssertionError): a = Tensor(3.14) diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py index 3c07f1439f..6ba4d98ce4 100644 --- a/test/test_speed_v_torch.py +++ b/test/test_speed_v_torch.py @@ -157,6 +157,10 @@ class TestSpeed(unittest.TestCase): helper_test_generic_square('cumsum_0', 256, f0, f0, onearg=True) helper_test_generic_square('cumsum_1', 256, f1, f1, onearg=True) + def test_cat(self): + helper_test_generic_square('cat_0', 256, lambda x,y: torch.cat((x,y),dim=0), lambda x,y: x.cat(y,dim=0)) + helper_test_generic_square('cat_1', 256, lambda x,y: torch.cat((x,y),dim=1), lambda x,y: x.cat(y,dim=1)) + def test_array_packing(self): N = 2048 def f(a, b): return a.reshape(N, N // 32, 32).permute(1,0,2).contiguous() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c3a2c4d0cb..838c782819 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -322,13 +322,14 @@ class Tensor: def cat(self, *args, dim=0): dim = (dim + len(self.shape)) if dim < 0 else dim assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args) - catargs = [self] + list(args) - assert all(len(t.shape) != 0 for t in catargs), "zero-dimensional tensor cannot be concatenated" - shape_cumsum = [0, *accumulate([y.shape[dim] for y in catargs])] - slc = [[(0, s) for s in self.shape] for _ in catargs] - for s,k in zip(slc, shape_cumsum): - s[dim] = (-k, shape_cumsum[-1]-k) - return reduce(Tensor.__add__, [arg.slice(s) for arg,s in zip(catargs, slc)]) + catargs = [self, *args] + assert all(t.shape for t in catargs), "zero-dimensional tensor cannot be concatenated" + shapes = [s.shape[dim] for s in catargs] + shape_cumsum = [0, *accumulate(shapes)] + slc = [[(0, 0) for _ in self.shape] for _ in catargs] + for shp,k,s in zip(shapes, shape_cumsum[:-1], slc): + s[dim] = (k, shape_cumsum[-1] - k - shp) + return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)]) @staticmethod def stack(tensors, dim=0):