mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
move _broadcasted to OpMixin (#15461)
it needs both ElementwiseMixin and MovementMixin
This commit is contained in:
@@ -1,6 +1,15 @@
|
||||
from typing import Self
|
||||
from tinygrad.mixin.elementwise import ElementwiseMixin
|
||||
from tinygrad.mixin.movement import MovementMixin
|
||||
from tinygrad.uop.ops import _broadcast_shape
|
||||
from tinygrad.dtype import least_upper_dtype, sum_acc_dtype
|
||||
|
||||
|
||||
class OpMixin(ElementwiseMixin, MovementMixin):
|
||||
pass
|
||||
def _broadcasted(self, y, reverse=False) -> tuple[Self, Self]:
|
||||
if not isinstance(y, type(self)): y = self.ufix(y)
|
||||
x, y = (self, y) if not reverse else (y, self)
|
||||
out_shape, out_dtype = _broadcast_shape(x.shape, y.shape), least_upper_dtype(x.dtype, y.dtype)
|
||||
# NOTE: the backward cast is no-op in forward and uses sum_acc_dtype in the backward sum
|
||||
return x.cast(sum_acc_dtype(x.dtype))._broadcast_to(out_shape).cast(x.dtype).cast(out_dtype), \
|
||||
y.cast(sum_acc_dtype(y.dtype))._broadcast_to(out_shape).cast(y.dtype).cast(out_dtype)
|
||||
|
||||
@@ -2952,17 +2952,6 @@ class Tensor(OpMixin):
|
||||
dtype = self.dtype if dtypes.is_float(self.dtype) or (dtypes.is_int(self.dtype) and isinstance(x, (int, InvalidType))) else None
|
||||
return Tensor(x, self.device, dtype, requires_grad=False)
|
||||
|
||||
def _broadcasted(self, y:Tensor|ConstType|UOp, reverse:bool=False) -> tuple[Tensor, Tensor]:
|
||||
if not isinstance(y, Tensor): y = self.ufix(y)
|
||||
|
||||
x, y = (self, y) if not reverse else (y, self)
|
||||
|
||||
out_shape, out_dtype = _broadcast_shape(x.shape, y.shape), least_upper_dtype(x.dtype, y.dtype)
|
||||
|
||||
# NOTE: the backward cast is no-op in forward and uses sum_acc_dtype in the backward sum
|
||||
return x.cast(sum_acc_dtype(x.dtype))._broadcast_to(out_shape).cast(x.dtype).cast(out_dtype), \
|
||||
y.cast(sum_acc_dtype(y.dtype))._broadcast_to(out_shape).cast(y.dtype).cast(out_dtype)
|
||||
|
||||
def sub(self, x:Tensor|ConstType, reverse=False) -> Tensor:
|
||||
"""
|
||||
Subtracts `x` from `self`.
|
||||
|
||||
Reference in New Issue
Block a user