mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Fix examples/train_efficientnet (#1947)
* added missing colon * bug fixes for cifar10 dataset loading needed a reshape to work with conv layers and resolve fetched tensor to numpy since further code expects numpy array
This commit is contained in:
@@ -59,15 +59,17 @@ if __name__ == "__main__":
|
||||
p.daemon = True
|
||||
p.start()
|
||||
else:
|
||||
X_train, Y_train = fetch_cifar()
|
||||
X_train, Y_train, _, _ = fetch_cifar()
|
||||
X_train = X_train.reshape((-1, 3, 32, 32))
|
||||
Y_train = Y_train.reshape((-1,))
|
||||
|
||||
with Tensor.train()
|
||||
with Tensor.train():
|
||||
for i in (t := trange(steps)):
|
||||
if IMAGENET:
|
||||
X, Y = q.get(True)
|
||||
else:
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
X, Y = X_train[samp], Y_train[samp]
|
||||
X, Y = X_train.numpy()[samp], Y_train.numpy()[samp]
|
||||
|
||||
st = time.time()
|
||||
out = model.forward(Tensor(X.astype(np.float32), requires_grad=False))
|
||||
|
||||
Reference in New Issue
Block a user