mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
[Fix] fix typo: test_mnist -> datasets (#492)
* test_mnist -> datasets * fix mnist_gan
This commit is contained in:
@@ -6,10 +6,10 @@ from tqdm import tqdm
|
||||
sys.path.append(os.getcwd())
|
||||
sys.path.append(os.path.join(os.getcwd(), 'test'))
|
||||
|
||||
from tinygrad.tensor import Tensor, Function, register
|
||||
from tinygrad.tensor import Tensor
|
||||
from extra.utils import get_parameters
|
||||
import tinygrad.nn.optim as optim
|
||||
from test_mnist import X_train
|
||||
from datasets import fetch_mnist
|
||||
from torchvision.utils import make_grid, save_image
|
||||
import torch
|
||||
GPU = os.getenv("GPU") is not None
|
||||
@@ -61,6 +61,7 @@ if __name__ == "__main__":
|
||||
disc_loss = []
|
||||
output_folder = "outputs"
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
X_train = fetch_mnist()[0]
|
||||
train_data_size = len(X_train)
|
||||
ds_noise = Tensor(np.random.randn(64,128).astype(np.float32), requires_grad=False)
|
||||
n_steps = int(train_data_size/batch_size)
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
import sys
|
||||
sys.path.append(os.getcwd())
|
||||
sys.path.append(os.path.join(os.getcwd(), 'test'))
|
||||
from test_mnist import fetch_mnist
|
||||
from datasets import fetch_mnist
|
||||
from tqdm import trange
|
||||
|
||||
def augment_img(X, rotate=10, px=3):
|
||||
|
||||
Reference in New Issue
Block a user