diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 518aa6efe5..494f59837a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3173,7 +3173,7 @@ class Tensor(MathTrait): print(Tensor([-3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5]).round().numpy()) ``` """ - return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor()) + return ((self > 0) == ((b := self.trunc() / 2.0).trunc() == b)).where((self - 0.5).ceil(), (self + 0.5).floor()) def isinf(self:Tensor, detect_positive:bool=True, detect_negative:bool=True) -> Tensor: """