mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
Support scalars by kartik4949
This commit is contained in:
@@ -66,9 +66,14 @@ class TestOps(unittest.TestCase):
|
||||
def test_logsoftmax(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7, gpu=self.gpu)
|
||||
def test_tanh(self):
|
||||
helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6, gpu=self.gpu)
|
||||
def test_topo_sort(self):
|
||||
helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6, gpu=self.gpu)
|
||||
|
||||
def test_scalar_mul(self):
|
||||
helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2, gpu=self.gpu)
|
||||
def test_scalar_rmul(self):
|
||||
helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x, gpu=self.gpu)
|
||||
|
||||
def test_broadcast_full(self):
|
||||
for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul),
|
||||
|
||||
Reference in New Issue
Block a user