Files
tinygrad/test/test_optim.py
Liam bcf1518309 All devices are equal! (#196)
* Update all devices to be tested

ANE, CPU and OCL all now support all tests.

However tests are not currently passing on GPU and I cannot test on CPU.

Failing GPU test are not an issue caused by this update. Tests have not
been passing due to a missing "six" required installation.

OpenCL Tests have not been run since commit: 1a1c63a08b

devices have 3 types and are handle by a new DeviceTypes enum. (The goal
is to revert to Tensor.<type>, but this current setup allows for keyword
argument defaults: `device=DeviceType.CPU`)

All references to Tensor.GPU/CPU/ANE as been converted to the
corresponding `DeviceTypes` enum.

Refactor of the conversion code to allow for any device to any device
conversion.

* Add six dependency in requirements.txt

* Resolve failure to run tests

Move six into gpu required installs. Remove six from standard
installation.

* Remove repeated data conversion

* Refactor method names

Also reduce code with .to and .to_

* Dynamic device handlers

* Refactor DeviceTypes -> Device

* Add mem copy profiling back

* test_backward_pass_diamond_model passing

* Resolve Sum issue on GPU

* Revert batchnorm2d tests

* Update README with upadated API

* ANE testing with

* Last minute line gains
2020-12-15 23:44:08 -08:00

89 lines
2.5 KiB
Python

import numpy as np
import torch
import unittest
from tinygrad.tensor import Tensor, GPU, Device
from tinygrad.optim import Adam, SGD, RMSprop
from extra.utils import get_parameters
from .config import ANE
x_init = np.random.randn(1,3).astype(np.float32)
W_init = np.random.randn(3,3).astype(np.float32)
m_init = np.random.randn(1,3).astype(np.float32)
def step_tinygrad(optim, kwargs={}, device=Device.CPU):
net = TinyNet()
optim = optim([net.x, net.W], **kwargs)
if device==Device.GPU: [x.gpu_() for x in get_parameters([net, optim])]
elif device==Device.ANE: [x.ane_() for x in get_parameters([net, optim])]
out = net.forward()
out.backward()
optim.step()
return net.x.cpu().data, net.W.cpu().data
def step_pytorch(optim, kwargs={}):
net = TorchNet()
optim = optim([net.x, net.W], **kwargs)
out = net.forward()
out.backward()
optim.step()
return net.x.detach().numpy(), net.W.detach().numpy()
class TinyNet():
def __init__(self):
self.x = Tensor(x_init.copy())
self.W = Tensor(W_init.copy())
self.m = Tensor(m_init.copy())
def forward(self):
out = self.x.dot(self.W).relu()
out = out.logsoftmax()
out = out.mul(self.m).add(self.m).sum()
return out
class TorchNet():
def __init__(self):
self.x = torch.tensor(x_init.copy(), requires_grad=True)
self.W = torch.tensor(W_init.copy(), requires_grad=True)
self.m = torch.tensor(m_init.copy())
def forward(self):
out = self.x.matmul(self.W).relu()
out = torch.nn.functional.log_softmax(out, dim=1)
out = out.mul(self.m).add(self.m).sum()
return out
class TestOptim(unittest.TestCase):
device = Device.CPU
def test_adam(self):
for x,y in zip(step_tinygrad(Adam, device=self.device),
step_pytorch(torch.optim.Adam)):
np.testing.assert_allclose(x, y, atol=1e-4)
def test_sgd(self):
for x,y in zip(step_tinygrad(SGD, kwargs={'lr': 0.001}, device=self.device),
step_pytorch(torch.optim.SGD, kwargs={'lr': 0.001})):
np.testing.assert_allclose(x, y, atol=1e-5)
def test_rmsprop(self):
for x,y in zip(step_tinygrad(RMSprop, kwargs={'lr': 0.001, 'decay': 0.99}, device=self.device),
step_pytorch(torch.optim.RMSprop,
kwargs={'lr': 0.001, 'alpha': 0.99})):
np.testing.assert_allclose(x, y, atol=1e-5)
@unittest.skipUnless(GPU, "Requires GPU")
class TestOptimGPU(TestOptim):
device = Device.GPU
@unittest.skipUnless(ANE, "Requires ANE")
class TestOptimANE(TestOptim):
device = Device.ANE
if __name__ == '__main__':
unittest.main()