mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
refactor _broadcasted (#2747)
also moved the expand noop check to .expand.
This commit is contained in:
@@ -276,6 +276,7 @@ class Tensor:
|
||||
new_shape = argfix(shape, *args)
|
||||
return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])) # noqa: E501
|
||||
def expand(self, shape, *args) -> Tensor:
|
||||
if shape == self.shape: return self
|
||||
return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
|
||||
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
|
||||
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
|
||||
@@ -725,24 +726,25 @@ class Tensor:
|
||||
|
||||
# ***** broadcasted binary mlops *****
|
||||
|
||||
# TODO: y can be int here
|
||||
def _broadcasted(self, y:Union[Tensor, float], reverse:bool=False) -> Tuple[Tensor, Tensor]:
|
||||
x: Tensor = self
|
||||
if not isinstance(y, Tensor):
|
||||
if 0 in x.shape: return x, x.full_like(y)
|
||||
y = Tensor(y, device=self.device, requires_grad=False, dtype=self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32) # noqa: E501
|
||||
# make y a Tensor
|
||||
if 0 in self.shape: return self, self.full_like(y)
|
||||
y_dtype = self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32
|
||||
y = Tensor(y, self.device, dtype=y_dtype, requires_grad=False)
|
||||
|
||||
x: Tensor = self
|
||||
if reverse: x, y = y, x
|
||||
if (xshape:=x.shape) == (yshape:=y.shape): return (x, y)
|
||||
|
||||
shape_delta = len(xshape) - len(yshape)
|
||||
if shape_delta > 0: y = y.reshape((1,) * shape_delta + yshape)
|
||||
elif shape_delta < 0: x = x.reshape((1,) * -shape_delta + xshape)
|
||||
if (xshape:=x.shape) == (yshape:=y.shape): return (x, y)
|
||||
# left pad shape with 1s
|
||||
if len(y.shape) < len(x.shape): y = y.reshape((1,) * (len(x.shape) - len(y.shape)) + y.shape)
|
||||
elif len(x.shape) < len(y.shape): x = x.reshape((1,) * (len(y.shape) - len(x.shape)) + x.shape)
|
||||
|
||||
shape_ret = tuple([max(x, y) for x, y in zip(xshape, yshape)])
|
||||
if xshape != shape_ret: x = x.expand(shape_ret)
|
||||
if yshape != shape_ret: y = y.expand(shape_ret)
|
||||
return (x, y)
|
||||
broadcasted_shape = tuple(max(xi, yi) for xi, yi in zip(x.shape, y.shape))
|
||||
return x.expand(broadcasted_shape), y.expand(broadcasted_shape)
|
||||
|
||||
# TODO: x can be int here
|
||||
def _to_float(self, x:Union[Tensor, float]):
|
||||
return x.lazydata.base.op.arg if isinstance(x, Tensor) and x.lazydata.is_unrealized_contiguous_const() \
|
||||
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
|
||||
|
||||
Reference in New Issue
Block a user