Fix Tensor ceil and floor for whole numbers (#1071)

* Works on non-special numbers

* Test different cases
This commit is contained in:
Jacky Lee
2023-06-27 23:22:17 -07:00
committed by GitHub
parent 1f5d45ca8c
commit 754e54ebb9
2 changed files with 13 additions and 5 deletions

View File

@@ -125,8 +125,14 @@ class TestOps(unittest.TestCase):
tt2 = Tensor.ones(4, requires_grad=True)
self.assertRaises(RuntimeError, (tt1 < tt2).sum().backward)
def test_floor(self): helper_test_op([(45,65)], lambda x: torch.floor(x), lambda x: x.floor(), forward_only=True)
def test_ceil(self): helper_test_op([(45,65)], lambda x: torch.ceil(x), lambda x:x.ceil(), forward_only=True)
def test_floor(self):
helper_test_op([(45,65)], lambda x: torch.floor(x), lambda x: x.floor(), forward_only=True)
a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5])
helper_test_op([], lambda: torch.floor(b), lambda: Tensor.floor(a), forward_only=True)
def test_ceil(self):
helper_test_op([(45,65)], lambda x: torch.ceil(x), lambda x: x.ceil(), forward_only=True)
a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5])
helper_test_op([], lambda: torch.ceil(b), lambda: Tensor.ceil(a), forward_only=True)
def test_tril(self):
helper_test_op([(3,3)], lambda x: x.tril(), lambda x: x.tril())
helper_test_op([(3,3)], lambda x: x.tril(1), lambda x: x.tril(1))