diff --git a/examples/train_efficientnet.py b/examples/train_efficientnet.py index da1c76b79e..dbbd8ef5c8 100644 --- a/examples/train_efficientnet.py +++ b/examples/train_efficientnet.py @@ -59,15 +59,17 @@ if __name__ == "__main__": p.daemon = True p.start() else: - X_train, Y_train = fetch_cifar() + X_train, Y_train, _, _ = fetch_cifar() + X_train = X_train.reshape((-1, 3, 32, 32)) + Y_train = Y_train.reshape((-1,)) - with Tensor.train() + with Tensor.train(): for i in (t := trange(steps)): if IMAGENET: X, Y = q.get(True) else: samp = np.random.randint(0, X_train.shape[0], size=(BS)) - X, Y = X_train[samp], Y_train[samp] + X, Y = X_train.numpy()[samp], Y_train.numpy()[samp] st = time.time() out = model.forward(Tensor(X.astype(np.float32), requires_grad=False))