expand broadcast functions a bit (#4891)

taking some good stuff from the #4886. I think `from_, to` is more readble than `sh, s` too
[run_process_replay]
This commit is contained in:
chenyu
2024-06-08 20:16:54 -04:00
committed by GitHub
parent 2849d0a2a1
commit a3ec4234df

View File

@@ -66,8 +66,11 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
return ret
def _pad_left(*shps:Tuple[sint, ...], v=1): return tuple((v,) * (max(len(i_) for i_ in shps) - len(i)) + i for i in shps)
def _broadcast_shape(*shps:Tuple[sint, ...]): return tuple(0 if any(sh_ == 0 for sh_ in sh) else max(sh) for sh in zip(*_pad_left(*shps)))
def _pad_left(*shapes:Tuple[sint, ...]) -> Tuple[Tuple[sint, ...], ...]:
max_dim = max(len(shape) for shape in shapes)
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
def _broadcast_shape(*shapes:Tuple[sint, ...]) -> Tuple[sint, ...]:
return tuple(0 if any(size == 0 for size in nth_dim_sizes) else max(nth_dim_sizes) for nth_dim_sizes in zip(*_pad_left(*shapes)))
class Tensor:
"""
@@ -756,7 +759,7 @@ class Tensor:
print(t.expand(4, -1).numpy())
```
"""
return self._broadcast_to(tuple(sh if s==-1 or s is None else s for s, sh in zip(*(_pad_left(argfix(shape, *args), self.shape)))))
return self._broadcast_to(tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_pad_left(self.shape, argfix(shape, *args))))))
def permute(self, order, *args) -> Tensor:
"""
@@ -2265,11 +2268,14 @@ class Tensor:
return self / (1 + self.abs())
# ***** broadcasted elementwise ops *****
def _broadcast_to(self, shape:Tuple[sint, ...]):
reshape_arg, _ = _pad_left(self.shape, shape)
if self.ndim > len(shape) or not all(sh in {s,1} or (s==0 and sh==1) for sh,s in zip(reshape_arg, shape)):
raise ValueError(f"cannot broadcast tensor with shape={self.shape} to {shape=}")
return F.Expand.apply(self.reshape(reshape_arg), shape=shape) if shape != self.shape else self
def _broadcast_to(self, shape:Tuple[sint, ...]) -> Tensor:
if self.shape == shape: return self
if self.ndim > len(shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {shape=}")
# first pad left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
padded, _ = _pad_left(self.shape, shape)
# for each dimension, check either from_ is 1, or it does not change
if any(from_ != 1 and from_ != to for from_,to in zip(padded, shape)): raise ValueError(f"cannot broadcast from shape={self.shape} to {shape=}")
return F.Expand.apply(self.reshape(padded), shape=shape)
def _broadcasted(self, y:Union[Tensor, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
x: Tensor = self
@@ -2281,7 +2287,7 @@ class Tensor:
if isinstance(y, Node): y = Tensor.from_node(y, device=self.device)
else: y = Tensor(dtypes.as_const(y, y_dtype), self.device, y_dtype, requires_grad=False)
if match_dtype:
if match_dtype and x.dtype != y.dtype:
output_dtype = least_upper_dtype(x.dtype, y.dtype)
x, y = x.cast(output_dtype), y.cast(output_dtype)