mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
write sqrt and div using pow
This commit is contained in:
@@ -31,6 +31,20 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-7, grad_atol=1e-7):
|
||||
print("testing %30r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms" % (shps, torch_fp, tinygrad_fp, torch_fbp-torch_fp, tinygrad_fbp-tinygrad_fp))
|
||||
|
||||
class TestOps(unittest.TestCase):
|
||||
def test_add(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add)
|
||||
def test_sub(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub)
|
||||
def test_mul(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x*y, Tensor.mul)
|
||||
def test_div(self):
|
||||
# TODO: why does this need more tolerance?
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div, atol=5e-5)
|
||||
def test_pow(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow)
|
||||
def test_sqrt(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt)
|
||||
|
||||
def test_conv2d(self):
|
||||
for bs in [1,8]:
|
||||
for cin in [1,3]:
|
||||
|
||||
Reference in New Issue
Block a user