Fix Tensor ceil and floor for whole numbers (#1071)

* Works on non-special numbers

* Test different cases
This commit is contained in:
Jacky Lee
2023-06-27 23:22:17 -07:00
committed by GitHub
parent 1f5d45ca8c
commit 754e54ebb9
2 changed files with 13 additions and 5 deletions

View File

@@ -496,9 +496,11 @@ class Tensor:
# ***** math functions (unary) *****
def ceil(self: Tensor) -> Tensor:
b = self.cast(dtypes.int32).contiguous()
return (self > 0).where(b+1, b)
def floor(self: Tensor) -> Tensor: return self.ceil() - 1
b = self.cast(dtypes.int32).contiguous().cast(self.dtype)
return (self > b).where(b+1, b)
def floor(self: Tensor) -> Tensor:
b = self.cast(dtypes.int32).contiguous().cast(self.dtype)
return (self < b).where(b-1, b)
def __neg__(self): return 0.0-self
def sqrt(self): return self.pow(0.5)