rewrite Tensor.round to not use cast int (#11654)

This commit is contained in:
chenyu
2025-08-13 10:51:08 -07:00
committed by GitHub
parent d2521d828a
commit 94e6d84e32

View File

@@ -3173,7 +3173,7 @@ class Tensor(MathTrait):
print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).round().numpy())
```
"""
return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
return ((self > 0) == ((b := self.trunc() / 2.0).trunc() == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
def isinf(self:Tensor, detect_positive:bool=True, detect_negative:bool=True) -> Tensor:
"""