From b02f77b3542a5ca0c8ca7129e19952cbe0b79d4d Mon Sep 17 00:00:00 2001 From: Roelof van Dijk <3604013+roelofvandijk@users.noreply.github.com> Date: Mon, 21 Aug 2023 23:21:46 +0200 Subject: [PATCH] perf: faster broadcasted (#1601) Co-authored-by: Roelof van Dijk --- tinygrad/tensor.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b6aa22eac3..f96467720a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -577,23 +577,21 @@ class Tensor: # ***** broadcasted binary mlops ***** - def _broadcasted(self, fxn:Type[Function], other:Union[Tensor, float], reverse:bool=False) -> Tensor: - dtype = self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32 + def _broadcasted(self, fxn:Type[Function], y:Union[Tensor, float], reverse:bool=False) -> Tensor: x: Tensor = self - y: Tensor = Tensor(cast(float, other), device=self.device, requires_grad=False, dtype=dtype) if other.__class__ is not Tensor else cast(Tensor, other) + if not isinstance(y, Tensor): + y = Tensor(y, device=self.device, requires_grad=False, dtype=self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32) if reverse: x, y = y, x - if x.shape == y.shape: return fxn.apply(x, y) + if (xshape:=x.shape) == (yshape:=y.shape): return fxn.apply(x, y) - len_x_shape, len_y_shape = len(x.shape), len(y.shape) - max_shape = max(len_x_shape, len_y_shape) - - if len_x_shape != max_shape: x = x.reshape((1,) * (max_shape - len_x_shape) + x.shape) - if len_y_shape != max_shape: y = y.reshape((1,) * (max_shape - len_y_shape) + y.shape) - - shape_ret = tuple([max(x, y) for x, y in zip(x.shape, y.shape)]) - if x.shape != shape_ret: x = x.expand(shape_ret) - if y.shape != shape_ret: y = y.expand(shape_ret) + shape_delta = len(xshape) - len(yshape) + if shape_delta > 0: y = y.reshape((1,) * shape_delta + yshape) + elif shape_delta < 0: x = x.reshape((1,) * -shape_delta + xshape) + if (xshape:=x.shape) == (yshape:=y.shape): return fxn.apply(x, y) + shape_ret = tuple([max(x, y) for x, y in zip(xshape, yshape)]) + if xshape != shape_ret: x = x.expand(shape_ret) + if yshape != shape_ret: y = y.expand(shape_ret) return fxn.apply(x, y) def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Add, x, reverse) if x.__class__ is Tensor or x else self