mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
expand broadcast functions a bit (#4891)
taking some good stuff from the #4886. I think `from_, to` is more readble than `sh, s` too [run_process_replay]
This commit is contained in:
@@ -66,8 +66,11 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
|
||||
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
|
||||
return ret
|
||||
|
||||
def _pad_left(*shps:Tuple[sint, ...], v=1): return tuple((v,) * (max(len(i_) for i_ in shps) - len(i)) + i for i in shps)
|
||||
def _broadcast_shape(*shps:Tuple[sint, ...]): return tuple(0 if any(sh_ == 0 for sh_ in sh) else max(sh) for sh in zip(*_pad_left(*shps)))
|
||||
def _pad_left(*shapes:Tuple[sint, ...]) -> Tuple[Tuple[sint, ...], ...]:
|
||||
max_dim = max(len(shape) for shape in shapes)
|
||||
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
|
||||
def _broadcast_shape(*shapes:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
||||
return tuple(0 if any(size == 0 for size in nth_dim_sizes) else max(nth_dim_sizes) for nth_dim_sizes in zip(*_pad_left(*shapes)))
|
||||
|
||||
class Tensor:
|
||||
"""
|
||||
@@ -756,7 +759,7 @@ class Tensor:
|
||||
print(t.expand(4, -1).numpy())
|
||||
```
|
||||
"""
|
||||
return self._broadcast_to(tuple(sh if s==-1 or s is None else s for s, sh in zip(*(_pad_left(argfix(shape, *args), self.shape)))))
|
||||
return self._broadcast_to(tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_pad_left(self.shape, argfix(shape, *args))))))
|
||||
|
||||
def permute(self, order, *args) -> Tensor:
|
||||
"""
|
||||
@@ -2265,11 +2268,14 @@ class Tensor:
|
||||
return self / (1 + self.abs())
|
||||
|
||||
# ***** broadcasted elementwise ops *****
|
||||
def _broadcast_to(self, shape:Tuple[sint, ...]):
|
||||
reshape_arg, _ = _pad_left(self.shape, shape)
|
||||
if self.ndim > len(shape) or not all(sh in {s,1} or (s==0 and sh==1) for sh,s in zip(reshape_arg, shape)):
|
||||
raise ValueError(f"cannot broadcast tensor with shape={self.shape} to {shape=}")
|
||||
return F.Expand.apply(self.reshape(reshape_arg), shape=shape) if shape != self.shape else self
|
||||
def _broadcast_to(self, shape:Tuple[sint, ...]) -> Tensor:
|
||||
if self.shape == shape: return self
|
||||
if self.ndim > len(shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {shape=}")
|
||||
# first pad left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
|
||||
padded, _ = _pad_left(self.shape, shape)
|
||||
# for each dimension, check either from_ is 1, or it does not change
|
||||
if any(from_ != 1 and from_ != to for from_,to in zip(padded, shape)): raise ValueError(f"cannot broadcast from shape={self.shape} to {shape=}")
|
||||
return F.Expand.apply(self.reshape(padded), shape=shape)
|
||||
|
||||
def _broadcasted(self, y:Union[Tensor, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
|
||||
x: Tensor = self
|
||||
@@ -2281,7 +2287,7 @@ class Tensor:
|
||||
if isinstance(y, Node): y = Tensor.from_node(y, device=self.device)
|
||||
else: y = Tensor(dtypes.as_const(y, y_dtype), self.device, y_dtype, requires_grad=False)
|
||||
|
||||
if match_dtype:
|
||||
if match_dtype and x.dtype != y.dtype:
|
||||
output_dtype = least_upper_dtype(x.dtype, y.dtype)
|
||||
x, y = x.cast(output_dtype), y.cast(output_dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user