diff --git a/README.md b/README.md index 2b4bcd7879..b1a58a8f81 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,7 @@ python -m pytest * Make broadcasting work on the backward pass (simple please) * EfficientNet backward pass * Tensors on GPU (GPU support, must support Mac) +* Make tinygrad work on comma two and run driving model * Reduce code * Increase speed * Add features diff --git a/test/test_mnist.py b/test/test_mnist.py index d8d39ba47d..502c95df31 100644 --- a/test/test_mnist.py +++ b/test/test_mnist.py @@ -37,7 +37,7 @@ class TinyConvNet: return [self.l1, self.c1, self.c2] def forward(self, x): - x.data = x.data.reshape((-1, 1, 28, 28)) # hacks + x = x.reshape(shape=(-1, 1, 28, 28)) # hacks x = x.conv2d(self.c1).relu().max_pool2d() x = x.conv2d(self.c2).relu().max_pool2d() x = x.reshape(shape=[x.shape[0], -1]) @@ -83,6 +83,15 @@ def evaluate(model, gpu=False): assert accuracy > 0.95 class TestMNIST(unittest.TestCase): + @unittest.skipUnless(GPU, "Requires GPU") + def test_conv_gpu(self): + np.random.seed(1337) + model = TinyConvNet() + [x.cuda_() for x in model.parameters()] + optimizer = optim.SGD(model.parameters(), lr=0.001) + train(model, optimizer, steps=1000, gpu=True) + evaluate(model, gpu=True) + def test_conv(self): np.random.seed(1337) model = TinyConvNet() diff --git a/tinygrad/opsgpu.py b/tinygrad/opsgpu.py index 3b723b7a00..a76d0b5fc6 100644 --- a/tinygrad/opsgpu.py +++ b/tinygrad/opsgpu.py @@ -195,6 +195,23 @@ class Dot(Function): register('dot', Dot, gpu=True) register('matmul', Dot, gpu=True) +# ************* simple ops ************* + +class Reshape(Function): + @staticmethod + def forward(ctx, x, shape): + ctx.save_for_backward(x.shape) + x.shape = shape + return x + + @staticmethod + def backward(ctx, grad_output): + in_shape, = ctx.saved_tensors + grad_output.shape = in_shape + return grad_output +register('reshape', Reshape, gpu=True) + +# ************* activation ops ************* class ReLU(Function): @staticmethod diff --git a/tinygrad/utils.py b/tinygrad/utils.py index e59f559a74..8fe2bb7367 100644 --- a/tinygrad/utils.py +++ b/tinygrad/utils.py @@ -17,8 +17,8 @@ def fetch(url): dat = f.read() else: print("fetching %s" % url) + dat = requests.get(url).content with open(fp+".tmp", "wb") as f: - dat = requests.get(url).content f.write(dat) os.rename(fp+".tmp", fp) return dat