mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import unittest
|
||||
@@ -41,8 +42,6 @@ class TestOps(unittest.TestCase):
|
||||
gpu = False
|
||||
def test_add(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add, gpu=self.gpu)
|
||||
def test_broadcast_add(self):
|
||||
helper_test_op([(1,32,32,32), (1,32,1,1)], lambda x,y: x+y, Tensor.add, gpu=self.gpu, forward_only=True)
|
||||
def test_sub(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub, gpu=self.gpu)
|
||||
def test_mul(self):
|
||||
@@ -64,6 +63,16 @@ class TestOps(unittest.TestCase):
|
||||
def test_logsoftmax(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7, gpu=self.gpu)
|
||||
|
||||
def test_broadcast(self):
|
||||
if os.getenv('CI') and self.gpu:
|
||||
raise unittest.SkipTest('GPU broadcasting not fully supported')
|
||||
for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul),
|
||||
(torch.div, Tensor.div), (torch.pow, Tensor.pow)]:
|
||||
for shapes in [((1,32,32,32), (1,32,1,1)), ((5,13,24,16), (5,1,24,1)),
|
||||
((4,1), (4,5)), ((1,4), (5,4)), ((1,3,1,7,1), (2,1,5,1,8))]:
|
||||
with self.subTest(op=torch_op.__name__, shapes=shapes):
|
||||
helper_test_op(shapes, torch_op, tinygrad_op, gpu=self.gpu, forward_only=True)
|
||||
|
||||
def test_pad2d(self):
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), gpu=self.gpu)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user