Files
tinygrad/datasets/__init__.py
2022-09-25 12:50:28 -04:00

23 lines
993 B
Python

import os
import io
import numpy as np
import gzip
import tarfile
import pickle
from extra.utils import fetch
def fetch_mnist():
parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy()
X_train = parse(os.path.dirname(__file__)+"/mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28*28)).astype(np.float32)
Y_train = parse(os.path.dirname(__file__)+"/mnist/train-labels-idx1-ubyte.gz")[8:]
X_test = parse(os.path.dirname(__file__)+"/mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28*28)).astype(np.float32)
Y_test = parse(os.path.dirname(__file__)+"/mnist/t10k-labels-idx1-ubyte.gz")[8:]
return X_train, Y_train, X_test, Y_test
def fetch_cifar():
tt = tarfile.open(fileobj=io.BytesIO(fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')), mode='r:gz')
db = pickle.load(tt.extractfile('cifar-10-batches-py/data_batch_1'), encoding="bytes")
X = db[b'data'].reshape((-1, 3, 32, 32))
Y = np.array(db[b'labels'])
return X, Y