From 754e54ebb9ee5ab116cfc44deba99a4928d72261 Mon Sep 17 00:00:00 2001 From: Jacky Lee <39754370+jla524@users.noreply.github.com> Date: Tue, 27 Jun 2023 23:22:17 -0700 Subject: [PATCH] Fix Tensor ceil and floor for whole numbers (#1071) * Works on non-special numbers * Test different cases --- test/test_ops.py | 10 ++++++++-- tinygrad/tensor.py | 8 +++++--- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 3f5c433a99..7a0d63783c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -125,8 +125,14 @@ class TestOps(unittest.TestCase): tt2 = Tensor.ones(4, requires_grad=True) self.assertRaises(RuntimeError, (tt1 < tt2).sum().backward) - def test_floor(self): helper_test_op([(45,65)], lambda x: torch.floor(x), lambda x: x.floor(), forward_only=True) - def test_ceil(self): helper_test_op([(45,65)], lambda x: torch.ceil(x), lambda x:x.ceil(), forward_only=True) + def test_floor(self): + helper_test_op([(45,65)], lambda x: torch.floor(x), lambda x: x.floor(), forward_only=True) + a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5]) + helper_test_op([], lambda: torch.floor(b), lambda: Tensor.floor(a), forward_only=True) + def test_ceil(self): + helper_test_op([(45,65)], lambda x: torch.ceil(x), lambda x: x.ceil(), forward_only=True) + a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5]) + helper_test_op([], lambda: torch.ceil(b), lambda: Tensor.ceil(a), forward_only=True) def test_tril(self): helper_test_op([(3,3)], lambda x: x.tril(), lambda x: x.tril()) helper_test_op([(3,3)], lambda x: x.tril(1), lambda x: x.tril(1)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7531c8d02b..37f0487d8e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -496,9 +496,11 @@ class Tensor: # ***** math functions (unary) ***** def ceil(self: Tensor) -> Tensor: - b = self.cast(dtypes.int32).contiguous() - return (self > 0).where(b+1, b) - def floor(self: Tensor) -> Tensor: return self.ceil() - 1 + b = self.cast(dtypes.int32).contiguous().cast(self.dtype) + return (self > b).where(b+1, b) + def floor(self: Tensor) -> Tensor: + b = self.cast(dtypes.int32).contiguous().cast(self.dtype) + return (self < b).where(b-1, b) def __neg__(self): return 0.0-self def sqrt(self): return self.pow(0.5)