improve test_example

This commit is contained in:
George Hotz
2023-03-12 22:59:40 -07:00
parent 5577634cf3
commit a4abcf0969

View File

@@ -1,40 +1,59 @@
import unittest
from tinygrad.lazy import Device
from tinygrad.tensor import Tensor
def multidevice_test(fxn):
def ret(self):
for device in Device._buffers:
with self.subTest(device=device):
try:
Device[device]
except Exception:
print(f"WARNING: {device} test isn't running")
continue
fxn(self, device)
return ret
class TestExample(unittest.TestCase):
def _test_example_readme(self, device):
@multidevice_test
def test_2_plus_3(self, device):
a = Tensor([2], device=device)
b = Tensor([3], device=device)
result = a + b
print(f"{a.numpy()} + {b.numpy()} = {result.numpy()}")
assert result.numpy()[0] == 5.
@multidevice_test
def test_example_readme(self, device):
x = Tensor.eye(3, device=device, requires_grad=True)
y = Tensor([[2.0,0,-2.0]], device=device, requires_grad=True)
z = y.matmul(x).sum()
z.backward()
print(x.grad.numpy()) # dz/dx
print(y.grad.numpy()) # dz/dy
x.grad.numpy() # dz/dx
y.grad.numpy() # dz/dy
assert x.grad.device == device
assert y.grad.device == device
def _test_example_matmul(self, device):
@multidevice_test
def test_example_matmul(self, device):
try:
Device[device]
except Exception:
print(f"WARNING: {device} test isn't running")
return
x = Tensor.eye(64, device=device, requires_grad=True)
y = Tensor.eye(64, device=device, requires_grad=True)
z = y.matmul(x).sum()
z.backward()
print(x.grad.numpy()) # dz/dx
print(y.grad.numpy()) # dz/dy
x.grad.numpy() # dz/dx
y.grad.numpy() # dz/dy
assert x.grad.device == device
assert y.grad.device == device
def test_example_readme_cpu(self): self._test_example_readme("CPU")
def test_example_readme_gpu(self): self._test_example_readme("GPU")
def test_example_readme_torch(self): self._test_example_readme("TORCH")
def test_example_readme_llvm(self): self._test_example_readme("LLVM")
def test_example_matmul_cpu(self): self._test_example_matmul("CPU")
def test_example_matmul_gpu(self): self._test_example_matmul("GPU")
def test_example_matmul_torch(self): self._test_example_matmul("TORCH")
def test_example_matmul_llvm(self): self._test_example_matmul("LLVM")
if __name__ == '__main__':
unittest.main()