add cat support

This commit is contained in:
George Hotz
2021-11-28 23:21:49 -05:00
parent ce3d198bb7
commit 3cdc77f526
2 changed files with 18 additions and 0 deletions

View File

@@ -220,6 +220,10 @@ class TestOps(unittest.TestCase):
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz),
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), rtol=1e-5)
def test_cat(self):
for dim in range(-1, 2):
helper_test_op([(45,65), (45,65)], lambda x,y: torch.cat((x,y), dim), lambda x,y: x.cat(y, dim), forward_only=True)
if __name__ == '__main__':
np.random.seed(1337)
unittest.main(verbosity=2)

View File

@@ -177,6 +177,20 @@ class Tensor:
assert s.step is None or s.step == 1
return self.slice(arg = arg + [(0,self.shape[i]) for i in range(len(arg), len(self.shape))])
def cat(self, y, dim=0):
assert len(self.shape) == len(y.shape)
dim = (dim + len(self.shape)) if dim < 0 else dim
s1, s2 = [], []
for i in range(len(self.shape)):
if i != dim:
assert self.shape[i] == y.shape[i]
s1.append((0, self.shape[i]))
s2.append((0, self.shape[i]))
else:
s1.append((0, self.shape[i]+y.shape[i]))
s2.append((-self.shape[i], y.shape[i]))
return self.slice(arg=s1) + y.slice(arg=s2)
def pad2d(self, padding):
return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]]