mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
implement Tensor.round (#3225)
This commit is contained in:
@@ -245,6 +245,10 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65)], lambda x: x.ceil(), forward_only=True)
|
||||
a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5])
|
||||
helper_test_op([], lambda: torch.ceil(b), lambda: Tensor.ceil(a), forward_only=True)
|
||||
def test_round(self):
|
||||
helper_test_op([(45,65)], lambda x: x.round(), forward_only=True)
|
||||
a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5])
|
||||
helper_test_op([], lambda: torch.round(b), lambda: Tensor.round(a), forward_only=True)
|
||||
def test_tril(self):
|
||||
helper_test_op([(3,3)], lambda x: x.tril())
|
||||
helper_test_op([(3,3)], lambda x: x.tril(1))
|
||||
|
||||
Reference in New Issue
Block a user