Added test_chunk and fixed (#1283)

This commit is contained in:
Umut Zengin
2023-07-20 05:21:26 +03:00
committed by GitHub
parent 3f2497160c
commit 74e63fe4ee
2 changed files with 29 additions and 4 deletions

View File

@@ -340,10 +340,10 @@ class Tensor:
# TODO: make this nicer with syntactic sugar in slice
def chunk(self, num, dim):
slice_params = [[(0, s) for s in self.shape] for _ in range(num)]
for i,k in enumerate(range(0, self.shape[dim], self.shape[dim]//num)):
slice_params[i][dim] = (k, min(self.shape[dim], k+self.shape[dim]//num))
return [self.slice(p) for p in slice_params]
slice_params = [[slice(None) for s in self.shape] for _ in range(ceil(self.shape[dim]/ceil(self.shape[dim]/num)))]
for i, k in enumerate(range(0, self.shape[dim], ceil(self.shape[dim]/num))):
slice_params[i][dim] = slice(k, k + ceil(self.shape[dim]/num))
return [self[tuple(sl)] for sl in slice_params]
def squeeze(self, dim=None):
if dim is None: return self if 1 not in self.shape else self.reshape(*[size for size in self.shape if size != 1])