move lerp to mixin (#15634)

last function of math function section
This commit is contained in:
chenyu
2026-04-07 15:13:00 -04:00
committed by GitHub
parent 890286e8d6
commit 9c6e925b56
2 changed files with 13 additions and 15 deletions

View File

@@ -974,3 +974,16 @@ class ElementwiseMixin(DTypeMixin, CreationMixin):
"""
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return self.logical_not() if self.dtype == dtypes.bool else self ^ -1
def lerp(self, end: Self, weight: Self | ConstType) -> Self:
"""
Linearly interpolates between `self` and `end` by `weight`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 3.]).lerp(Tensor([4., 5., 6.]), 0.5).numpy())
```
"""
if self.dtype == dtypes.uint8 and isinstance(weight, ElementwiseMixin):
w_i = (weight * (1<<(W_PREC:=7)) + 0.5).cast(dtypes.int16)
return (self+(((end - self).cast(dtypes.int8) * w_i + (1<<W_PREC-1)).cast(dtypes.uint16) >> W_PREC)).cast(dtypes.uint8)
return self + (end - self) * weight

View File

@@ -2385,21 +2385,6 @@ class Tensor(OpMixin):
"""
return self._apply_uop(UOp.contiguous_backward)
# ***** math functions *****
def lerp(self, end:Tensor, weight:Tensor|float) -> Tensor:
"""
Linearly interpolates between `self` and `end` by `weight`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 3.]).lerp(Tensor([4., 5., 6.]), 0.5).numpy())
```
"""
if self.dtype == dtypes.uint8 and isinstance(weight, Tensor):
w_i = (weight * (1<<(W_PREC:=7)) + 0.5).cast(dtypes.int16)
return (self+(((end - self).cast(dtypes.int8) * w_i + (1<<W_PREC-1)).cast(dtypes.uint16) >> W_PREC)).cast(dtypes.uint8)
return self + (end - self) * weight
# ***** broadcasted elementwise ops *****
def ufix(self, x) -> Tensor: