From 102e6356e95e314acace4e734656566ad8875993 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 6 Dec 2020 13:44:31 -0800 Subject: [PATCH] replace layer_init_uniform with .uniform --- README.md | 5 ++--- examples/train_efficientnet.py | 7 +++---- test/test_mnist.py | 12 ++++++------ tinygrad/tensor.py | 4 ++++ tinygrad/utils.py | 4 ---- 5 files changed, 15 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 4a26e6fbfb..74f3f61fc1 100644 --- a/README.md +++ b/README.md @@ -57,12 +57,11 @@ It turns out, a decent autograd tensor library is 90% of what you need for neura ```python from tinygrad.tensor import Tensor import tinygrad.optim as optim -from tinygrad.utils import layer_init_uniform class TinyBobNet: def __init__(self): - self.l1 = Tensor(layer_init_uniform(784, 128)) - self.l2 = Tensor(layer_init_uniform(128, 10)) + self.l1 = Tensor.uniform(784, 128) + self.l2 = Tensor.uniform(128, 10) def forward(self, x): return x.dot(self.l1).relu().dot(self.l2).logsoftmax() diff --git a/examples/train_efficientnet.py b/examples/train_efficientnet.py index 1033b48af8..f50fa0e388 100644 --- a/examples/train_efficientnet.py +++ b/examples/train_efficientnet.py @@ -4,7 +4,6 @@ import numpy as np from extra.efficientnet import EfficientNet from tinygrad.tensor import Tensor from tinygrad.utils import get_parameters, fetch -from tinygrad.utils import layer_init_uniform from tqdm import trange import tinygrad.optim as optim import io @@ -15,9 +14,9 @@ class TinyConvNet: def __init__(self, classes=10): conv = 3 inter_chan, out_chan = 8, 16 # for speed - self.c1 = Tensor(layer_init_uniform(inter_chan,3,conv,conv)) - self.c2 = Tensor(layer_init_uniform(out_chan,inter_chan,conv,conv)) - self.l1 = Tensor(layer_init_uniform(out_chan*6*6, classes)) + self.c1 = Tensor.uniform(inter_chan,3,conv,conv) + self.c2 = Tensor.uniform(out_chan,inter_chan,conv,conv) + self.l1 = Tensor.uniform(out_chan*6*6, classes) def forward(self, x): x = x.conv2d(self.c1).relu().max_pool2d() diff --git a/test/test_mnist.py b/test/test_mnist.py index 8096153cbd..dbe63e1a83 100644 --- a/test/test_mnist.py +++ b/test/test_mnist.py @@ -3,7 +3,7 @@ import os import unittest import numpy as np from tinygrad.tensor import Tensor, GPU -from tinygrad.utils import layer_init_uniform, fetch +from tinygrad.utils import fetch import tinygrad.optim as optim from tqdm import trange @@ -23,8 +23,8 @@ X_train, Y_train, X_test, Y_test = fetch_mnist() # create a model class TinyBobNet: def __init__(self): - self.l1 = Tensor(layer_init_uniform(784, 128)) - self.l2 = Tensor(layer_init_uniform(128, 10)) + self.l1 = Tensor.uniform(784, 128) + self.l2 = Tensor.uniform(128, 10) def parameters(self): return [self.l1, self.l2] @@ -39,9 +39,9 @@ class TinyConvNet: conv = 3 #inter_chan, out_chan = 32, 64 inter_chan, out_chan = 8, 16 # for speed - self.c1 = Tensor(layer_init_uniform(inter_chan,1,conv,conv)) - self.c2 = Tensor(layer_init_uniform(out_chan,inter_chan,conv,conv)) - self.l1 = Tensor(layer_init_uniform(out_chan*5*5, 10)) + self.c1 = Tensor.uniform(inter_chan,1,conv,conv) + self.c2 = Tensor.uniform(out_chan,inter_chan,conv,conv) + self.l1 = Tensor.uniform(out_chan*5*5, 10) def parameters(self): return [self.l1, self.c1, self.c2] diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 181d4318cb..d271562b63 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -117,6 +117,10 @@ class Tensor: def randn(*shape, **kwargs): return Tensor(np.random.randn(*shape).astype(np.float32), **kwargs) + @staticmethod + def uniform(*shape, **kwargs): + return Tensor((np.random.uniform(-1., 1., size=shape)/np.sqrt(np.prod(shape))).astype(np.float32), **kwargs) + @staticmethod def eye(dim, **kwargs): return Tensor(np.eye(dim).astype(np.float32), **kwargs) diff --git a/tinygrad/utils.py b/tinygrad/utils.py index 7029862407..73de62f033 100644 --- a/tinygrad/utils.py +++ b/tinygrad/utils.py @@ -1,10 +1,6 @@ import numpy as np from tinygrad.tensor import Tensor -def layer_init_uniform(*x): - ret = np.random.uniform(-1., 1., size=x)/np.sqrt(np.prod(x)) - return ret.astype(np.float32) - def fetch(url): import requests, os, hashlib, tempfile fp = os.path.join(tempfile.gettempdir(), hashlib.md5(url.encode('utf-8')).hexdigest())