diff --git a/test/test_ops.py b/test/test_ops.py index e1b9718562..09afec25f6 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -220,6 +220,10 @@ class TestOps(unittest.TestCase): lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz), lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), rtol=1e-5) + def test_cat(self): + for dim in range(-1, 2): + helper_test_op([(45,65), (45,65)], lambda x,y: torch.cat((x,y), dim), lambda x,y: x.cat(y, dim), forward_only=True) + if __name__ == '__main__': np.random.seed(1337) unittest.main(verbosity=2) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 283a6fd3b9..5ae260c28b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -177,6 +177,20 @@ class Tensor: assert s.step is None or s.step == 1 return self.slice(arg = arg + [(0,self.shape[i]) for i in range(len(arg), len(self.shape))]) + def cat(self, y, dim=0): + assert len(self.shape) == len(y.shape) + dim = (dim + len(self.shape)) if dim < 0 else dim + s1, s2 = [], [] + for i in range(len(self.shape)): + if i != dim: + assert self.shape[i] == y.shape[i] + s1.append((0, self.shape[i])) + s2.append((0, self.shape[i])) + else: + s1.append((0, self.shape[i]+y.shape[i])) + s2.append((-self.shape[i], y.shape[i])) + return self.slice(arg=s1) + y.slice(arg=s2) + def pad2d(self, padding): return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]]