diff --git a/tinygrad/mixin/elementwise.py b/tinygrad/mixin/elementwise.py index c1cb59cd82..8a4129c495 100644 --- a/tinygrad/mixin/elementwise.py +++ b/tinygrad/mixin/elementwise.py @@ -974,3 +974,16 @@ class ElementwiseMixin(DTypeMixin, CreationMixin): """ if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") return self.logical_not() if self.dtype == dtypes.bool else self ^ -1 + + def lerp(self, end: Self, weight: Self | ConstType) -> Self: + """ + Linearly interpolates between `self` and `end` by `weight`. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([1., 2., 3.]).lerp(Tensor([4., 5., 6.]), 0.5).numpy()) + ``` + """ + if self.dtype == dtypes.uint8 and isinstance(weight, ElementwiseMixin): + w_i = (weight * (1<<(W_PREC:=7)) + 0.5).cast(dtypes.int16) + return (self+(((end - self).cast(dtypes.int8) * w_i + (1<> W_PREC)).cast(dtypes.uint8) + return self + (end - self) * weight diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 2f42f83da7..45c45d62ce 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2385,21 +2385,6 @@ class Tensor(OpMixin): """ return self._apply_uop(UOp.contiguous_backward) - # ***** math functions ***** - - def lerp(self, end:Tensor, weight:Tensor|float) -> Tensor: - """ - Linearly interpolates between `self` and `end` by `weight`. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([1., 2., 3.]).lerp(Tensor([4., 5., 6.]), 0.5).numpy()) - ``` - """ - if self.dtype == dtypes.uint8 and isinstance(weight, Tensor): - w_i = (weight * (1<<(W_PREC:=7)) + 0.5).cast(dtypes.int16) - return (self+(((end - self).cast(dtypes.int8) * w_i + (1<> W_PREC)).cast(dtypes.uint8) - return self + (end - self) * weight - # ***** broadcasted elementwise ops ***** def ufix(self, x) -> Tensor: