[Fix] fix typo: test_mnist -> datasets (#492)

* test_mnist -> datasets

* fix mnist_gan
This commit is contained in:
AllentDan
2023-01-30 13:30:47 +08:00
committed by GitHub
parent 2db272c7f7
commit 7b6b1f32b1
2 changed files with 4 additions and 3 deletions

View File

@@ -6,10 +6,10 @@ from tqdm import tqdm
sys.path.append(os.getcwd())
sys.path.append(os.path.join(os.getcwd(), 'test'))
from tinygrad.tensor import Tensor, Function, register
from tinygrad.tensor import Tensor
from extra.utils import get_parameters
import tinygrad.nn.optim as optim
from test_mnist import X_train
from datasets import fetch_mnist
from torchvision.utils import make_grid, save_image
import torch
GPU = os.getenv("GPU") is not None
@@ -61,6 +61,7 @@ if __name__ == "__main__":
disc_loss = []
output_folder = "outputs"
os.makedirs(output_folder, exist_ok=True)
X_train = fetch_mnist()[0]
train_data_size = len(X_train)
ds_noise = Tensor(np.random.randn(64,128).astype(np.float32), requires_grad=False)
n_steps = int(train_data_size/batch_size)

View File

@@ -4,7 +4,7 @@ import os
import sys
sys.path.append(os.getcwd())
sys.path.append(os.path.join(os.getcwd(), 'test'))
from test_mnist import fetch_mnist
from datasets import fetch_mnist
from tqdm import trange
def augment_img(X, rotate=10, px=3):