From 567707a5f68ea582866a3c811e3c0908c24f759f Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 25 Oct 2020 17:16:47 -0700 Subject: [PATCH] rename max_pool2d to match torch, remove more fast conv crap --- test/test_conv_speed.py | 54 ++++------------------------------------- test/test_mnist.py | 4 +-- test/test_tensor.py | 2 +- tinygrad/ops.py | 4 +-- 4 files changed, 10 insertions(+), 54 deletions(-) diff --git a/test/test_conv_speed.py b/test/test_conv_speed.py index a74fdc4978..d4c8feccc9 100644 --- a/test/test_conv_speed.py +++ b/test/test_conv_speed.py @@ -1,15 +1,4 @@ #!/usr/bin/env python - -# if you'd like to use the line profiler -try: - import line_profiler - prof = line_profiler.LineProfiler() - import builtins - builtins.__dict__['profile'] = prof - # add @profile decorator to probe -except ImportError: - prof = None - import time import cProfile import pstats @@ -18,20 +7,6 @@ import numpy as np import torch from tinygrad.tensor import Tensor -def profile_conv(bs, chans, conv, cnt=10): - img = Tensor.zeros(bs, 1, 28, 28) - conv = Tensor.randn(chans, 1, conv, conv) - fpt, bpt = 0.0, 0.0 - for i in range(cnt): - et0 = time.time() - out = img.conv2d(conv) - et1 = time.time() - g = out.mean().backward() - et2 = time.time() - fpt += (et1-et0) - bpt += (et2-et1) - return fpt/cnt, bpt/cnt - def start_profile(): import time pr = cProfile.Profile(timer=lambda: int(time.time()*1e9), timeunit=1e-6) @@ -45,27 +20,12 @@ def stop_profile(pr, sort='cumtime'): ps.sort_stats(sort) ps.print_stats(0.2) - if prof is not None: - prof.print_stats() - class TestConvSpeed(unittest.TestCase): - def test_forward_backward_3x3(self): - # warmup - profile_conv(128, 16, 3, cnt=1) - - pr = start_profile() - fpt, bpt = profile_conv(128, 16, 3) - stop_profile(pr) - - print("forward pass: %.3f ms" % (fpt*1000)) - print("backward pass: %.3f ms" % (bpt*1000)) - def test_mnist(self): # https://keras.io/examples/vision/mnist_convnet/ conv = 3 inter_chan, out_chan = 32, 64 - # ****** torch baseline ******* torch.backends.mkldnn.enabled = False @@ -83,7 +43,7 @@ class TestConvSpeed(unittest.TestCase): with torch.autograd.profiler.profile(record_shapes=True) as tprof: cnt = 5 fpt, bpt = 0.0, 0.0 - for i in range(1+cnt): + for i in range(cnt): et0 = time.time() x = torch.randn(128, 1, 28, 28, requires_grad=True) x = mp(c2d(x,c1).relu()) @@ -94,13 +54,9 @@ class TestConvSpeed(unittest.TestCase): et1 = time.time() out.backward() et2 = time.time() - if i == 0: - pr = start_profile() - else: - fpt += (et1-et0) - bpt += (et2-et1) + fpt += (et1-et0) + bpt += (et2-et1) - stop_profile(pr, sort='time') fpt_baseline = (fpt*1000/cnt) bpt_baseline = (bpt*1000/cnt) print("torch forward pass: %.3f ms" % fpt_baseline) @@ -119,8 +75,8 @@ class TestConvSpeed(unittest.TestCase): for i in range(1+cnt): et0 = time.time() x = Tensor.randn(128, 1, 28, 28) - x = x.conv2d(c1).relu().maxpool2x2() - x = x.conv2d(c2).relu().maxpool2x2() + x = x.conv2d(c1).relu().max_pool2d() + x = x.conv2d(c2).relu().max_pool2d() x = x.reshape(Tensor(np.array((x.shape[0], -1)))) out = x.dot(l1).logsoftmax() out = out.mean() diff --git a/test/test_mnist.py b/test/test_mnist.py index 1affa06d52..46a0b95848 100644 --- a/test/test_mnist.py +++ b/test/test_mnist.py @@ -32,8 +32,8 @@ class TinyConvNet: def forward(self, x): x.data = x.data.reshape((-1, 1, 28, 28)) # hacks - x = x.conv2d(self.c1).relu().maxpool2x2() - x = x.conv2d(self.c2).relu().maxpool2x2() + x = x.conv2d(self.c1).relu().max_pool2d() + x = x.conv2d(self.c2).relu().max_pool2d() x = x.reshape(Tensor(np.array((x.shape[0], -1)))) return x.dot(self.l1).logsoftmax() diff --git a/test/test_tensor.py b/test/test_tensor.py index 18299dd084..676adef4af 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -90,7 +90,7 @@ class TestOps(unittest.TestCase): xt = Tensor(x.detach().numpy()) # in tinygrad - ret = xt.maxpool2x2() + ret = xt.max_pool2d() assert ret.shape == (5,2,10//2,8//2) ret.mean().backward() diff --git a/tinygrad/ops.py b/tinygrad/ops.py index b4cb8b1a06..18fa9c579e 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -124,7 +124,7 @@ class Conv2D(Function): return dx, dw register('conv2d', Conv2D) -class MaxPool2x2(Function): +class MaxPool2D(Function): @staticmethod def forward(ctx, x): my, mx = (x.shape[2]//2)*2, (x.shape[3]//2)*2 @@ -147,5 +147,5 @@ class MaxPool2x2(Function): for X in range(2): ret[:, :, Y:my:2, X:mx:2] = grad_output * (idxs == (Y*2+X)) return ret -register('maxpool2x2', MaxPool2x2) +register('max_pool2d', MaxPool2D)