Small Tensor.cat optimization and reformating (#1347)

This commit is contained in:
Umut Zengin
2023-07-27 01:01:12 +03:00
committed by GitHub
parent 4056f97187
commit d4ebadf2da
3 changed files with 14 additions and 9 deletions

View File

@@ -1042,8 +1042,8 @@ class TestOps(unittest.TestCase):
lambda x: Tensor.avg_pool2d(x, kernel_size=(111,28)), rtol=1e-5)
def test_cat(self):
for dim in range(-1, 2):
helper_test_op([(45,65), (45,65)], lambda x,y: torch.cat((x,y), dim), lambda x,y: x.cat(y, dim=dim))
for dim in range(-2, 3):
helper_test_op([(45,65, 90), (45,65,90), (45,65,90)], lambda x,y,z: torch.cat((x,y,z), dim), lambda x,y,z: x.cat(y, z, dim=dim))
with self.assertRaises(AssertionError):
a = Tensor(3.14)

View File

@@ -157,6 +157,10 @@ class TestSpeed(unittest.TestCase):
helper_test_generic_square('cumsum_0', 256, f0, f0, onearg=True)
helper_test_generic_square('cumsum_1', 256, f1, f1, onearg=True)
def test_cat(self):
helper_test_generic_square('cat_0', 256, lambda x,y: torch.cat((x,y),dim=0), lambda x,y: x.cat(y,dim=0))
helper_test_generic_square('cat_1', 256, lambda x,y: torch.cat((x,y),dim=1), lambda x,y: x.cat(y,dim=1))
def test_array_packing(self):
N = 2048
def f(a, b): return a.reshape(N, N // 32, 32).permute(1,0,2).contiguous()

View File

@@ -322,13 +322,14 @@ class Tensor:
def cat(self, *args, dim=0):
dim = (dim + len(self.shape)) if dim < 0 else dim
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
catargs = [self] + list(args)
assert all(len(t.shape) != 0 for t in catargs), "zero-dimensional tensor cannot be concatenated"
shape_cumsum = [0, *accumulate([y.shape[dim] for y in catargs])]
slc = [[(0, s) for s in self.shape] for _ in catargs]
for s,k in zip(slc, shape_cumsum):
s[dim] = (-k, shape_cumsum[-1]-k)
return reduce(Tensor.__add__, [arg.slice(s) for arg,s in zip(catargs, slc)])
catargs = [self, *args]
assert all(t.shape for t in catargs), "zero-dimensional tensor cannot be concatenated"
shapes = [s.shape[dim] for s in catargs]
shape_cumsum = [0, *accumulate(shapes)]
slc = [[(0, 0) for _ in self.shape] for _ in catargs]
for shp,k,s in zip(shapes, shape_cumsum[:-1], slc):
s[dim] = (k, shape_cumsum[-1] - k - shp)
return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
@staticmethod
def stack(tensors, dim=0):