mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
perf: faster broadcasted (#1601)
Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>
This commit is contained in:
@@ -577,23 +577,21 @@ class Tensor:
|
||||
|
||||
# ***** broadcasted binary mlops *****
|
||||
|
||||
def _broadcasted(self, fxn:Type[Function], other:Union[Tensor, float], reverse:bool=False) -> Tensor:
|
||||
dtype = self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32
|
||||
def _broadcasted(self, fxn:Type[Function], y:Union[Tensor, float], reverse:bool=False) -> Tensor:
|
||||
x: Tensor = self
|
||||
y: Tensor = Tensor(cast(float, other), device=self.device, requires_grad=False, dtype=dtype) if other.__class__ is not Tensor else cast(Tensor, other)
|
||||
if not isinstance(y, Tensor):
|
||||
y = Tensor(y, device=self.device, requires_grad=False, dtype=self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32)
|
||||
if reverse: x, y = y, x
|
||||
if x.shape == y.shape: return fxn.apply(x, y)
|
||||
if (xshape:=x.shape) == (yshape:=y.shape): return fxn.apply(x, y)
|
||||
|
||||
len_x_shape, len_y_shape = len(x.shape), len(y.shape)
|
||||
max_shape = max(len_x_shape, len_y_shape)
|
||||
|
||||
if len_x_shape != max_shape: x = x.reshape((1,) * (max_shape - len_x_shape) + x.shape)
|
||||
if len_y_shape != max_shape: y = y.reshape((1,) * (max_shape - len_y_shape) + y.shape)
|
||||
|
||||
shape_ret = tuple([max(x, y) for x, y in zip(x.shape, y.shape)])
|
||||
if x.shape != shape_ret: x = x.expand(shape_ret)
|
||||
if y.shape != shape_ret: y = y.expand(shape_ret)
|
||||
shape_delta = len(xshape) - len(yshape)
|
||||
if shape_delta > 0: y = y.reshape((1,) * shape_delta + yshape)
|
||||
elif shape_delta < 0: x = x.reshape((1,) * -shape_delta + xshape)
|
||||
if (xshape:=x.shape) == (yshape:=y.shape): return fxn.apply(x, y)
|
||||
|
||||
shape_ret = tuple([max(x, y) for x, y in zip(xshape, yshape)])
|
||||
if xshape != shape_ret: x = x.expand(shape_ret)
|
||||
if yshape != shape_ret: y = y.expand(shape_ret)
|
||||
return fxn.apply(x, y)
|
||||
|
||||
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Add, x, reverse) if x.__class__ is Tensor or x else self
|
||||
|
||||
Reference in New Issue
Block a user