diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 571d111f42..35b42bc261 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2946,13 +2946,15 @@ class Tensor(OpMixin): # ***** broadcasted elementwise ops ***** - def _broadcasted(self, y:Tensor|ConstType|UOp, reverse:bool=False, backward_cast:bool=True) -> tuple[Tensor, Tensor]: + def ufix(self, x) -> Tensor: + # TODO: x:ConstType|UOp does not work because mixin only accepts Self | ConstType + assert isinstance(x, (*get_args(ConstType), UOp)), f"{type(x)=}, {x=}" + dtype = self.dtype if dtypes.is_float(self.dtype) or (dtypes.is_int(self.dtype) and isinstance(x, (int, InvalidType))) else None + return Tensor(x, self.device, dtype, requires_grad=False) + + def _broadcasted(self, y:Tensor|ConstType|UOp, reverse:bool=False) -> tuple[Tensor, Tensor]: x: Tensor = self - if not isinstance(y, Tensor): - # make y a Tensor - assert isinstance(y, (*get_args(ConstType), UOp)), f"{type(y)=}, {y=}" - y_dtype = x.dtype if dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, (int, InvalidType))) else None - y = Tensor(y, x.device, y_dtype, requires_grad=False) + if not isinstance(y, Tensor): y = x.ufix(y) if x.dtype != y.dtype: output_dtype = least_upper_dtype(x.dtype, y.dtype) @@ -2965,8 +2967,8 @@ class Tensor(OpMixin): # broadcast # NOTE: the backward cast is no-op in forward and uses sum_acc_dtype in the backward sum - return x.cast(sum_acc_dtype(x.dtype) if backward_cast else x.dtype)._broadcast_to(out_shape).cast(x.dtype), \ - y.cast(sum_acc_dtype(y.dtype) if backward_cast else y.dtype)._broadcast_to(out_shape).cast(y.dtype) + return x.cast(sum_acc_dtype(x.dtype))._broadcast_to(out_shape).cast(x.dtype), \ + y.cast(sum_acc_dtype(y.dtype))._broadcast_to(out_shape).cast(y.dtype) def sub(self, x:Tensor|ConstType, reverse=False) -> Tensor: """