mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
Added test_chunk and fixed (#1283)
This commit is contained in:
@@ -340,10 +340,10 @@ class Tensor:
|
||||
|
||||
# TODO: make this nicer with syntactic sugar in slice
|
||||
def chunk(self, num, dim):
|
||||
slice_params = [[(0, s) for s in self.shape] for _ in range(num)]
|
||||
for i,k in enumerate(range(0, self.shape[dim], self.shape[dim]//num)):
|
||||
slice_params[i][dim] = (k, min(self.shape[dim], k+self.shape[dim]//num))
|
||||
return [self.slice(p) for p in slice_params]
|
||||
slice_params = [[slice(None) for s in self.shape] for _ in range(ceil(self.shape[dim]/ceil(self.shape[dim]/num)))]
|
||||
for i, k in enumerate(range(0, self.shape[dim], ceil(self.shape[dim]/num))):
|
||||
slice_params[i][dim] = slice(k, k + ceil(self.shape[dim]/num))
|
||||
return [self[tuple(sl)] for sl in slice_params]
|
||||
|
||||
def squeeze(self, dim=None):
|
||||
if dim is None: return self if 1 not in self.shape else self.reshape(*[size for size in self.shape if size != 1])
|
||||
|
||||
Reference in New Issue
Block a user