mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -1206,7 +1206,7 @@ class Tensor(OpMixin):
|
||||
if any(st != 1 for st in strides):
|
||||
# pad shape to multiple of stride
|
||||
if not all_int(x.shape): raise RuntimeError("symbolic shape not supported")
|
||||
x = x.pad(tuple((0, round_up(s, st) - s) for s, st in zip(x.shape, strides)))
|
||||
x = x.pad_to(tuple(round_up(s, st) for s, st in zip(x.shape, strides)))
|
||||
x = x.reshape(tuple(flatten((s // st, st) for s, st in zip(x.shape, strides))))
|
||||
x = x.shrink(tuple(flatten(((0, s), (0, 1)) for s in x.shape[::2]))).reshape(x.shape[::2])
|
||||
|
||||
@@ -1333,7 +1333,7 @@ class Tensor(OpMixin):
|
||||
dim = self._resolve_dim(dim)
|
||||
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
|
||||
index = index.to(self.device)
|
||||
x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
|
||||
x = self.shrink_to(tuple(i if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
|
||||
return (index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).where(x, 0)).sum(-1, dtype=self.dtype)
|
||||
|
||||
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
||||
@@ -2386,7 +2386,7 @@ class Tensor(OpMixin):
|
||||
# interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO)
|
||||
ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]])
|
||||
# merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final
|
||||
ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(tuple((0, s) for s in [bs, cout, *oyx]))
|
||||
ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink_to(bs, cout, *oyx)
|
||||
|
||||
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
|
||||
|
||||
@@ -2425,7 +2425,7 @@ class Tensor(OpMixin):
|
||||
x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
|
||||
x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
|
||||
x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
|
||||
x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
|
||||
x = x.shrink_to(None, None, *[k-(s-1) for k,s in zip(x.shape[2:], stride)])
|
||||
padding = flatten((((k-1)*d-pB,(k-1)*d-pA+op) for k,d,(pB,pA),op in reversed(list(zip(HW, dilation, padding, output_padding)))))
|
||||
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
||||
|
||||
@@ -2645,13 +2645,12 @@ class Tensor(OpMixin):
|
||||
f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
|
||||
if self.dtype != src.dtype: raise RuntimeError(f"expect {self.dtype=} to be equal to {src.dtype=}")
|
||||
# shrink src to index shape to shrink away the unused values
|
||||
src = src.shrink(tuple((0,s) for s in index.shape))
|
||||
src = src.shrink_to(index.shape)
|
||||
# prepare src and mask for reduce with respect to dim
|
||||
src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim)
|
||||
mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim)
|
||||
# pad src and mask to self.shape so that reduce can be done with padded values as no-ops
|
||||
src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
|
||||
return src, mask
|
||||
return src.pad_to(*self.shape, None), mask.pad_to(*self.shape, None)
|
||||
|
||||
def scatter(self, dim:int, index:Tensor, src:Tensor|PyConst, reduce:Literal['multiply', 'add']|None=None) -> Tensor:
|
||||
"""
|
||||
@@ -2769,7 +2768,7 @@ class Tensor(OpMixin):
|
||||
# flip wires back to undo the crossover
|
||||
blue_box, flipped_green_box = x.split(1, crossover_dim)
|
||||
x = blue_box.cat(flipped_green_box.flip(flip_dims), dim=crossover_dim)
|
||||
x = x.flatten(dim, dim+n_stages-1).shrink(tuple((0, s) for s in self.shape))
|
||||
x = x.flatten(dim, dim+n_stages-1).shrink_to(self.shape)
|
||||
# compute indices for sorted values
|
||||
mask = Tensor.ones(orig_len, orig_len, dtype=dtypes.bool, device=self.device).tril().reshape((None, None) + (1,)*(self.ndim-dim-1))
|
||||
def compute_counts(t:Tensor): return (mask & (t.unsqueeze(dim) == t.unsqueeze(dim+1))).sum(dim+1)
|
||||
@@ -2809,8 +2808,8 @@ class Tensor(OpMixin):
|
||||
if not sorted_: raise NotImplementedError("topk with sorted_=False is not supported")
|
||||
if k > self.shape[dim:=self._resolve_dim(dim)]: raise ValueError(f"selected index {k=} is out of range")
|
||||
x, idx = self.sort(dim, descending=largest)
|
||||
shrink_to_k = tuple((0, k) if i == dim else None for i in range(self.ndim))
|
||||
return x.shrink(shrink_to_k), idx.shrink(shrink_to_k)
|
||||
topk_shape = tuple(k if i == dim else None for i in range(self.ndim))
|
||||
return x.shrink_to(topk_shape), idx.shrink_to(topk_shape)
|
||||
|
||||
# ***** unary ops *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user