move _broadcasted to OpMixin (#15461)

it needs both ElementwiseMixin and MovementMixin
This commit is contained in:
chenyu
2026-03-24 23:56:01 -04:00
committed by GitHub
parent 519ba22470
commit 02878c5a2f
2 changed files with 10 additions and 12 deletions

View File

@@ -1,6 +1,15 @@
from typing import Self
from tinygrad.mixin.elementwise import ElementwiseMixin
from tinygrad.mixin.movement import MovementMixin
from tinygrad.uop.ops import _broadcast_shape
from tinygrad.dtype import least_upper_dtype, sum_acc_dtype
class OpMixin(ElementwiseMixin, MovementMixin):
pass
def _broadcasted(self, y, reverse=False) -> tuple[Self, Self]:
if not isinstance(y, type(self)): y = self.ufix(y)
x, y = (self, y) if not reverse else (y, self)
out_shape, out_dtype = _broadcast_shape(x.shape, y.shape), least_upper_dtype(x.dtype, y.dtype)
# NOTE: the backward cast is no-op in forward and uses sum_acc_dtype in the backward sum
return x.cast(sum_acc_dtype(x.dtype))._broadcast_to(out_shape).cast(x.dtype).cast(out_dtype), \
y.cast(sum_acc_dtype(y.dtype))._broadcast_to(out_shape).cast(y.dtype).cast(out_dtype)

View File

@@ -2952,17 +2952,6 @@ class Tensor(OpMixin):
dtype = self.dtype if dtypes.is_float(self.dtype) or (dtypes.is_int(self.dtype) and isinstance(x, (int, InvalidType))) else None
return Tensor(x, self.device, dtype, requires_grad=False)
def _broadcasted(self, y:Tensor|ConstType|UOp, reverse:bool=False) -> tuple[Tensor, Tensor]:
if not isinstance(y, Tensor): y = self.ufix(y)
x, y = (self, y) if not reverse else (y, self)
out_shape, out_dtype = _broadcast_shape(x.shape, y.shape), least_upper_dtype(x.dtype, y.dtype)
# NOTE: the backward cast is no-op in forward and uses sum_acc_dtype in the backward sum
return x.cast(sum_acc_dtype(x.dtype))._broadcast_to(out_shape).cast(x.dtype).cast(out_dtype), \
y.cast(sum_acc_dtype(y.dtype))._broadcast_to(out_shape).cast(y.dtype).cast(out_dtype)
def sub(self, x:Tensor|ConstType, reverse=False) -> Tensor:
"""
Subtracts `x` from `self`.