mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Tensor.ufix (#15452)
* Tensor.ufix prep moving _broadcasted to mixin * remove backward_cast
This commit is contained in:
@@ -2946,13 +2946,15 @@ class Tensor(OpMixin):
|
||||
|
||||
# ***** broadcasted elementwise ops *****
|
||||
|
||||
def _broadcasted(self, y:Tensor|ConstType|UOp, reverse:bool=False, backward_cast:bool=True) -> tuple[Tensor, Tensor]:
|
||||
def ufix(self, x) -> Tensor:
|
||||
# TODO: x:ConstType|UOp does not work because mixin only accepts Self | ConstType
|
||||
assert isinstance(x, (*get_args(ConstType), UOp)), f"{type(x)=}, {x=}"
|
||||
dtype = self.dtype if dtypes.is_float(self.dtype) or (dtypes.is_int(self.dtype) and isinstance(x, (int, InvalidType))) else None
|
||||
return Tensor(x, self.device, dtype, requires_grad=False)
|
||||
|
||||
def _broadcasted(self, y:Tensor|ConstType|UOp, reverse:bool=False) -> tuple[Tensor, Tensor]:
|
||||
x: Tensor = self
|
||||
if not isinstance(y, Tensor):
|
||||
# make y a Tensor
|
||||
assert isinstance(y, (*get_args(ConstType), UOp)), f"{type(y)=}, {y=}"
|
||||
y_dtype = x.dtype if dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, (int, InvalidType))) else None
|
||||
y = Tensor(y, x.device, y_dtype, requires_grad=False)
|
||||
if not isinstance(y, Tensor): y = x.ufix(y)
|
||||
|
||||
if x.dtype != y.dtype:
|
||||
output_dtype = least_upper_dtype(x.dtype, y.dtype)
|
||||
@@ -2965,8 +2967,8 @@ class Tensor(OpMixin):
|
||||
|
||||
# broadcast
|
||||
# NOTE: the backward cast is no-op in forward and uses sum_acc_dtype in the backward sum
|
||||
return x.cast(sum_acc_dtype(x.dtype) if backward_cast else x.dtype)._broadcast_to(out_shape).cast(x.dtype), \
|
||||
y.cast(sum_acc_dtype(y.dtype) if backward_cast else y.dtype)._broadcast_to(out_shape).cast(y.dtype)
|
||||
return x.cast(sum_acc_dtype(x.dtype))._broadcast_to(out_shape).cast(x.dtype), \
|
||||
y.cast(sum_acc_dtype(y.dtype))._broadcast_to(out_shape).cast(y.dtype)
|
||||
|
||||
def sub(self, x:Tensor|ConstType, reverse=False) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user