diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d1371dc0a9..e69db10c36 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1206,7 +1206,7 @@ class Tensor(OpMixin): if any(st != 1 for st in strides): # pad shape to multiple of stride if not all_int(x.shape): raise RuntimeError("symbolic shape not supported") - x = x.pad(tuple((0, round_up(s, st) - s) for s, st in zip(x.shape, strides))) + x = x.pad_to(tuple(round_up(s, st) for s, st in zip(x.shape, strides))) x = x.reshape(tuple(flatten((s // st, st) for s, st in zip(x.shape, strides)))) x = x.shrink(tuple(flatten(((0, s), (0, 1)) for s in x.shape[::2]))).reshape(x.shape[::2]) @@ -1333,7 +1333,7 @@ class Tensor(OpMixin): dim = self._resolve_dim(dim) assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim" index = index.to(self.device) - x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim) + x = self.shrink_to(tuple(i if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim) return (index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).where(x, 0)).sum(-1, dtype=self.dtype) def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor: @@ -2386,7 +2386,7 @@ class Tensor(OpMixin): # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO) ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]]) # merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final - ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(tuple((0, s) for s in [bs, cout, *oyx])) + ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink_to(bs, cout, *oyx) return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward() @@ -2425,7 +2425,7 @@ class Tensor(OpMixin): x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:])) x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride))) x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)]) - x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)])) + x = x.shrink_to(None, None, *[k-(s-1) for k,s in zip(x.shape[2:], stride)]) padding = flatten((((k-1)*d-pB,(k-1)*d-pA+op) for k,d,(pB,pA),op in reversed(list(zip(HW, dilation, padding, output_padding))))) return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding) @@ -2645,13 +2645,12 @@ class Tensor(OpMixin): f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}" if self.dtype != src.dtype: raise RuntimeError(f"expect {self.dtype=} to be equal to {src.dtype=}") # shrink src to index shape to shrink away the unused values - src = src.shrink(tuple((0,s) for s in index.shape)) + src = src.shrink_to(index.shape) # prepare src and mask for reduce with respect to dim src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim) mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim) # pad src and mask to self.shape so that reduce can be done with padded values as no-ops - src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask)) - return src, mask + return src.pad_to(*self.shape, None), mask.pad_to(*self.shape, None) def scatter(self, dim:int, index:Tensor, src:Tensor|PyConst, reduce:Literal['multiply', 'add']|None=None) -> Tensor: """ @@ -2769,7 +2768,7 @@ class Tensor(OpMixin): # flip wires back to undo the crossover blue_box, flipped_green_box = x.split(1, crossover_dim) x = blue_box.cat(flipped_green_box.flip(flip_dims), dim=crossover_dim) - x = x.flatten(dim, dim+n_stages-1).shrink(tuple((0, s) for s in self.shape)) + x = x.flatten(dim, dim+n_stages-1).shrink_to(self.shape) # compute indices for sorted values mask = Tensor.ones(orig_len, orig_len, dtype=dtypes.bool, device=self.device).tril().reshape((None, None) + (1,)*(self.ndim-dim-1)) def compute_counts(t:Tensor): return (mask & (t.unsqueeze(dim) == t.unsqueeze(dim+1))).sum(dim+1) @@ -2809,8 +2808,8 @@ class Tensor(OpMixin): if not sorted_: raise NotImplementedError("topk with sorted_=False is not supported") if k > self.shape[dim:=self._resolve_dim(dim)]: raise ValueError(f"selected index {k=} is out of range") x, idx = self.sort(dim, descending=largest) - shrink_to_k = tuple((0, k) if i == dim else None for i in range(self.ndim)) - return x.shrink(shrink_to_k), idx.shrink(shrink_to_k) + topk_shape = tuple(k if i == dim else None for i in range(self.ndim)) + return x.shrink_to(topk_shape), idx.shrink_to(topk_shape) # ***** unary ops *****