mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
Match torch on fractional negative base pow (#1352)
* feat: match torch on fractional negative base pow * feat: tests for trunc
This commit is contained in:
@@ -182,6 +182,10 @@ class TestOps(unittest.TestCase):
|
||||
tt2 = Tensor.ones(4, requires_grad=True)
|
||||
self.assertRaises(RuntimeError, (tt1 < tt2).sum().backward)
|
||||
|
||||
def test_trunc(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.trunc(x), lambda x: x.trunc(), forward_only=True)
|
||||
a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5])
|
||||
helper_test_op([], lambda: torch.trunc(b), lambda: Tensor.trunc(a), forward_only=True)
|
||||
def test_floor(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.floor(x), lambda x: x.floor(), forward_only=True)
|
||||
a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5])
|
||||
@@ -271,6 +275,11 @@ class TestOps(unittest.TestCase):
|
||||
# Regression tests for https://github.com/tinygrad/tinygrad/issues/1151
|
||||
helper_test_op([(45,65)], lambda x: x**3, lambda x: Tensor.pow(x,3), a=-10)
|
||||
helper_test_op([()], lambda x: x**3, lambda x: Tensor.pow(x,3), a=-10)
|
||||
# Regression tests for https://github.com/tinygrad/tinygrad/issues/1251
|
||||
helper_test_op([(45,65)], lambda x: x**0.2, lambda x: Tensor.pow(x,0.2), a=-10)
|
||||
helper_test_op([(45,65)], lambda x: x**1.2, lambda x: Tensor.pow(x,1.2), a=-10)
|
||||
helper_test_op([()], lambda x: x**0.2, lambda x: Tensor.pow(x,0.2), a=-10)
|
||||
helper_test_op([()], lambda x: x**1.2, lambda x: Tensor.pow(x,1.2), a=-10)
|
||||
def test_pow_const(self):
|
||||
helper_test_op([(45,65)], lambda x: x**1.0, lambda x: x**1.0)
|
||||
helper_test_op([(45,65)], lambda x: x**-1.0, lambda x: x**-1.0)
|
||||
|
||||
Reference in New Issue
Block a user