update Tensor.round doc and example (#5318)

document rounding half to even and update examples to show
This commit is contained in:
chenyu
2024-07-07 12:10:39 -04:00
committed by GitHub
parent c1e330f302
commit 296a1a36bb

View File

@@ -2012,7 +2012,7 @@ class Tensor:
Truncates the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).trunc().numpy())
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).trunc().numpy())
```
"""
return self.cast(dtypes.int32).cast(self.dtype)
@@ -2021,7 +2021,7 @@ class Tensor:
Rounds the tensor element-wise towards positive infinity.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).ceil().numpy())
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).ceil().numpy())
```
"""
return (self > (b := self.trunc())).where(b+1, b)
@@ -2030,19 +2030,20 @@ class Tensor:
Rounds the tensor element-wise towards negative infinity.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).floor().numpy())
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).floor().numpy())
```
"""
return (self < (b := self.trunc())).where(b-1, b)
def round(self: Tensor) -> Tensor:
"""
Rounds the tensor element-wise.
Rounds the tensor element-wise with rounding half to even.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).round().numpy())
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).round().numpy())
```
"""
return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor:
"""
Linearly interpolates between `self` and `end` by `weight`.