mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
improve test_example
This commit is contained in:
@@ -1,40 +1,59 @@
|
||||
import unittest
|
||||
from tinygrad.lazy import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
def multidevice_test(fxn):
|
||||
def ret(self):
|
||||
for device in Device._buffers:
|
||||
with self.subTest(device=device):
|
||||
try:
|
||||
Device[device]
|
||||
except Exception:
|
||||
print(f"WARNING: {device} test isn't running")
|
||||
continue
|
||||
fxn(self, device)
|
||||
return ret
|
||||
|
||||
class TestExample(unittest.TestCase):
|
||||
def _test_example_readme(self, device):
|
||||
@multidevice_test
|
||||
def test_2_plus_3(self, device):
|
||||
a = Tensor([2], device=device)
|
||||
b = Tensor([3], device=device)
|
||||
result = a + b
|
||||
print(f"{a.numpy()} + {b.numpy()} = {result.numpy()}")
|
||||
assert result.numpy()[0] == 5.
|
||||
|
||||
@multidevice_test
|
||||
def test_example_readme(self, device):
|
||||
x = Tensor.eye(3, device=device, requires_grad=True)
|
||||
y = Tensor([[2.0,0,-2.0]], device=device, requires_grad=True)
|
||||
z = y.matmul(x).sum()
|
||||
z.backward()
|
||||
|
||||
print(x.grad.numpy()) # dz/dx
|
||||
print(y.grad.numpy()) # dz/dy
|
||||
x.grad.numpy() # dz/dx
|
||||
y.grad.numpy() # dz/dy
|
||||
|
||||
assert x.grad.device == device
|
||||
assert y.grad.device == device
|
||||
|
||||
def _test_example_matmul(self, device):
|
||||
@multidevice_test
|
||||
def test_example_matmul(self, device):
|
||||
try:
|
||||
Device[device]
|
||||
except Exception:
|
||||
print(f"WARNING: {device} test isn't running")
|
||||
return
|
||||
|
||||
x = Tensor.eye(64, device=device, requires_grad=True)
|
||||
y = Tensor.eye(64, device=device, requires_grad=True)
|
||||
z = y.matmul(x).sum()
|
||||
z.backward()
|
||||
|
||||
print(x.grad.numpy()) # dz/dx
|
||||
print(y.grad.numpy()) # dz/dy
|
||||
x.grad.numpy() # dz/dx
|
||||
y.grad.numpy() # dz/dy
|
||||
|
||||
assert x.grad.device == device
|
||||
assert y.grad.device == device
|
||||
|
||||
def test_example_readme_cpu(self): self._test_example_readme("CPU")
|
||||
def test_example_readme_gpu(self): self._test_example_readme("GPU")
|
||||
def test_example_readme_torch(self): self._test_example_readme("TORCH")
|
||||
def test_example_readme_llvm(self): self._test_example_readme("LLVM")
|
||||
|
||||
def test_example_matmul_cpu(self): self._test_example_matmul("CPU")
|
||||
def test_example_matmul_gpu(self): self._test_example_matmul("GPU")
|
||||
def test_example_matmul_torch(self): self._test_example_matmul("TORCH")
|
||||
def test_example_matmul_llvm(self): self._test_example_matmul("LLVM")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user