implement Tensor.round (#3225)

This commit is contained in:
Obada Khalili
2024-01-24 18:49:17 +02:00
committed by GitHub
parent 842053873d
commit 0e103b4aa0
2 changed files with 5 additions and 0 deletions

View File

@@ -245,6 +245,10 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: x.ceil(), forward_only=True)
a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5])
helper_test_op([], lambda: torch.ceil(b), lambda: Tensor.ceil(a), forward_only=True)
def test_round(self):
helper_test_op([(45,65)], lambda x: x.round(), forward_only=True)
a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5])
helper_test_op([], lambda: torch.round(b), lambda: Tensor.round(a), forward_only=True)
def test_tril(self):
helper_test_op([(3,3)], lambda x: x.tril())
helper_test_op([(3,3)], lambda x: x.tril(1))