Tensor.ufix (#15452)

* Tensor.ufix

prep moving _broadcasted to mixin

* remove backward_cast
This commit is contained in:
chenyu
2026-03-24 22:34:43 -04:00
committed by GitHub
parent 1b3d00d6ac
commit f6ed4da268

View File

@@ -2946,13 +2946,15 @@ class Tensor(OpMixin):
# ***** broadcasted elementwise ops *****
def _broadcasted(self, y:Tensor|ConstType|UOp, reverse:bool=False, backward_cast:bool=True) -> tuple[Tensor, Tensor]:
def ufix(self, x) -> Tensor:
# TODO: x:ConstType|UOp does not work because mixin only accepts Self | ConstType
assert isinstance(x, (*get_args(ConstType), UOp)), f"{type(x)=}, {x=}"
dtype = self.dtype if dtypes.is_float(self.dtype) or (dtypes.is_int(self.dtype) and isinstance(x, (int, InvalidType))) else None
return Tensor(x, self.device, dtype, requires_grad=False)
def _broadcasted(self, y:Tensor|ConstType|UOp, reverse:bool=False) -> tuple[Tensor, Tensor]:
x: Tensor = self
if not isinstance(y, Tensor):
# make y a Tensor
assert isinstance(y, (*get_args(ConstType), UOp)), f"{type(y)=}, {y=}"
y_dtype = x.dtype if dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, (int, InvalidType))) else None
y = Tensor(y, x.device, y_dtype, requires_grad=False)
if not isinstance(y, Tensor): y = x.ufix(y)
if x.dtype != y.dtype:
output_dtype = least_upper_dtype(x.dtype, y.dtype)
@@ -2965,8 +2967,8 @@ class Tensor(OpMixin):
# broadcast
# NOTE: the backward cast is no-op in forward and uses sum_acc_dtype in the backward sum
return x.cast(sum_acc_dtype(x.dtype) if backward_cast else x.dtype)._broadcast_to(out_shape).cast(x.dtype), \
y.cast(sum_acc_dtype(y.dtype) if backward_cast else y.dtype)._broadcast_to(out_shape).cast(y.dtype)
return x.cast(sum_acc_dtype(x.dtype))._broadcast_to(out_shape).cast(x.dtype), \
y.cast(sum_acc_dtype(y.dtype))._broadcast_to(out_shape).cast(y.dtype)
def sub(self, x:Tensor|ConstType, reverse=False) -> Tensor:
"""