fixup training loop

This commit is contained in:
George Hotz
2020-12-27 18:35:56 -05:00
parent f15bec6dbc
commit a361ef6861
4 changed files with 80 additions and 45 deletions

View File

@@ -11,9 +11,9 @@ from extra.utils import fetch, get_parameters
def fetch_mnist():
import gzip
parse = lambda dat: np.frombuffer(gzip.decompress(dat), dtype=np.uint8).copy()
X_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28))
X_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32)
Y_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"))[8:]
X_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28))
X_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32)
Y_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"))[8:]
return X_train, Y_train, X_test, Y_test