mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fixup training loop
This commit is contained in:
@@ -11,9 +11,9 @@ from extra.utils import fetch, get_parameters
|
||||
def fetch_mnist():
|
||||
import gzip
|
||||
parse = lambda dat: np.frombuffer(gzip.decompress(dat), dtype=np.uint8).copy()
|
||||
X_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28))
|
||||
X_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32)
|
||||
Y_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"))[8:]
|
||||
X_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28))
|
||||
X_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32)
|
||||
Y_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"))[8:]
|
||||
return X_train, Y_train, X_test, Y_test
|
||||
|
||||
|
||||
Reference in New Issue
Block a user