diff --git a/tinygrad/mixin/__init__.py b/tinygrad/mixin/__init__.py index 5766f30694..e4be275b9c 100644 --- a/tinygrad/mixin/__init__.py +++ b/tinygrad/mixin/__init__.py @@ -1,6 +1,15 @@ +from typing import Self from tinygrad.mixin.elementwise import ElementwiseMixin from tinygrad.mixin.movement import MovementMixin +from tinygrad.uop.ops import _broadcast_shape +from tinygrad.dtype import least_upper_dtype, sum_acc_dtype class OpMixin(ElementwiseMixin, MovementMixin): - pass + def _broadcasted(self, y, reverse=False) -> tuple[Self, Self]: + if not isinstance(y, type(self)): y = self.ufix(y) + x, y = (self, y) if not reverse else (y, self) + out_shape, out_dtype = _broadcast_shape(x.shape, y.shape), least_upper_dtype(x.dtype, y.dtype) + # NOTE: the backward cast is no-op in forward and uses sum_acc_dtype in the backward sum + return x.cast(sum_acc_dtype(x.dtype))._broadcast_to(out_shape).cast(x.dtype).cast(out_dtype), \ + y.cast(sum_acc_dtype(y.dtype))._broadcast_to(out_shape).cast(y.dtype).cast(out_dtype) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 14bed09aea..446ff3b538 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2952,17 +2952,6 @@ class Tensor(OpMixin): dtype = self.dtype if dtypes.is_float(self.dtype) or (dtypes.is_int(self.dtype) and isinstance(x, (int, InvalidType))) else None return Tensor(x, self.device, dtype, requires_grad=False) - def _broadcasted(self, y:Tensor|ConstType|UOp, reverse:bool=False) -> tuple[Tensor, Tensor]: - if not isinstance(y, Tensor): y = self.ufix(y) - - x, y = (self, y) if not reverse else (y, self) - - out_shape, out_dtype = _broadcast_shape(x.shape, y.shape), least_upper_dtype(x.dtype, y.dtype) - - # NOTE: the backward cast is no-op in forward and uses sum_acc_dtype in the backward sum - return x.cast(sum_acc_dtype(x.dtype))._broadcast_to(out_shape).cast(x.dtype).cast(out_dtype), \ - y.cast(sum_acc_dtype(y.dtype))._broadcast_to(out_shape).cast(y.dtype).cast(out_dtype) - def sub(self, x:Tensor|ConstType, reverse=False) -> Tensor: """ Subtracts `x` from `self`.