mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
getting convs to work on gpu
This commit is contained in:
@@ -109,6 +109,7 @@ python -m pytest
|
||||
* Make broadcasting work on the backward pass (simple please)
|
||||
* EfficientNet backward pass
|
||||
* Tensors on GPU (GPU support, must support Mac)
|
||||
* Make tinygrad work on comma two and run driving model
|
||||
* Reduce code
|
||||
* Increase speed
|
||||
* Add features
|
||||
|
||||
@@ -37,7 +37,7 @@ class TinyConvNet:
|
||||
return [self.l1, self.c1, self.c2]
|
||||
|
||||
def forward(self, x):
|
||||
x.data = x.data.reshape((-1, 1, 28, 28)) # hacks
|
||||
x = x.reshape(shape=(-1, 1, 28, 28)) # hacks
|
||||
x = x.conv2d(self.c1).relu().max_pool2d()
|
||||
x = x.conv2d(self.c2).relu().max_pool2d()
|
||||
x = x.reshape(shape=[x.shape[0], -1])
|
||||
@@ -83,6 +83,15 @@ def evaluate(model, gpu=False):
|
||||
assert accuracy > 0.95
|
||||
|
||||
class TestMNIST(unittest.TestCase):
|
||||
@unittest.skipUnless(GPU, "Requires GPU")
|
||||
def test_conv_gpu(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyConvNet()
|
||||
[x.cuda_() for x in model.parameters()]
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.001)
|
||||
train(model, optimizer, steps=1000, gpu=True)
|
||||
evaluate(model, gpu=True)
|
||||
|
||||
def test_conv(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyConvNet()
|
||||
|
||||
@@ -195,6 +195,23 @@ class Dot(Function):
|
||||
register('dot', Dot, gpu=True)
|
||||
register('matmul', Dot, gpu=True)
|
||||
|
||||
# ************* simple ops *************
|
||||
|
||||
class Reshape(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, shape):
|
||||
ctx.save_for_backward(x.shape)
|
||||
x.shape = shape
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
in_shape, = ctx.saved_tensors
|
||||
grad_output.shape = in_shape
|
||||
return grad_output
|
||||
register('reshape', Reshape, gpu=True)
|
||||
|
||||
# ************* activation ops *************
|
||||
|
||||
class ReLU(Function):
|
||||
@staticmethod
|
||||
|
||||
@@ -17,8 +17,8 @@ def fetch(url):
|
||||
dat = f.read()
|
||||
else:
|
||||
print("fetching %s" % url)
|
||||
with open(fp+".tmp", "wb") as f:
|
||||
dat = requests.get(url).content
|
||||
with open(fp+".tmp", "wb") as f:
|
||||
f.write(dat)
|
||||
os.rename(fp+".tmp", fp)
|
||||
return dat
|
||||
|
||||
Reference in New Issue
Block a user